├── .gitignore ├── Makefile ├── README.md ├── config └── exp_configs │ ├── data_pubtabnet.yaml │ └── general_exp.yaml ├── dataset ├── data_preprocessing_config.json ├── preprocess_data.py └── preprocess_data_utils.py ├── evaluate_ted.py ├── figures └── pipeline.png ├── pyproject.toml ├── requirements.txt ├── scripts ├── preprocess_data │ └── preprocess_pubtabnet.sh ├── testing │ └── test_pubtabnet.sh └── training │ └── train_pubtabnet.sh ├── test.py ├── tflop ├── datamodule │ ├── datasets │ │ └── tflop.py │ └── preprocess │ │ ├── common_utils.py │ │ ├── hi_mul_con_table.py │ │ └── image_utils.py ├── evaluator.py ├── lightning_module │ └── lightning_module.py ├── loss.py ├── model │ ├── decoder │ │ ├── mbart_decoder.py │ │ └── utils.py │ ├── model │ │ ├── TFLOP.py │ │ └── TFLOP_Config.py │ └── visual_encoder │ │ ├── __init__.py │ │ ├── backbones │ │ └── SwinV2.py │ │ └── swin.py └── utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | pretrain_weights/* 165 | resources/* 166 | TFLOP-dataset/* 167 | pretrain_weights/* 168 | results/* -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Specify the names of all executables to make. 2 | PROG=update install style_check 3 | .PHONY: ${PROG} 4 | 5 | update: 6 | pip install --upgrade pip wheel 7 | pip install --upgrade -r requirements.txt 8 | 9 | install: 10 | pip install --upgrade pip wheel 11 | pip install -r requirements.txt 12 | git config core.hooksPath .github/hooks 13 | 14 | style_check: 15 | black . --config pyproject.toml 16 | isort . --gitignore --settings-path pyproject.toml 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TFLOP: Table Structure Recognition Framework with Layout Pointer Mechanism 2 | Official implemenetation of "TFLOP: Table Structure Recognition Framework with Layout Pointer Mechanism" (IJCAI 2024) 3 | 4 | ![](./figures/pipeline.png) 5 | 6 | ## 📣 Latest Updates 7 | 8 | - 💻 [17/01/2025] Release of TFLOP code! 9 | - 🚀 [15/10/2024] Try out the enterprise-grade integration of TFLOP within Upstage’s Document Parse -- [[Link](https://console.upstage.ai/playground/document-parse)] 10 | - ⚡️ [03/08/2024] Presentation of TFLOP in IJCAI 2024 -- [[Paper](https://arxiv.org/abs/2501.11800)] 11 | 12 | ## 🚀 Getting Started 13 | ### Installation 14 | ```bash 15 | # Create a new conda environment with Python 3.9 16 | conda create -n tflop python=3.9 17 | conda activate tflop 18 | 19 | # Clone the TFLOP repository 20 | git clone https://github.com/UpstageAI/TFLOP 21 | 22 | # Install required packages 23 | cd TFLOP 24 | pip install torch==2.0.1 torchmetrics==1.6.0 torchvision==0.15.2 25 | pip install -r requirements.txt 26 | ``` 27 | ### Download required files 28 | 1. install & login huggingface 29 | 30 | reference: https://huggingface.co/docs/huggingface_hub/en/guides/cli 31 | ```bash 32 | pip install -U "huggingface_hub[cli]" 33 | huggingface-cli login 34 | ``` 35 | 2. install git-lfs 36 | ```bash 37 | sudo apt install git-lfs 38 | git lfs install 39 | ``` 40 | 3. download dataset from [huggingface](https://huggingface.co/datasets/upstage/TFLOP-dataset) 41 | ```bash 42 | git clone https://huggingface.co/datasets/upstage/TFLOP-dataset 43 | ``` 44 | Directory Layout 45 | ```bash 46 | ├── images 47 | │ ├── test.tar.gz 48 | │ ├── train.tar.gz 49 | │ └── validation.tar.gz 50 | ├── meta_data 51 | │ ├── erroneous_pubtabnet_data.json 52 | │ ├── final_eval_v2.json 53 | │ └── PubTabNet_2.0.0.jsonl 54 | └── pse_results 55 | ├── test 56 | │ └── end2end_results.pkl 57 | ├── train 58 | │ ├── detection_results_0.pkl 59 | │ ├── detection_results_1.pkl 60 | │ ├── detection_results_2.pkl 61 | │ ├── detection_results_3.pkl 62 | │ ├── detection_results_4.pkl 63 | │ ├── detection_results_5.pkl 64 | │ ├── detection_results_6.pkl 65 | │ └── detection_results_7.pkl 66 | └── val 67 | └── detection_results_0.pkl 68 | ``` 69 | 4. unzip image files 70 | ```bash 71 | cd TFLOP-dataset 72 | cd images 73 | tar -xvzf train.tar.gz 74 | tar -xvzf validation.tar.gz 75 | tar -xvzf test.tar.gz 76 | ``` 77 | 5. download pretrained weights 78 | ```bash 79 | mkdir pretrain_weights 80 | cd pretrain_weights 81 | git clone --branch official https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v2 82 | ``` 83 | ### Data preprocessing 84 | 1. preprocess dataset with pse result 85 | ```bash 86 | bash scripts/preprocess_data/preprocess_pubtabnet.sh 87 | ``` 88 | 2. You can get TFLOP-dataset/meta_data/dataset_train.jsonl, TFLOP-dataset/meta_data/validation.jsonl 89 | ```bash 90 | TFLOP-dataset 91 | ├── images 92 | │ ├── test 93 | │ ├── train 94 | │ ├── validation 95 | ├── meta_data 96 | │ ├── dataset_train.jsonl 97 | │ ├── dataset_validation.jsonl 98 | │ ├── erroneous_pubtabnet_data.json 99 | │ ├── final_eval_v2.json 100 | │ └── PubTabNet_2.0.0.jsonl 101 | └── pse_results 102 | ├── test 103 | ├── train 104 | └── val 105 | ``` 106 | 107 | ### Training 108 | ```bash 109 | bash scripts/training/train_pubtabnet.sh 110 | ``` 111 | 112 | ### Evaluation 113 | ```bash 114 | bash scripts/testing/test_pubtabnet.sh 115 | python evaluate_ted.py --model_inference_pathdir / \ 116 | --output_savepath / 117 | 118 | # Example 119 | bash scripts/testing/test_pubtabnet.sh 0 1 results/pubtabnet_experiment/expv1 epoch_29_step_231000 120 | ``` 121 | 122 | ### Contributors 123 | 124 | 125 | 133 | 141 | 149 | 150 |
126 | 127 | mckhang 128 |
129 | Khang, Minsoo 130 |
131 |
132 |
134 | 135 | SeHwanJoo 136 |
137 | Joo, SeHwan 138 |
139 |
140 |
142 | 143 | tghong 144 |
145 | Hong, Teakgyu 146 |
147 |
148 |
151 | 152 | ## Acknowledgement 153 | We would like to express our gratitude for the outstanding works that have served as valuable references in this research: 154 | - [Donut](https://github.com/clovaai/donut) repository for architecture implementation 155 | - [SupContrast](https://github.com/HobbitLong/SupContrast) repository for Contrastive Learning implementation 156 | - [PubTabNet](https://github.com/ibm-aur-nlp/PubTabNet/blob/master/src/metric.py) repository for TED implementation 157 | 158 | 159 | ## Citation 160 | ``` 161 | @inproceedings{khang2024tflop, 162 | title={TFLOP: table structure recognition framework with layout pointer mechanism}, 163 | author={Khang, Minsoo and Hong, Teakgyu}, 164 | booktitle={Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence}, 165 | pages={947--955}, 166 | year={2024} 167 | } 168 | ``` -------------------------------------------------------------------------------- /config/exp_configs/data_pubtabnet.yaml: -------------------------------------------------------------------------------- 1 | # Dataset specific configurations 2 | #-------------------------------# 3 | image_path: TFLOP-dataset/images 4 | meta_data_path: TFLOP-dataset/meta_data 5 | 6 | input_size: 7 | height: 768 8 | width: 768 9 | window_size: 8 10 | align_along_axis: False 11 | 12 | max_length: 2700 13 | bbox_token_cnt: 864 14 | use_cell_bbox: False # Use dr coord by default for PubTabNet -------------------------------------------------------------------------------- /config/exp_configs/general_exp.yaml: -------------------------------------------------------------------------------- 1 | # General experiment related configurations 2 | #-----------------------------------------# 3 | # Data batch config 4 | train_batch_size: 12 # could vary by dataset depending on GPU Mem. 5 | val_batch_size: 16 6 | test_batch_size: 1 7 | accumulate_grad_batches: 1 8 | 9 | # Exp. mode config 10 | use_OTSL: False 11 | add_row_col_prefix: False 12 | add_row_cnt_supervision: False 13 | add_col_cnt_supervision: False 14 | use_ColConLoss: False 15 | use_imgRoiAlign: False 16 | 17 | # Contrast. Learning Config 18 | use_isEmptyFilled_contLearning: False 19 | use_isTheadTbody_contLearning: False 20 | use_RowWise_contLearning: False 21 | use_ColWise_contLearning: False 22 | use_CellWise_contLearning: False 23 | span_coeff_mode: "proportional" 24 | 25 | # Augmentation configs 26 | shuffle_cell_bbox_rate: 0.0 27 | add_watermark: False 28 | empty_cell_ptr_loss_coeff: 0.5 29 | non_empty_cell_ptr_loss_coeff: 0.5 30 | 31 | # Exp. training configs 32 | lr: 2e-5 33 | max_epochs: -1 34 | max_steps: 1000000 35 | warmup_steps: null 36 | 37 | # Placeholder configurations 38 | exp_name: 39 | exp_version: 40 | result_path: 41 | 42 | # Checkpoint & path configurations 43 | pretrained_tokenizer_name_or_path: "hyunwoongko/asian-bart-ecjk" 44 | pretrained_model_name_or_path: "pretrain_weights/donut-base-finetuned-cord-v2" 45 | resume_from_checkpoint_path: null 46 | 47 | #-----Fixed for all experiments-----# 48 | dataset_script_path: tflop.datamodule.datasets.tflop 49 | dataset_class_name: TFLOPDataset 50 | num_training_samples_per_epoch: 480346 # Does not seem to be used anymore 51 | drop_bbox_rate: 0.0 # Does not seem to be used anymore 52 | mask_text_box: False 53 | curriculum_stage: False 54 | max_num_row: 40 # TODO: Potentially needed for colconloss 55 | max_num_col: 40 # TODO: Potentially needed for colconloss 56 | use_fast_decoder: False 57 | use_ptr_decoder: True # NOTE: This needs to be set to false for non-pointer baseline models 58 | 59 | seed: 42 60 | num_workers: 8 61 | val_check_interval: 1.0 62 | check_val_every_n_epoch: 1 63 | gradient_clip_val: 1.0 64 | num_nodes: 1 65 | strategy: "deepspeed_stage_2" # choices: ["ddp", "deepspeed_stage_2"] 66 | 67 | special_chars: 68 | - rowspan=" 69 | - colspan=" 70 | - ' 72 | - 73 | - 74 | - 75 | - 76 | - 77 | - 78 | - 79 | - 80 | - 81 | - 82 | - C-tag 83 | - U-tag 84 | - L-tag 85 | - X-tag 86 | - NL-tag 87 | - R-tag 88 | - 89 | -------------------------------------------------------------------------------- /dataset/data_preprocessing_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "SPLITS": ["train", "validation"], 3 | "OTSL_TAG": { 4 | "C": "C-tag", 5 | "L": "L-tag", 6 | "U": "U-tag", 7 | "X": "X-tag", 8 | "NL": "NL-tag" 9 | }, 10 | "PUBTABNET_PATH": "TFLOP-dataset/meta_data/PubTabNet_2.0.0.jsonl", 11 | "AMBIGUOUS_DATA_PATH": "TFLOP-dataset/meta_data/erroneous_pubtabnet_data.json", 12 | "DR_COORD_PATH": { 13 | "train": { 14 | "0": "TFLOP-dataset/pse_results/train/detection_results_0.pkl", 15 | "1": "TFLOP-dataset/pse_results/train/detection_results_1.pkl", 16 | "2": "TFLOP-dataset/pse_results/train/detection_results_2.pkl", 17 | "3": "TFLOP-dataset/pse_results/train/detection_results_3.pkl", 18 | "4": "TFLOP-dataset/pse_results/train/detection_results_4.pkl", 19 | "5": "TFLOP-dataset/pse_results/train/detection_results_5.pkl", 20 | "6": "TFLOP-dataset/pse_results/train/detection_results_6.pkl", 21 | "7": "TFLOP-dataset/pse_results/train/detection_results_7.pkl" 22 | }, 23 | "validation": { 24 | "0": "TFLOP-dataset/pse_results/val/detection_results_0.pkl" 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /dataset/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pickle 4 | import random 5 | 6 | import torch 7 | from tqdm import tqdm 8 | 9 | from .preprocess_data_utils import convert_html_to_otsl 10 | 11 | 12 | def format_pubtabnet_gold_coords(gold_bbox_collection): 13 | """Preprocess gold coordinate info for PubTabNet dataset. 14 | 15 | NOTE 16 | - In PubTabnet, empty cells are marked by absence of 'bbox' in the cell's dictionary 17 | 18 | Args: 19 | gold_bbox_collection List[Dict]: List of cell dictionaries 20 | Each dictionary has 'tokens' and 'bbox' (if the cell is filled) keys 21 | E.g. [{'tokens': ['', 'R', 'i', 's', 'k', ' ', 'F', 'a', 'c', 't', 'o', 'r', 's', ''], 'bbox': [28, 5, 77, 14]}, ... ] 22 | """ 23 | cells = [] 24 | for cell in gold_bbox_collection: 25 | if "bbox" in cell: 26 | # This is a cell with filledContent 27 | string_coords = ["%.2f" % c for c in cell["bbox"]] + ["2"] 28 | # Add serialised string 29 | text = "".join(cell["tokens"]) 30 | string_coords = string_coords + [text] 31 | cells.append(" ".join(string_coords)) 32 | else: 33 | # This is an empty cell 34 | string_coords = ["-1.0", "-1.0", "-1.0", "-1.0", "1"] 35 | text = "" 36 | string_coords = string_coords + [text] 37 | cells.append(" ".join(string_coords)) 38 | 39 | return cells 40 | 41 | 42 | def group_det_bbox( 43 | pred_bbox_tensor, gold_bbox_tensor, IOU_threshold=0.1, IOP_threshold=0.1 44 | ): 45 | """Map pred bbox to gold bbox based on IOU. 46 | 47 | Args: 48 | pred_bbox_tensor: torch.Tensor, (N, 4) 49 | gold_bbox_tensor: torch.Tensor, (M, 4) 50 | IOU_threshold: float, threshold for IOU 51 | IOP_threshold: float, threshold for IOP 52 | 53 | """ 54 | 55 | x_left_y_top_tensor = torch.max( 56 | pred_bbox_tensor.unsqueeze(1)[:, :, :2], gold_bbox_tensor.unsqueeze(0)[:, :, :2] 57 | ) # (N, M, 2) 58 | x_right_y_bottom_tensor = torch.min( 59 | pred_bbox_tensor.unsqueeze(1)[:, :, 2:], gold_bbox_tensor.unsqueeze(0)[:, :, 2:] 60 | ) # (N, M, 2) 61 | 62 | x_left = x_left_y_top_tensor[:, :, 0] 63 | y_top = x_left_y_top_tensor[:, :, 1] 64 | x_right = x_right_y_bottom_tensor[:, :, 0] 65 | y_bottom = x_right_y_bottom_tensor[:, :, 1] 66 | 67 | # Compute the intersection area 68 | intersection_area = torch.logical_or(x_right < x_left, y_bottom < y_top).float() 69 | intersection_area = 1 - intersection_area # (N, M) 70 | intersection_area = intersection_area * (x_right - x_left) * (y_bottom - y_top) 71 | 72 | # Compute the area of both bounding boxes 73 | bbox1_area = (pred_bbox_tensor[:, 2] - pred_bbox_tensor[:, 0]) * ( 74 | pred_bbox_tensor[:, 3] - pred_bbox_tensor[:, 1] 75 | ) # (N,) 76 | bbox2_area = (gold_bbox_tensor[:, 2] - gold_bbox_tensor[:, 0]) * ( 77 | gold_bbox_tensor[:, 3] - gold_bbox_tensor[:, 1] 78 | ) # (M,) 79 | 80 | # Compute the IOU 81 | iou = intersection_area / ( 82 | bbox1_area.unsqueeze(1) + bbox2_area.unsqueeze(0) - intersection_area + 1e-6 83 | ) # (N, M) 84 | 85 | # Map the pred bbox to gold bbox 86 | iou_max, iou_gold_bbox_idx = torch.max(iou, dim=1) # (N,) 87 | iou_pred_bbox_idx = torch.arange(pred_bbox_tensor.shape[0]) # (N,) 88 | iou_pred_bbox_idx = iou_pred_bbox_idx[iou_max > IOU_threshold] 89 | iou_gold_bbox_idx = iou_gold_bbox_idx[iou_max > IOU_threshold] 90 | 91 | # For preds not associated with any gold bbox, recheck and associate if the overlap of pred bbox is > 0.1 92 | iop = intersection_area / bbox1_area.unsqueeze(1) # (N, M) 93 | iop_max, iop_gold_bbox_idx = torch.max(iop, dim=1) # (N,) 94 | iop_pred_bbox_idx = torch.arange(pred_bbox_tensor.shape[0]) # (N,) 95 | bool_mask = torch.logical_and(iop_max > IOP_threshold, iou_max <= IOU_threshold) 96 | iop_pred_bbox_idx = iop_pred_bbox_idx[bool_mask] 97 | iop_gold_bbox_idx = iop_gold_bbox_idx[bool_mask] 98 | 99 | pred_bbox_idx = torch.cat([iou_pred_bbox_idx, iop_pred_bbox_idx], dim=0) 100 | gold_bbox_idx = torch.cat([iou_gold_bbox_idx, iop_gold_bbox_idx], dim=0) 101 | 102 | return pred_bbox_idx, gold_bbox_idx, iou, intersection_area 103 | 104 | 105 | def preprocess_det_bbox( 106 | pred_bbox_collection, gold_bbox_collection, IOU_threshold=0.1, IOP_threshold=0.1 107 | ): 108 | """Preprocess detected bbox and gold bbox. 109 | 110 | Args: 111 | pred_bbox_collection: List[List[float]], list of detected bounding boxes 112 | Each detected bounding box is represented as [x1, y1, x2, y2, x3, y3, x4, y4] 113 | gold_bbox_collection: List[Dict], list of gold bounding boxes 114 | IOU_threshold: float, threshold for IOU (Intersection over Union) 115 | IOP_threshold: float, threshold for IOP (Intersection over Prediction) 116 | """ 117 | 118 | # Reformat bounding boxes to [x_left, y_top, x_right, y_bottom] 119 | pred_cell_bboxes = [ 120 | [ 121 | min(coord[0], coord[2], coord[4], coord[6]), 122 | min(coord[1], coord[3], coord[5], coord[7]), 123 | max(coord[0], coord[2], coord[4], coord[6]), 124 | max(coord[1], coord[3], coord[5], coord[7]), 125 | ] 126 | for coord in pred_bbox_collection 127 | ] 128 | 129 | gold_cell_bboxes = [x["bbox"] for x in gold_bbox_collection if "bbox" in x] 130 | gold_cell_contents = [ 131 | "".join(x["tokens"]) for x in gold_bbox_collection if "bbox" in x 132 | ] 133 | grp_to_filled_gold_idx_mapping = {} 134 | current_filled_idx = 0 135 | for gold_idx, gold_bbox in enumerate(gold_bbox_collection): 136 | if "bbox" in gold_bbox: 137 | grp_to_filled_gold_idx_mapping[current_filled_idx] = gold_idx 138 | current_filled_idx += 1 139 | 140 | pred_bbox_tensor = torch.tensor(pred_cell_bboxes) 141 | gold_bbox_tensor = torch.tensor(gold_cell_bboxes) 142 | 143 | pred_bbox_idx, gold_bbox_idx, iou, intersection_area = group_det_bbox( 144 | pred_bbox_tensor, 145 | gold_bbox_tensor, 146 | IOU_threshold=IOU_threshold, 147 | IOP_threshold=IOP_threshold, 148 | ) 149 | 150 | # Group up pred_idx by common gold_bbox_idx 151 | pred_bbox_idx_group = {} 152 | for pred_idx, gold_idx in zip(pred_bbox_idx, gold_bbox_idx): 153 | if gold_idx.item() not in pred_bbox_idx_group: 154 | pred_bbox_idx_group[gold_idx.item()] = [] 155 | pred_bbox_idx_group[gold_idx.item()].append(pred_idx.item()) 156 | 157 | def sort_bbox_coords(bbox_coord_list): 158 | if len(bbox_coord_list) == 1: 159 | return bbox_coord_list 160 | else: 161 | new_list = [] 162 | # 1. get min_height across all bboxes in the list 163 | min_height = 100000 164 | for bbox_coord in bbox_coord_list: 165 | min_height = min(min_height, bbox_coord[3] - bbox_coord[1]) 166 | 167 | sorted_by_height_interval = {} 168 | for bbox_coord in bbox_coord_list: 169 | height_interval = int(bbox_coord[1] / min_height) 170 | if height_interval not in sorted_by_height_interval: 171 | sorted_by_height_interval[height_interval] = [] 172 | sorted_by_height_interval[height_interval].append(bbox_coord) 173 | 174 | for height_interval in sorted(sorted_by_height_interval.keys()): 175 | bbox_coords = sorted_by_height_interval[height_interval] 176 | # sort bbox_coords by x_left 177 | bbox_coords = sorted(bbox_coords, key=lambda x: x[0]) 178 | new_list.extend(bbox_coords) 179 | 180 | return new_list 181 | 182 | # First, serialize bbox coords within each group 183 | for gold_idx in pred_bbox_idx_group.keys(): 184 | pred_bbox_idx_group[gold_idx] = sort_bbox_coords( 185 | [pred_cell_bboxes[x] for x in pred_bbox_idx_group[gold_idx]] 186 | ) 187 | 188 | # Next, serialize groups 189 | gold_idx_first_bbox_list = [(k, v[0]) for k, v in pred_bbox_idx_group.items()] 190 | minimum_group_height = 100000 191 | for gold_idx, bbox_coord in gold_idx_first_bbox_list: 192 | minimum_group_height = min(minimum_group_height, bbox_coord[3] - bbox_coord[1]) 193 | 194 | sorted_by_group_height_interval = {} 195 | for gold_idx, bbox_coord in gold_idx_first_bbox_list: 196 | height_interval = int(bbox_coord[1] / minimum_group_height) 197 | if height_interval not in sorted_by_group_height_interval: 198 | sorted_by_group_height_interval[height_interval] = [] 199 | sorted_by_group_height_interval[height_interval].append((gold_idx, bbox_coord)) 200 | 201 | sorted_gold_idx_first_bbox_list = [] 202 | for height_interval in sorted(sorted_by_group_height_interval.keys()): 203 | gold_idx_first_bbox_list = sorted_by_group_height_interval[height_interval] 204 | # sort bbox_coords by x_left 205 | gold_idx_first_bbox_list = sorted( 206 | gold_idx_first_bbox_list, key=lambda x: x[1][0] 207 | ) 208 | sorted_gold_idx_first_bbox_list.extend(gold_idx_first_bbox_list) 209 | 210 | # Finally, serialize the whole list 211 | serialized_pred_bbox = {} 212 | for group_idx, group_data in enumerate(sorted_gold_idx_first_bbox_list): 213 | serialized_pred_bbox[group_idx] = [ 214 | pred_bbox_idx_group[group_data[0]], 215 | grp_to_filled_gold_idx_mapping[group_data[0]], 216 | gold_cell_contents[group_data[0]], 217 | ] 218 | 219 | return serialized_pred_bbox 220 | 221 | 222 | if __name__ == "__main__": 223 | parser = argparse.ArgumentParser() 224 | parser.add_argument( 225 | "--data_config_path", type=str, help="Path to the data config file" 226 | ) 227 | parser.add_argument( 228 | "--output_dir", type=str, help="Output directory to save the preprocessed data" 229 | ) 230 | parser.add_argument( 231 | "--bin_idx", type=int, default=-1, help="Index of the bin to process" 232 | ) 233 | parser.add_argument( 234 | "--num_bins", 235 | type=int, 236 | default=0, 237 | help="Number of bins to split the dataset into", 238 | ) 239 | args = parser.parse_args() 240 | 241 | # sanity check 242 | if args.bin_idx != -1: 243 | assert args.num_bins > 0 244 | 245 | # config loading and setting up for preprocessing 246 | data_config = json.load(open(args.data_config_path, "r")) 247 | random.seed(42) 248 | generated_dataset = { 249 | k: [] for k in data_config["SPLITS"] 250 | } # This is to store all preprocessed data 251 | 252 | # Data loading 253 | meta_pubtabnet_data = open(data_config["PUBTABNET_PATH"], "r").readlines() 254 | pse_det_data = {"train": {}, "validation": {}} 255 | for split in data_config["SPLITS"]: 256 | for pickle_path in tqdm( 257 | data_config["DR_COORD_PATH"][split].values(), 258 | desc="Loading PSE Det result for %s" % split, 259 | ): 260 | pse_data_list_loaded = pickle.load(open(pickle_path, "rb")) 261 | for pse_data in pse_data_list_loaded: 262 | pse_det_data[split][pse_data["file_name"]] = pse_data 263 | 264 | # In PubTabNet, ambiguous HTML representations for training and validation datasets are removed. NOTE: this is not done for test dataset 265 | split_dataset = {"train": [], "validation": []} 266 | ambiguous_data_filenames = [] 267 | for split_type, split_amb_filenames in json.load( 268 | open(data_config["AMBIGUOUS_DATA_PATH"], "r") 269 | ).items(): 270 | ambiguous_data_filenames.extend(split_amb_filenames) 271 | for raw_data in tqdm(meta_pubtabnet_data, desc="Removing amgiguous data"): 272 | data = json.loads(raw_data) 273 | if data["filename"] in ambiguous_data_filenames: 274 | continue 275 | split = data["split"] 276 | if split == "val": 277 | split = "validation" 278 | assert split in ["train", "validation"], ( 279 | "Invalid split %s" % split 280 | ) # NOTE: Test dataset is not processed at this stage 281 | split_dataset[split].append(raw_data) 282 | 283 | # Preprocessing 284 | # Preprocess Train and Validation dataset 285 | for split in ["train", "validation"]: 286 | if args.bin_idx != -1: 287 | # Split the dataset into bins 288 | num_data = len(split_dataset[split]) 289 | bin_size = int(num_data / args.num_bins) 290 | start_idx = args.bin_idx * bin_size 291 | if args.bin_idx == args.num_bins - 1: 292 | end_idx = num_data 293 | else: 294 | end_idx = start_idx + bin_size 295 | sliced_dataset = split_dataset[split][start_idx:end_idx] 296 | else: 297 | sliced_dataset = split_dataset[split] 298 | 299 | for raw_data in tqdm(sliced_dataset, desc="Pre-processing split %s" % split): 300 | loaded_data = json.loads(raw_data) 301 | data_filename = loaded_data["filename"] 302 | 303 | # GET OTSL representation 304 | otsl_seq, num_rows, num_cols = convert_html_to_otsl( 305 | html_seq=loaded_data["html"]["structure"]["tokens"], 306 | otsl_tag_maps=data_config["OTSL_TAG"], 307 | ) 308 | gold_bbox_seq = format_pubtabnet_gold_coords(loaded_data["html"]["cells"]) 309 | 310 | # Get pse det result 311 | pse_det_result = pse_det_data[split][data_filename] 312 | 313 | # Get pred and gold bbox idx grouping 314 | # pred_bbox_idx is a dict mapping group_idx to list of bbox_coords 315 | pred_bbox_idx = preprocess_det_bbox( 316 | pse_det_result["bbox"], 317 | loaded_data["html"]["cells"], 318 | IOU_threshold=0.1, 319 | IOP_threshold=0.1, 320 | ) 321 | 322 | data_entry = { 323 | "file_name": loaded_data["filename"], 324 | "dr_coord": pred_bbox_idx, 325 | "gold_coord": gold_bbox_seq, 326 | "org_html": loaded_data["html"]["structure"]["tokens"], 327 | "otsl_seq": otsl_seq, 328 | "num_rows": num_rows, 329 | "num_cols": num_cols, 330 | "split": split, 331 | } 332 | 333 | generated_dataset[split].append(data_entry) 334 | 335 | list_of_data = generated_dataset[split] 336 | if args.bin_idx != -1: 337 | savename = "%s/dataset_%s_%d_%d.jsonl" % ( 338 | args.output_dir, 339 | split, 340 | args.bin_idx, 341 | args.num_bins, 342 | ) 343 | else: 344 | savename = "%s/dataset_%s.jsonl" % (args.output_dir, split) 345 | with open(savename, "w") as f: 346 | for d in list_of_data: 347 | f.write(json.dumps(d) + "\n") 348 | -------------------------------------------------------------------------------- /dataset/preprocess_data_utils.py: -------------------------------------------------------------------------------- 1 | def convert_html_to_otsl(html_seq, otsl_tag_maps): 2 | """ 3 | Convert list of html tokens to OTSL format 4 | 5 | Args: 6 | html_seq List[str]: list of html tokens 7 | E.g. ['', '', '', '', ...] 8 | ost_tag_maps Dict[str, str]: mapping of otsl tag symbols 9 | E.g.{"C": "C-tag", "L": "L-tag", "U": "U-tag", "X": "X-tag", "NL": "NL-tag"} 10 | 11 | Returns: 12 | List[str]: Full OTSL sequence 13 | Int: Number of rows in the table 14 | Int: Number of columns in the table 15 | """ 16 | 17 | # 1. Split list of HTML tokens into head and body 18 | end_of_head_index = html_seq.index("") 19 | thead_seq = html_seq[:end_of_head_index] 20 | thead_seq = [ 21 | x for x in thead_seq if x not in ["", "", "", ""] 22 | ] 23 | tbody_seq = html_seq[(end_of_head_index + 1) :] 24 | tbody_seq = [ 25 | x for x in tbody_seq if x not in ["", "", "", ""] 26 | ] 27 | 28 | # 2. Format HTML tags into row-wise list of tokens 29 | thead_row_wise_seq = get_row_wise(thead_seq) 30 | tbody_row_wise_seq = get_row_wise(tbody_seq) 31 | 32 | # 2.1 Check if the thead section is empty 33 | is_head_empty = False 34 | if len(thead_row_wise_seq) == 0: 35 | is_head_empty = True 36 | 37 | # 3. Convert row-wise list of tokens into OTSL array -> 4. Convert OTSL array into OTSL sequence 38 | thead_OTSL_array, num_head_rows, num_cols = None, 0, None 39 | thead_OTSL_seq = [] 40 | if not is_head_empty: 41 | thead_OTSL_array, num_head_rows, num_cols = get_OTSL_array( 42 | thead_row_wise_seq, otsl_tag_maps 43 | ) 44 | thead_OTSL_seq = convert_OTSL_array_to_OTSL_seq( 45 | thead_OTSL_array, 46 | num_rows=num_head_rows, 47 | num_cols=num_cols, 48 | otsl_tag_maps=otsl_tag_maps, 49 | ) 50 | 51 | tbody_OTSL_array, num_body_rows, num_cols = get_OTSL_array( 52 | tbody_row_wise_seq, otsl_tag_maps, ref_num_cols=num_cols 53 | ) 54 | tbody_OTSL_seq = convert_OTSL_array_to_OTSL_seq( 55 | tbody_OTSL_array, 56 | num_rows=num_body_rows, 57 | num_cols=num_cols, 58 | otsl_tag_maps=otsl_tag_maps, 59 | ) 60 | 61 | # 5. Combine thead and tbody into one OTSL sequence 62 | combined_OTSL_seq = ( 63 | [""] 64 | + thead_OTSL_seq 65 | + ["", ""] 66 | + tbody_OTSL_seq 67 | + [""] 68 | ) 69 | num_rows = num_head_rows + num_body_rows 70 | 71 | return combined_OTSL_seq, num_rows, num_cols 72 | 73 | 74 | def get_OTSL_array(row_wise_html_tags, otsl_tag_maps, ref_num_cols=None): 75 | """Generate OTSL array from row-wise html tags. 76 | 77 | Args: 78 | row_wise_html_tags List[List[str]]: list of list of html tags, where each inner list is a row 79 | E.g. [['', ''], ['', '']] 80 | otsl_tag_maps Dict[str, str]: mapping of otsl tag symbols 81 | E.g.{"C": "C-tag", "L": "L-tag", "U": "U-tag", "X": "X-tag", "NL": "NL-tag"} 82 | ref_num_cols int: reference number of columns to use. If None, will derive from row_wise_html_tags 83 | - Used to sanity check if tbody's num_cols match that of thead 84 | 85 | Returns: 86 | Tuple[List[List[str]], int, int]: OTSL array, number of rows, number of columns 87 | """ 88 | num_rows, num_cols = get_num_rows_and_cols(row_wise_html_tags) 89 | if ref_num_cols is not None and ref_num_cols != num_cols: 90 | raise ValueError( 91 | "Number of columns in tbody does not match that of thead. Got %s but expected %s" 92 | % (num_cols, ref_num_cols) 93 | ) 94 | 95 | # 1. Initialize OTSL array 96 | otsl_array = [list([None] * num_cols) for _ in range(num_rows)] 97 | 98 | # 2. Fill in OTSL array 99 | curr_row_ind, curr_col_ind = 0, 0 100 | current_data = {"standard": 0, "rowspan": 0, "colspan": 0} 101 | for row_tokens in row_wise_html_tags: 102 | for tok_i, tok in enumerate(row_tokens): 103 | # 2.1 sanity check token 104 | if tok not in ["", "", ""] and ( 105 | "rowspan" not in tok and "colspan" not in tok 106 | ): 107 | raise ValueError("Invalid HTML %s" % tok) 108 | 109 | # 2.2 iter over tokens in the row 110 | if tok in [""]: 111 | continue 112 | elif tok == "": 113 | current_data["standard"] += 1 114 | elif "rowspan" in tok: 115 | current_data["rowspan"] += int(tok.split("=")[1].split('"')[1]) 116 | elif "colspan" in tok: 117 | current_data["colspan"] += int(tok.split("=")[1].split('"')[1]) 118 | elif tok == "": 119 | # End of cell -> i.e. Time to start updating OTSL array with current_data 120 | # 2.2.1 Find row & col ind to insert data 121 | while otsl_array[curr_row_ind][curr_col_ind] is not None: 122 | curr_col_ind += 1 123 | if curr_col_ind >= num_cols: 124 | curr_col_ind = 0 125 | curr_row_ind += 1 126 | assert ( 127 | curr_row_ind < num_rows 128 | ), "curr_row_ind %s >= num_rows %s" % (curr_row_ind, num_rows) 129 | 130 | # 2.2.2 Sanity check current_data before insertion 131 | sanity_check_move(current_data) 132 | 133 | # 2.2.3 Insert data 134 | otsl_array = insert_data_into_OTSL( 135 | current_data=current_data, 136 | OTSL_array=otsl_array, 137 | otsl_tag_maps=otsl_tag_maps, 138 | curr_row_ind=curr_row_ind, 139 | curr_col_ind=curr_col_ind, 140 | ) 141 | 142 | # 2.2.4 reset current_data 143 | current_data = {"standard": 0, "rowspan": 0, "colspan": 0} 144 | 145 | else: 146 | raise ValueError("Invalid HTML %s" % tok) 147 | 148 | return otsl_array, num_rows, num_cols 149 | 150 | 151 | def convert_OTSL_array_to_OTSL_seq(otsl_array, num_rows, num_cols, otsl_tag_maps): 152 | """Convert OTSL array to OTSL sequence. 153 | 154 | Args: 155 | otsl_array List[List[str]]: OTSL array 156 | num_rows int: number of rows in OTSL array 157 | num_cols int: number of columns in OTSL array 158 | otsl_tag_maps Dict[str, str]: mapping of otsl tag symbols 159 | E.g.{"C": "C-tag", "L": "L-tag", "U": "U-tag", "X": "X-tag", "NL": "NL-tag"} 160 | 161 | Returns: 162 | List[str]: OTSL sequence 163 | """ 164 | OTSL_seq = [] 165 | 166 | for row_ind in range(num_rows): 167 | for col_ind in range(num_cols): 168 | assert ( 169 | otsl_array[row_ind][col_ind] is not None 170 | ), "row_ind %s, col_ind %s" % (row_ind, col_ind) 171 | OTSL_seq.append(otsl_array[row_ind][col_ind]) 172 | 173 | OTSL_seq.append(otsl_tag_maps["NL"]) 174 | 175 | return OTSL_seq 176 | 177 | 178 | # -----Auxiliary Functions-----# 179 | def get_row_wise(tok_list): 180 | """Given list of HTML tokens, group them into row-wise format. 181 | 182 | NOTE: 183 | Raises error if there are tokens not encapsulated by 184 | 185 | Args: 186 | tok_list List[str]: list of html tokens 187 | E.g. ['', '', '', ...] 188 | 189 | Returns: 190 | List[List[str]]: list of list of tokens, where each inner list is a row 191 | """ 192 | row_wise_tokens = [] 193 | 194 | is_within_row = False 195 | for tok in tok_list: 196 | if tok == "": 197 | is_within_row = True 198 | tmp_row = [] 199 | elif tok == "": 200 | is_within_row = False 201 | row_wise_tokens.append(tmp_row) 202 | else: 203 | assert is_within_row, "Token not encapsulated by " 204 | tmp_row.append(tok) 205 | 206 | return row_wise_tokens 207 | 208 | 209 | def get_num_rows_and_cols(row_wise_html_tags): 210 | """Given row-wise html tags, derive number of rows and columns. 211 | 212 | Args: 213 | row_wise_html_tags List[List[str]]: list of list of html tags, where each inner list is a row 214 | E.g. [['', ''], ['', '']] 215 | 216 | Returns: 217 | Tuple[int, int]: number of rows and columns 218 | """ 219 | 220 | # Derive the number of rows in this table 221 | num_rows = len(row_wise_html_tags) 222 | num_cols = 0 223 | col_span_tracker = 0 224 | 225 | # Derive the number of columns in this table 226 | for first_row_tok in row_wise_html_tags[0]: 227 | if first_row_tok == "": 228 | if col_span_tracker == 0: 229 | num_cols += 1 230 | else: 231 | num_cols += col_span_tracker 232 | col_span_tracker = 0 233 | else: 234 | if "colspan" in first_row_tok: 235 | col_span_tracker += int(first_row_tok.split("=")[1].split('"')[1]) 236 | 237 | return num_rows, num_cols 238 | 239 | 240 | def sanity_check_move(current_data): 241 | """Sanity checker of current move data prior to updating OTSL array. 242 | 243 | Args: 244 | current_data Dict: current data of move 245 | E.g. {'standard': 1, 'rowspan': 0, 'colspan': 0} 246 | 247 | Checks: 248 | 1. If standard (i.e. single cell), then rowspan and colspan must be 0 249 | 2. If not standard, then rowspan or colspan must be > 0 250 | """ 251 | 252 | if current_data["standard"] == 0: 253 | assert sum([current_data["rowspan"], current_data["colspan"]]) > 0 254 | else: 255 | assert current_data["standard"] == 1 256 | assert sum([current_data["rowspan"], current_data["colspan"]]) == 0 257 | 258 | 259 | def insert_data_into_OTSL( 260 | current_data, OTSL_array, otsl_tag_maps, curr_row_ind, curr_col_ind 261 | ): 262 | """Given current_data, insert data into OTSL array. 263 | 264 | Args: 265 | current_data Dict: current data of move 266 | E.g. {'standard': 1, 'rowspan': 0, 'colspan': 0} 267 | OTSL_array List[List[str]]: OTSL array 268 | otsl_tag_maps Dict: mapping of otsl tag symbols 269 | curr_row_ind int: current row index 270 | curr_col_ind int: current column index 271 | 272 | NOTE: 273 | This function updates the OTSL array based on the current_data. 274 | There are 4 cases in total: 275 | 1. Standard cell (i.e. single cell, no rowspan or colspan) 276 | 2. Colspan only 277 | 3. Rowspan only 278 | 4. Both rowspan and colspan 279 | 280 | Returns: 281 | List[List[str]]: updated OTSL array 282 | """ 283 | 284 | if current_data["standard"] == 1: 285 | assert OTSL_array[curr_row_ind][curr_col_ind] is None 286 | OTSL_array[curr_row_ind][curr_col_ind] = otsl_tag_maps[ 287 | "C" 288 | ] # single cell mapped as 'C' in OTSL 289 | else: 290 | # Colspan only 291 | if current_data["rowspan"] == 0: 292 | assert OTSL_array[curr_row_ind][curr_col_ind] is None 293 | OTSL_array[curr_row_ind][curr_col_ind] = otsl_tag_maps["C"] 294 | for i in range(1, current_data["colspan"]): 295 | assert OTSL_array[curr_row_ind][curr_col_ind + i] is None 296 | OTSL_array[curr_row_ind][curr_col_ind + i] = otsl_tag_maps[ 297 | "L" 298 | ] # All cells other than root for colspan mapped as 'L' in OTSL 299 | 300 | # Rowspan only 301 | elif current_data["colspan"] == 0: 302 | assert OTSL_array[curr_row_ind][curr_col_ind] is None 303 | OTSL_array[curr_row_ind][curr_col_ind] = otsl_tag_maps["C"] 304 | for i in range(1, current_data["rowspan"]): 305 | assert OTSL_array[curr_row_ind + i][curr_col_ind] is None 306 | OTSL_array[curr_row_ind + i][curr_col_ind] = otsl_tag_maps[ 307 | "U" 308 | ] # All cells other than root for rowspan mapped as 'U' in OTSL 309 | 310 | # Both rowspan and colspan 311 | else: 312 | assert OTSL_array[curr_row_ind][curr_col_ind] is None 313 | OTSL_array[curr_row_ind][curr_col_ind] = otsl_tag_maps["C"] 314 | 315 | for i in range(1, current_data["colspan"]): 316 | assert OTSL_array[curr_row_ind][curr_col_ind + i] is None 317 | OTSL_array[curr_row_ind][curr_col_ind + i] = otsl_tag_maps["L"] 318 | 319 | for i in range(1, current_data["rowspan"]): 320 | assert OTSL_array[curr_row_ind + i][curr_col_ind] is None 321 | OTSL_array[curr_row_ind + i][curr_col_ind] = otsl_tag_maps["U"] 322 | 323 | for i in range(1, current_data["rowspan"]): 324 | for j in range(1, current_data["colspan"]): 325 | assert OTSL_array[curr_row_ind + i][curr_col_ind + j] is None 326 | OTSL_array[curr_row_ind + i][curr_col_ind + j] = otsl_tag_maps["X"] 327 | 328 | return OTSL_array 329 | 330 | 331 | def calculate_pointer_index( 332 | curr_row_ind, curr_col_ind, row_offset, col_offset, is_table_body, num_cols 333 | ): 334 | """Given current row & col index, along with other info, calculate the index to point to for potsl.""" 335 | 336 | # Apply offset values 337 | point_index = (curr_row_ind - row_offset) * num_cols + (curr_col_ind - col_offset) 338 | 339 | # Add number of rows to pointer as each row ends with NL tag 340 | point_index += curr_row_ind - row_offset 341 | 342 | # If current table is tbody, offset by 3 since , , tags are added 343 | # else, offset by 1 since 344 | if is_table_body: 345 | point_index += 3 346 | else: 347 | point_index += 1 348 | 349 | return point_index 350 | -------------------------------------------------------------------------------- /evaluate_ted.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import multiprocessing 4 | import os 5 | import time 6 | 7 | from Levenshtein import distance 8 | from tqdm import tqdm 9 | from tflop.evaluator import TEDS 10 | 11 | 12 | def strip_html_contents(html_string): 13 | """ Strips the content in the tabular data (i.e. remove preceding and trailing whitespaces in the content) """ 14 | 15 | # get pre_thead 16 | if "" in html_string: 17 | pre_thead, remain_str = html_string.split("", 1) 18 | else: 19 | pre_thead = "" 20 | remain_str = html_string[len(pre_thead) :] 21 | 22 | # get thead 23 | if "" in remain_str: 24 | thead, remain_str = remain_str.split("", 1) 25 | if not remain_str.startswith(""): 26 | remain_str = "" + remain_str 27 | else: 28 | if "" in remain_str: 29 | thead, tmp_remain = remain_str.split("", 1) 30 | remain_str = "" + tmp_remain 31 | elif "" in remain_str: 32 | thead, remain_str = remain_str.split("", 1) 33 | remain_str = "" + remain_str 34 | else: 35 | thead = remain_str.split("
")[0] 36 | remain_str = "" 37 | 38 | # get tbody 39 | remain_str = remain_str.split("", 1)[1] 40 | if "" in remain_str: 41 | tbody, post_tbody = remain_str.split("", 1) 42 | else: 43 | if remain_str == "": 44 | tbody = "" 45 | post_tbody = remain_str 46 | else: 47 | tbody = remain_str.split("")[0] 48 | post_tbody = "" 49 | 50 | thead_stripped, tbody_stripped = [], [] 51 | # thead handling 52 | thead_rows = thead.split("") 53 | for row in thead_rows: 54 | if row == "": 55 | continue 56 | if "" in row: 57 | row_contents = row.split("")[1].split("") 58 | else: 59 | row_contents = row.split("") 60 | 61 | row_stripped = [] 62 | for row_content in row_contents: 63 | if row_content == "": 64 | continue 65 | td_header, td_content = row_content.split(">", 1) 66 | if td_content.startswith("") and td_content.endswith(""): 67 | td_content = td_content[3:-4].strip() 68 | td_content = "" + td_content + "" 69 | else: 70 | td_content = td_content.strip() 71 | new_td_entry = td_header + ">" + td_content + "" 72 | row_stripped.append(new_td_entry) 73 | thead_stripped.append("" + "".join(row_stripped) + "") 74 | 75 | # tbody handling 76 | tbody_rows = tbody.split("") 77 | for row in tbody_rows: 78 | if row == "": 79 | continue 80 | if "" in row: 81 | row_contents = row.split("")[1].split("") 82 | else: 83 | row_contents = row.split("") 84 | 85 | row_stripped = [] 86 | for row_content in row_contents: 87 | if row_content == "": 88 | continue 89 | td_header, td_content = row_content.split(">", 1) 90 | if td_content.startswith("") and td_content.endswith(""): 91 | td_content = td_content[3:-4].strip() 92 | td_content = "" + td_content + "" 93 | else: 94 | td_content = td_content.strip() 95 | new_td_entry = td_header + ">" + td_content + "" 96 | row_stripped.append(new_td_entry) 97 | tbody_stripped.append("" + "".join(row_stripped) + "") 98 | 99 | new_html = ( 100 | pre_thead 101 | + "" 102 | + "".join(thead_stripped) 103 | + "" 104 | + "".join(tbody_stripped) 105 | + "" 106 | + post_tbody 107 | ) 108 | 109 | return new_html 110 | 111 | 112 | def evaluate_distance(data_tuple): 113 | """ Evaluation of TEDs and S-TEDs scores""" 114 | ted_evaluator_structure_only = TEDS(structure_only=True, n_jobs=1) 115 | ted_evaluator = TEDS(structure_only=False, n_jobs=1, ignore_nodes=["b"]) 116 | 117 | file_name, pred_string, gold_string = data_tuple 118 | 119 | # edit-distance 120 | edit_distance = distance(pred_string, gold_string) / max( 121 | len(pred_string), len(gold_string) 122 | ) 123 | 124 | # TED 125 | refined_pred = pred_string 126 | refined_gold = gold_string 127 | if pred_string.startswith("") and pred_string.endswith("
"): 128 | refined_pred = "" + pred_string + "" 129 | elif not pred_string.startswith("") and not pred_string.endswith( 130 | "
" 131 | ): 132 | refined_pred = "" + refined_pred + "
" 133 | 134 | if gold_string.startswith("") and gold_string.endswith("
"): 135 | refined_gold = "" + gold_string + "" 136 | elif not gold_string.startswith("") and not gold_string.endswith( 137 | "
" 138 | ): 139 | refined_gold = "" + refined_gold + "
" 140 | 141 | # strip content in table data 142 | refined_pred = strip_html_contents(refined_pred) 143 | refined_gold = strip_html_contents(refined_gold) 144 | 145 | # tree-edit-distance (structure only) 146 | try: 147 | ted_score_structure_only = ted_evaluator_structure_only.evaluate( 148 | refined_pred, refined_gold 149 | ) 150 | except: 151 | ted_score_structure_only = 0.0 152 | 153 | # tree-edit-distance (structure + content) 154 | try: 155 | ted_score = ted_evaluator.evaluate( 156 | refined_pred, refined_gold 157 | ) 158 | except: 159 | ted_score = 0.0 160 | 161 | return ( 162 | file_name, 163 | pred_string, 164 | gold_string, 165 | edit_distance, 166 | ted_score_structure_only, 167 | ted_score, 168 | ) 169 | 170 | 171 | if __name__ == "__main__": 172 | parser = argparse.ArgumentParser() 173 | parser.add_argument("--model_inference_pathdir", type=str, required=True) 174 | parser.add_argument("--output_savepath", type=str, required=True) 175 | args = parser.parse_args() 176 | 177 | if "full_model_inference.json" in os.listdir(args.model_inference_pathdir): 178 | with open( 179 | os.path.join(args.model_inference_pathdir, "full_model_inference.json"), "r" 180 | ) as f: 181 | model_inference = json.load(f) 182 | else: 183 | split_files = [ 184 | os.path.join(args.model_inference_pathdir, f) 185 | for f in os.listdir(args.model_inference_pathdir) 186 | if (f.startswith("full_model_inference") and f.endswith(".json")) 187 | ] 188 | model_inference = {} 189 | for split_file in split_files: 190 | with open(split_file, "r") as f: 191 | model_inference.update(json.load(f)) 192 | 193 | data_collection = [ 194 | ( 195 | k, 196 | v["pred_string"], 197 | v["answer_string"] 198 | ) 199 | for k, v in model_inference.items() 200 | ] 201 | # sort data_collection by k, where k is the file name 202 | data_collection = sorted(data_collection, key=lambda x: x[0]) 203 | 204 | batch_size = 200 205 | num_processes = 8 206 | if len(data_collection) % batch_size == 0: 207 | num_batches = len(data_collection) // batch_size 208 | else: 209 | num_batches = (len(data_collection) // batch_size) + 1 210 | 211 | result_collection = [] 212 | for batch_idx in tqdm( 213 | range(num_batches), 214 | desc="Evaluating...", 215 | position=1, 216 | leave=False, 217 | ): 218 | batch_data_collection = data_collection[ 219 | batch_idx * batch_size : (batch_idx + 1) * batch_size 220 | ] 221 | pool = multiprocessing.Pool(processes=num_processes) 222 | outputs = pool.map(evaluate_distance, batch_data_collection) 223 | pool.close() 224 | pool.join() 225 | result_collection.extend(outputs) 226 | 227 | teds_score = sum([x[-1] for x in result_collection]) / len(result_collection) 228 | print(f"TEDs score: {teds_score}") 229 | 230 | # /ted_score_output.json 231 | with open( 232 | os.path.join( 233 | args.output_savepath, 234 | "ted_score_output.json", 235 | ), 236 | "w", 237 | encoding="utf-8", 238 | ) as f: 239 | json.dump(result_collection, f, ensure_ascii=False) 240 | -------------------------------------------------------------------------------- /figures/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UpstageAI/TFLOP/3351a1ca706f03cf27570c76ae0c27b649d68df1/figures/pipeline.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [tool.black] 9 | line-length = 88 10 | target-version = ['py310'] 11 | 12 | [tool.isort] 13 | py_version = 310 14 | line_length = 88 15 | atomic = true 16 | combine_as_imports = true 17 | force_sort_within_sections = true 18 | profile = "black" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | aiohappyeyeballs==2.4.4 3 | aiohttp==3.11.9 4 | aiosignal==1.3.1 5 | antlr4-python3-runtime==4.9.3 6 | apted==1.0.3 7 | async-timeout==5.0.1 8 | attrs==24.2.0 9 | black==24.10.0 10 | cachetools==5.5.0 11 | certifi==2024.8.30 12 | charset-normalizer==3.4.0 13 | click==8.1.8 14 | cmake==3.31.1 15 | deepspeed==0.9.5 16 | Distance==0.1.3 17 | filelock==3.16.1 18 | frozenlist==1.5.0 19 | fsspec==2024.10.0 20 | google-auth==2.36.0 21 | google-auth-oauthlib==0.4.6 22 | grpcio==1.68.1 23 | hjson==3.1.0 24 | huggingface-hub==0.26.3 25 | idna==3.10 26 | importlib_metadata==8.5.0 27 | isort==5.13.2 28 | Jinja2==3.1.4 29 | Levenshtein==0.20.9 30 | lightning-utilities==0.11.9 31 | lit==18.1.8 32 | lxml==4.9.3 33 | Markdown==3.7 34 | MarkupSafe==3.0.2 35 | mpmath==1.3.0 36 | multidict==6.1.0 37 | mypy-extensions==1.0.0 38 | networkx==3.2.1 39 | ninja==1.11.1.2 40 | numpy==1.22.1 41 | nvidia-cublas-cu11==11.10.3.66 42 | nvidia-cublas-cu12==12.4.5.8 43 | nvidia-cuda-cupti-cu11==11.7.101 44 | nvidia-cuda-cupti-cu12==12.4.127 45 | nvidia-cuda-nvrtc-cu11==11.7.99 46 | nvidia-cuda-nvrtc-cu12==12.4.127 47 | nvidia-cuda-runtime-cu11==11.7.99 48 | nvidia-cuda-runtime-cu12==12.4.127 49 | nvidia-cudnn-cu11==8.5.0.96 50 | nvidia-cudnn-cu12==9.1.0.70 51 | nvidia-cufft-cu11==10.9.0.58 52 | nvidia-cufft-cu12==11.2.1.3 53 | nvidia-curand-cu11==10.2.10.91 54 | nvidia-curand-cu12==10.3.5.147 55 | nvidia-cusolver-cu11==11.4.0.1 56 | nvidia-cusolver-cu12==11.6.1.9 57 | nvidia-cusparse-cu11==11.7.4.91 58 | nvidia-cusparse-cu12==12.3.1.170 59 | nvidia-ml-py==12.535.161 60 | nvidia-nccl-cu11==2.14.3 61 | nvidia-nccl-cu12==2.21.5 62 | nvidia-nvjitlink-cu12==12.4.127 63 | nvidia-nvtx-cu11==11.7.91 64 | nvidia-nvtx-cu12==12.4.127 65 | nvitop==1.3.2 66 | oauthlib==3.2.2 67 | omegaconf==2.3.0 68 | packaging==24.2 69 | pathspec==0.12.1 70 | Pillow==9.4.0 71 | pip==24.3.1 72 | platformdirs==4.3.6 73 | propcache==0.2.1 74 | protobuf==3.20.3 75 | psutil==6.1.0 76 | py-cpuinfo==9.0.0 77 | pyasn1==0.6.1 78 | pyasn1_modules==0.4.1 79 | pydantic==1.10.19 80 | PyPDF2==2.11.1 81 | pytorch-lightning==2.0.4 82 | PyYAML==6.0.2 83 | rapidfuzz==2.15.2 84 | regex==2024.11.6 85 | requests==2.32.3 86 | requests-oauthlib==2.0.0 87 | rsa==4.9 88 | safetensors==0.4.5 89 | setuptools==75.6.0 90 | six==1.16.0 91 | sympy==1.13.1 92 | tensorboard==2.11.2 93 | tensorboard-data-server==0.6.1 94 | tensorboard-plugin-wit==1.8.1 95 | termcolor==2.5.0 96 | timm==0.6.13 97 | tokenizers==0.13.3 98 | tomli==2.2.1 99 | torch==2.0.1 100 | torchmetrics==1.6.0 101 | torchvision==0.15.2 102 | tqdm==4.67.1 103 | transformers==4.31.0 104 | triton==2.0.0 105 | typing_extensions==4.12.2 106 | urllib3==2.2.3 107 | Werkzeug==3.1.3 108 | wheel==0.45.1 109 | xformers==0.0.22 110 | yarl==1.18.3 111 | zipp==3.21.0 112 | -------------------------------------------------------------------------------- /scripts/preprocess_data/preprocess_pubtabnet.sh: -------------------------------------------------------------------------------- 1 | data_config_path="dataset/data_preprocessing_config.json" 2 | output_dir="TFLOP-dataset/meta_data" 3 | 4 | python -m dataset.preprocess_data --data_config_path $data_config_path \ 5 | --output_dir $output_dir 6 | -------------------------------------------------------------------------------- /scripts/testing/test_pubtabnet.sh: -------------------------------------------------------------------------------- 1 | bin_idx=$1 2 | total_bin_cnt=$2 3 | exp_savepath=$3 4 | epoch_step_checkpoint=$4 5 | 6 | pubtabnet_aux_json_path="TFLOP-dataset/meta_data/final_eval_v2.json" 7 | pubtabnet_aux_img_path="TFLOP-dataset/images/test" 8 | pubtabnet_aux_rec_pkl_path="TFLOP-dataset/pse_results/test/end2end_results.pkl" 9 | 10 | tokenizer_name_or_path="${exp_savepath}/${epoch_step_checkpoint}" 11 | model_name_or_path="${tokenizer_name_or_path}" 12 | exp_config_path="${exp_savepath}/config.yaml" 13 | model_config_path="${exp_savepath}/${epoch_step_checkpoint}/config.json" 14 | 15 | echo "Test..." 16 | python test.py --tokenizer_name_or_path ${tokenizer_name_or_path} \ 17 | --model_name_or_path ${model_name_or_path} \ 18 | --exp_config_path ${exp_config_path} \ 19 | --model_config_path ${model_config_path} \ 20 | --aux_json_path $pubtabnet_aux_json_path \ 21 | --aux_img_path $pubtabnet_aux_img_path \ 22 | --aux_rec_pkl_path $pubtabnet_aux_rec_pkl_path \ 23 | --batch_size 12 \ 24 | --save_dir "${exp_savepath}/${epoch_step_checkpoint}" \ 25 | --current_bin $bin_idx \ 26 | --num_bins $total_bin_cnt 27 | -------------------------------------------------------------------------------- /scripts/training/train_pubtabnet.sh: -------------------------------------------------------------------------------- 1 | EXP_NAME="pubtabnet_experiment" 2 | EXP_VERSION="${EXP_NAME}" 3 | RESULT_PATH="results" 4 | TOKENIZER="hyunwoongko/asian-bart-ecjk" 5 | 6 | python train.py --exp_config config/exp_configs/general_exp.yaml \ 7 | --data_config config/exp_configs/data_pubtabnet.yaml \ 8 | exp_version=${EXP_VERSION} \ 9 | exp_name=${EXP_NAME} \ 10 | result_path=${RESULT_PATH} \ 11 | max_length=1376 \ 12 | bbox_token_cnt=864 \ 13 | train_batch_size=16 \ 14 | val_batch_size=12 \ 15 | use_OTSL=True \ 16 | num_workers=8 \ 17 | use_bbox_HiMulConET=True \ 18 | lr=0.00008 \ 19 | max_steps=250000 \ 20 | pretrained_tokenizer_name_or_path=${TOKENIZER} \ 21 | use_imgRoiAlign=True \ 22 | use_RowWise_contLearning=True \ 23 | use_ColWise_contLearning=True \ 24 | span_coeff_mode=proportional \ 25 | empty_cell_ptr_loss_coeff=0.5 \ 26 | non_empty_cell_ptr_loss_coeff=0.5 27 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import partial 3 | import json 4 | from multiprocessing.pool import ThreadPool 5 | import os 6 | 7 | from Levenshtein import distance 8 | from omegaconf import OmegaConf 9 | import torch 10 | from torch.nn.utils.rnn import pad_sequence 11 | from torch.utils.data import DataLoader 12 | import tqdm 13 | from transformers import AutoTokenizer 14 | 15 | from tflop.datamodule.datasets.tflop import TFLOPTestDataset 16 | from tflop.model.model.TFLOP import TFLOP 17 | from tflop.model.model.TFLOP_Config import TFLOPConfig 18 | from tflop.utils import custom_format_html, decode_OTSL_seq, resolve_missing_config 19 | 20 | 21 | def evaluate_model( 22 | model, 23 | tokenizer, 24 | dataloader, 25 | config, 26 | model_dtype, 27 | current_bin=-1, 28 | num_bins=0, 29 | ): 30 | 31 | batch_lower_bound = 0 32 | batch_upper_bound = len(dataloader) 33 | if current_bin >= 0: 34 | assert num_bins > 0 35 | dataloader_binsize = len(dataloader) // num_bins 36 | if len(dataloader) % num_bins != 0: 37 | dataloader_binsize += 1 38 | batch_lower_bound = current_bin * dataloader_binsize 39 | batch_upper_bound = min((current_bin + 1) * dataloader_binsize, len(dataloader)) 40 | 41 | result_collection = {} 42 | batch_index = 0 43 | for batch in tqdm.tqdm(dataloader, desc="Evaluating"): 44 | if batch_index < batch_lower_bound or batch_index >= batch_upper_bound: 45 | batch_index += 1 46 | continue 47 | 48 | pointer_args = None 49 | if config.use_ptr_decoder: 50 | # img_tensor, input_ids, coords_int_padded, valid_coord_length, prompt_end_index, html_with_content, cell_text_collated 51 | image_tensors = batch[0] # (bsz, 3, height, width) 52 | decoder_input_ids = batch[1] # (bsz, text_token_length) 53 | coord_input_idx = batch[2] # (bsz, bbox_token_length, 4) 54 | coord_input_length = batch[3] # (bsz,) 55 | prompt_end_idxs = batch[4] # (bsz,) 56 | html_with_content = batch[5] # list of length==bsz 57 | cell_texts = batch[6] # list of length==bsz 58 | file_names = batch[7] # list of length==bsz 59 | 60 | pointer_args = { 61 | "coord_input_idx": coord_input_idx, 62 | "coord_input_length": coord_input_length, 63 | } 64 | else: 65 | raise NotImplementedError 66 | 67 | decoder_prompts = pad_sequence( 68 | [ 69 | input_id[: end_idx + 1] 70 | for input_id, end_idx in zip(decoder_input_ids, prompt_end_idxs) 71 | ], 72 | batch_first=True, 73 | ) 74 | # Cast tensors to same dtype as model 75 | if model_dtype == "float16": 76 | image_tensors = image_tensors.half() 77 | elif model_dtype == "bfloat16": 78 | image_tensors = image_tensors.bfloat16() 79 | else: 80 | raise ValueError(f"Invalid torch dtype: {model_dtype}") 81 | 82 | # Move data to cuda if model is cuda 83 | if torch.cuda.is_available(): 84 | image_tensors = image_tensors.cuda() 85 | decoder_prompts = decoder_prompts.cuda() 86 | 87 | if pointer_args is not None: 88 | pointer_args["coord_input_idx"] = pointer_args["coord_input_idx"].cuda() 89 | pointer_args["coord_input_length"] = pointer_args[ 90 | "coord_input_length" 91 | ].cuda() 92 | 93 | preds = model.inference( 94 | image_tensors=image_tensors, 95 | prompt_tensors=decoder_prompts, 96 | return_json=False, 97 | return_attentions=False, 98 | pointer_args=pointer_args, 99 | ) 100 | # preds content: 101 | # - output_sequences: (bsz, M), where M is max number of tokens generated within the batch (includes BOS and tokens) 102 | # - text_to_dr_coord: (bsz, M - 2, bbox_token_cnt) 103 | 104 | # Get html seq with content 105 | pred_collection = [] 106 | token_pred_collection = [] 107 | raw_collection = [] 108 | for data_i in range(preds["text_to_dr_coord"].shape[0]): 109 | token_id_seq = preds["output_sequences"][data_i] 110 | cell_text_data = cell_texts[data_i].split("") 111 | token_seq = tokenizer.convert_ids_to_tokens(token_id_seq) 112 | token_pred_collection.append(token_seq) 113 | 114 | output_seq_tokens = decode_OTSL_seq( 115 | otsl_token_seq=token_seq, 116 | pointer_tensor=preds["text_to_dr_coord"][data_i], 117 | cell_text_data=cell_text_data, 118 | ) 119 | pred_collection.append(output_seq_tokens) 120 | 121 | # Also collect raw output for sanity check 122 | raw_token_seq = [] 123 | for token_pred in token_seq: 124 | if token_pred == "▁": 125 | token_to_add = " " 126 | else: 127 | token_to_add = token_pred.replace("▁", "") 128 | raw_token_seq.append(token_to_add) 129 | raw_token_seq = "".join(raw_token_seq) 130 | raw_collection.append(raw_token_seq) 131 | 132 | # Third, get scores 133 | current_batch_result = {} 134 | data_index = 0 135 | for pred, answer, token_pred in zip( 136 | pred_collection, html_with_content, token_pred_collection 137 | ): 138 | curr_filename = file_names[data_index] 139 | assert ( 140 | curr_filename not in current_batch_result 141 | ), f"Duplicate filename: {curr_filename}" 142 | 143 | # pred = re.sub(r"(?:(?<=>) | (?== 0: 195 | assert args.num_bins > 0 196 | 197 | # Load saved config 198 | exp_config = OmegaConf.load(args.exp_config_path) 199 | model_config = OmegaConf.load(args.model_config_path) 200 | print("Config file loaded") 201 | 202 | # Load tokenizer 203 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path) 204 | print("Tokenizer loaded") 205 | 206 | print("\nLoading model...") 207 | # Load pretrained model 208 | # 1. Initializing model 209 | model_config_dict = { 210 | k: v for k, v in exp_config.items() if k in TFLOPConfig.get_member_variables() 211 | } 212 | model_config_dict = resolve_missing_config(model_config_dict) 213 | data_ids = ["C-tag"] 214 | model = TFLOP( 215 | config=TFLOPConfig(**model_config_dict), 216 | tokenizer=tokenizer, 217 | data_ids=data_ids, 218 | ) 219 | 220 | # 2. Load pretrained weights 221 | saved_state_dict = torch.load( 222 | os.path.join(args.model_name_or_path, "pytorch_model.bin"), 223 | map_location=torch.device("cpu"), 224 | ) 225 | encoder_state_dict = { 226 | k[len("encoder.") :]: v 227 | for k, v in saved_state_dict.items() 228 | if k.startswith("encoder.") 229 | } 230 | decoder_state_dict = { 231 | k[len("decoder.") :]: v 232 | for k, v in saved_state_dict.items() 233 | if k.startswith("decoder.") 234 | } 235 | if len(saved_state_dict) != (len(encoder_state_dict) + len(decoder_state_dict)): 236 | raise ValueError("Invalid saved state dict") 237 | print("Loading state_dict into model...") 238 | with ThreadPool(2) as p: 239 | p.map( 240 | partial(custom_load_state_dict, model), 241 | [ 242 | {"key": "encoder", "value": encoder_state_dict}, 243 | {"key": "decoder", "value": decoder_state_dict}, 244 | ], 245 | ) 246 | print("Model weights loaded") 247 | 248 | # 3. Set model dtype, device and mode 249 | if model_config.torch_dtype == "float16": 250 | model.half() 251 | elif model_config.torch_dtype == "bfloat16": 252 | model.bfloat16() 253 | else: 254 | raise ValueError(f"Invalid torch dtype: {model_config.torch_dtype}") 255 | 256 | if torch.cuda.is_available(): 257 | model = model.cuda() 258 | model.eval() 259 | print("Model set-up complete") 260 | 261 | # Set-up data module 262 | dataset_split = "validation" if args.use_validation else "test" 263 | dataset = TFLOPTestDataset( 264 | tokenizer=tokenizer, 265 | split=dataset_split, 266 | config=exp_config, 267 | aux_json_path=args.aux_json_path, 268 | aux_img_path=args.aux_img_path, 269 | aux_rec_pkl_path=args.aux_rec_pkl_path, 270 | ) 271 | 272 | dataloader = DataLoader( 273 | dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 274 | ) 275 | 276 | print("Dataset & loader setup complete. Evaluating...") 277 | import time 278 | 279 | t1 = time.time() 280 | result_collection = evaluate_model( 281 | model, 282 | tokenizer, 283 | dataloader, 284 | exp_config, 285 | model_config.torch_dtype, 286 | args.current_bin, 287 | args.num_bins, 288 | ) 289 | torch.cuda.synchronize() 290 | t2 = time.time() 291 | print(f"Evaluation complete. Time taken: {t2 - t1:.2f} seconds") 292 | 293 | if args.current_bin >= 0: 294 | save_path = os.path.join( 295 | args.save_dir, 296 | "full_model_inference_%s_%s.json" % (args.current_bin, args.num_bins), 297 | ) 298 | else: 299 | save_path = os.path.join(args.save_dir, "full_model_inference.json") 300 | 301 | with open(save_path, "w", encoding="utf-8") as ff: 302 | json.dump(result_collection, ff, ensure_ascii=False) 303 | -------------------------------------------------------------------------------- /tflop/datamodule/preprocess/common_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Tuple 3 | 4 | import torch 5 | 6 | from tflop.datamodule.preprocess.hi_mul_con_table import ( 7 | get_columnwise_HiMulConET_Coeff, 8 | get_rowwise_HiMulConET_Coeff, 9 | ) 10 | 11 | 12 | def int_convert_and_pad_coords( 13 | coords: List[Tuple], padding_coord: int = -1, max_length: int = None 14 | ): 15 | """Returns int() coords padded with padding_coord to max_length. 16 | 17 | Args: 18 | coords (List[Tuple]): List of coords, where each coord is a tuple of 4 float values. 19 | padding_coord (int, optional): Padding value. Defaults to -1. 20 | max_length (int, optional): Maximum length of the output. Defaults to None. 21 | 22 | Returns: 23 | tensor of shape: (bbox_token_cnt, 4) 24 | e.g. (768, 4) 25 | """ 26 | int_coords = [[int(x) for x in coord] for coord in coords] 27 | if max_length is None: 28 | return torch.tensor(int_coords) 29 | else: 30 | if type(padding_coord) == list: 31 | padd_int_coord = padding_coord 32 | else: 33 | padd_int_coord = [ 34 | padding_coord, 35 | padding_coord, 36 | padding_coord, 37 | padding_coord, 38 | ] 39 | 40 | if len(int_coords) > max_length: 41 | return torch.tensor(int_coords[:max_length]) 42 | else: 43 | for _ in range(max_length - len(int_coords)): 44 | int_coords.append(padd_int_coord) 45 | return torch.tensor(int_coords) 46 | 47 | 48 | def convert_gold_coords(gold_coords: List[str]): 49 | """Converts gold_coords into dict of coords, isFilled, text. 50 | 51 | NOTE: 52 | - isFilled is True only if the cell is filled with data & bbox is valid 53 | - isFilled is changed to False if bbox is not valid (i.e. width or height is 0) 54 | 55 | Returns: 56 | - gold_coord_dict: dict of coords, isFilled, text 57 | """ 58 | if gold_coords is None: 59 | return None 60 | 61 | gold_coord_dict = {"coords": [], "isFilled": [], "text": []} 62 | 63 | def is_bbox_invalid(bbox_coord_value): 64 | """Check if the bounding box is invalid (i.e., width or height is near 0).""" 65 | x1, y1, x2, y2 = bbox_coord_value 66 | if abs(x1 - x2) < 0.00001 or abs(y1 - y2) < 0.00001: 67 | return True 68 | return False 69 | 70 | for coord in gold_coords: 71 | split_content = coord.split(" ") 72 | four_coord_values = [float(x) for x in split_content[:4]] 73 | gold_coord_dict["coords"].append(four_coord_values) 74 | 75 | if is_bbox_invalid(four_coord_values): 76 | gold_coord_dict["isFilled"].append(False) 77 | else: 78 | gold_coord_dict["isFilled"].append(int(split_content[4]) == 2) 79 | 80 | gold_coord_dict["text"].append(" ".join(split_content[5:])) 81 | 82 | return gold_coord_dict 83 | 84 | 85 | def generate_filled_html(gold_text_list, is_cell_filled, org_html_list): 86 | """Insert cell filled text within html tags to form filled html sequence. 87 | 88 | Args: 89 | gold_text_list (List[str]): List of gold text for each cell 90 | is_cell_filled (List[bool]): List of bool indicating if the cell is filled with data 91 | org_html_list (List[str]): List of html tags 92 | """ 93 | filled_html_list = [] 94 | data_index = 0 95 | 96 | for org_html_tag in org_html_list: 97 | if org_html_tag != "": 98 | filled_html_list.append(org_html_tag) 99 | else: 100 | if len(gold_text_list[data_index]) > 0 and is_cell_filled[data_index]: 101 | filled_html_list.append(gold_text_list[data_index]) 102 | filled_html_list.append(org_html_tag) 103 | data_index += 1 104 | 105 | filled_html = "".join(filled_html_list) 106 | return filled_html 107 | 108 | 109 | def rescale_bbox( 110 | list_of_coords: List[Tuple], 111 | org_img_size: Tuple, 112 | new_img_size: Tuple, 113 | padding_dims: Tuple, 114 | ): 115 | """Rescale bounding box coordinates to new image size with padding dimensions. 116 | 117 | Args: 118 | list_of_coords (List[Tuple]): A list of bounding box coordinates as tuples (x1, y1, x2, y2). 119 | org_img_size (Tuple): The original size of the image as a tuple (width, height). 120 | new_img_size (Tuple): The new size of the image as a tuple (width, height). 121 | padding_dims (Tuple): The padding dimensions applied to the image as a tuple (left, top, right, bottom). 122 | 123 | NOTE: 124 | This function assumes quad-coord format (x1, y1, x2, y2) 125 | """ 126 | rescaled_coords = [] 127 | for coord in list_of_coords: 128 | x1, y1, x2, y2 = coord 129 | 130 | # Get width & height scales 131 | img_only_width = new_img_size[0] - padding_dims[0] - padding_dims[2] 132 | img_only_height = new_img_size[1] - padding_dims[1] - padding_dims[3] 133 | width_scale = img_only_width / org_img_size[0] 134 | height_scale = img_only_height / org_img_size[1] 135 | 136 | # Scale coords 137 | x1 = (x1 * width_scale) + padding_dims[0] 138 | x2 = (x2 * width_scale) + padding_dims[0] 139 | y1 = (y1 * height_scale) + padding_dims[1] 140 | y2 = (y2 * height_scale) + padding_dims[1] 141 | 142 | # Clipping coordinates to image size 143 | x1 = max(0, x1) 144 | y1 = max(0, y1) 145 | x2 = min(new_img_size[0], x2) 146 | y2 = min(new_img_size[1], y2) 147 | 148 | rescaled_coords.append([x1, y1, x2, y2]) 149 | 150 | return rescaled_coords 151 | 152 | 153 | def get_dr_pointer_label( 154 | dr_coords: dict, 155 | gold_coord_isFilled: List[bool], 156 | cell_shuffle_rate: float, 157 | input_ids: torch.IntTensor, 158 | data_ids: List[int], 159 | bbox_token_cnt: int, 160 | org_img_size: Tuple, 161 | new_img_size: Tuple, 162 | padding_dims: Tuple, 163 | coeff_tensor_args: dict, 164 | ): 165 | """ 166 | This function generates the pointer label when using detection coord information (i.e. dr_coords). 167 | 168 | Args: 169 | dr_coords dict: dict containing detection coord information 170 | gold_coord_isFilled list(bool): list of bool indicating if is filled with data 171 | cell_shuffle_rate float: rate of shuffling cells (for training mode only) 172 | input_ids torch.IntTensor: input_ids of shape (seq_length) 173 | data_ids list(int): list of token_ids which correspond to data tokens (e.g. for html) 174 | bbox_token_cnt int: number of tokens in sequence dedicated for bbox coord representation (e.g. 768) 175 | org_img_size Tuple: original image size as a tuple (width, height) 176 | new_img_size Tuple: new image size as a tuple (width, height) 177 | padding_dims Tuple: padding dimensions as a tuple (left, top, right, bottom) 178 | coeff_tensor_args dict: arguments for coefficient tensor generation 179 | """ 180 | # shared variable set-up 181 | seq_length = input_ids.shape[0] 182 | 183 | # 1. First get bool tensor indicating which of the input_ids correspond to data_ids 184 | is_data_tensor = torch.zeros( 185 | (seq_length - bbox_token_cnt - 1), dtype=torch.bool 186 | ) # -1 for bos token (causal supervision) 187 | for data_id in data_ids: 188 | is_data_tensor = torch.logical_or( 189 | is_data_tensor, input_ids[1:-bbox_token_cnt] == data_id 190 | ) 191 | 192 | # 2. Processing detection coord information 193 | # 2.1 Set-up filled-text only coords, text and indices 194 | filtered_dr_coords = [] 195 | filtered_texts = [] 196 | filtered_org_index_list = [] 197 | 198 | dr_coord_keys = [int(x) for x in list(dr_coords.keys())] 199 | for tmp_idx in sorted(dr_coord_keys): 200 | rescaled_coords = rescale_bbox( 201 | list_of_coords=dr_coords[str(tmp_idx)][0], 202 | org_img_size=org_img_size, 203 | new_img_size=new_img_size, 204 | padding_dims=padding_dims, 205 | ) 206 | filtered_dr_coords.extend(rescaled_coords) 207 | filtered_texts.append(dr_coords[str(tmp_idx)][2]) 208 | filtered_texts += [""] * (len(rescaled_coords) - 1) 209 | filtered_org_index_list.extend( 210 | [int(dr_coords[str(tmp_idx)][1])] * len(rescaled_coords) 211 | ) 212 | 213 | # 2.2 Serialize filtered_dr_coords from top-left to bottom-right 214 | serialized_dr_cell_coords, serialized_dr_cell_texts, serialized_org_index_list = ( 215 | serialize_bbox_top_left_bottom_right(filtered_dr_coords, filtered_texts) 216 | ) 217 | filtered_org_index_list = [ 218 | filtered_org_index_list[i] for i in serialized_org_index_list 219 | ] 220 | 221 | # 2.3 In event that no. of coord entries exceed bbox_token_cnt, slice it. NOTE -1 as first bbox is used for empty cell pointing 222 | if len(serialized_dr_cell_coords) > (bbox_token_cnt - 1): 223 | filtered_dr_coords = serialized_dr_cell_coords[: bbox_token_cnt - 1] 224 | filtered_texts = serialized_dr_cell_texts[: bbox_token_cnt - 1] 225 | filtered_org_index_list = filtered_org_index_list[: bbox_token_cnt - 1] 226 | else: 227 | filtered_dr_coords = serialized_dr_cell_coords 228 | filtered_texts = serialized_dr_cell_texts 229 | 230 | # 2.4 Shuffle filtered_dr_coords at cell_shuffle_rate 231 | if random.random() < cell_shuffle_rate: 232 | zipped_tmp = list( 233 | zip(filtered_org_index_list, filtered_dr_coords, filtered_texts) 234 | ) 235 | random.shuffle(zipped_tmp) 236 | new_org_index_list, new_filtered_dr_coords, new_filtered_texts = zip( 237 | *zipped_tmp 238 | ) 239 | else: 240 | new_org_index_list = filtered_org_index_list 241 | new_filtered_dr_coords = filtered_dr_coords 242 | new_filtered_texts = filtered_texts 243 | 244 | # 3. Get pointer label 245 | pointer_label = [] 246 | dataIndex2bboxIndex = {} 247 | data_cell_index = 0 248 | for i in range(is_data_tensor.shape[0]): 249 | tmp_label = torch.zeros((bbox_token_cnt), dtype=input_ids.dtype) 250 | 251 | if is_data_tensor[i]: # this is a data token, i.e. either '' or ') 310 | pointer_label = pointer_label[ 311 | 1: 312 | ] # (seq_length - bbox_token_cnt - 2, bbox_token_cnt) 313 | is_data_tensor = is_data_tensor[1:] # (seq_length - bbox_token_cnt - 2) 314 | 315 | return ( 316 | pointer_label, 317 | is_data_tensor, 318 | new_filtered_dr_coords, 319 | new_filtered_texts, 320 | coeff_tensor, 321 | ) 322 | 323 | 324 | def get_cell_pointer_label( 325 | gold_cell_coords: List, 326 | gold_coord_isFilled: List[bool], 327 | gold_text: List[str], 328 | cell_shuffle_rate: float, 329 | input_ids: torch.IntTensor, 330 | data_ids: List[int], 331 | bbox_token_cnt: int, 332 | coeff_tensor_args: dict, 333 | ): 334 | """ 335 | This function generates the pointer label when using cell-level information (i.e. gold cell coords). 336 | 337 | Args: 338 | gold_cell_coords list(list(int)): list of gold cell coords, each cell coord is a list of 4 int (x1, y1, x2, y2) 339 | gold_coord_isFilled list(bool): list of bool indicating if is filled with data 340 | gold_text list(str): list of str containing text in 341 | cell_shuffle_rate float: rate of shuffling cells (for training mode only) 342 | input_ids torch.IntTensor: input_ids of shape (seq_length) 343 | data_ids list(int): list of token_ids which correspond to data tokens (e.g. for html) 344 | bbox_token_cnt int: number of tokens in sequence dedicated for bbox coord representation (e.g. 768) 345 | 346 | """ 347 | # shared variable set-up 348 | seq_length = input_ids.shape[0] 349 | 350 | # 1. First get bool tensor indicating which of the input_ids correspond to data_ids 351 | is_data_tensor = torch.zeros( 352 | (seq_length - bbox_token_cnt - 1), dtype=torch.bool 353 | ) # -1 for bos token (causal supervision) 354 | for data_id in data_ids: 355 | is_data_tensor = torch.logical_or( 356 | is_data_tensor, input_ids[1:-bbox_token_cnt] == data_id 357 | ) 358 | 359 | # 2. Processing gold cell information 360 | # 2.1 First filter out all gold_cell_coords which are not filled -- to make subsequent steps easier 361 | filtered_gold_cell_coords = [ 362 | x for i, x in enumerate(gold_cell_coords) if gold_coord_isFilled[i] 363 | ] 364 | filtered_gold_cell_texts = [ 365 | x for i, x in enumerate(gold_text) if gold_coord_isFilled[i] 366 | ] 367 | 368 | filtered_org_index_list = [ 369 | i for i, x in enumerate(gold_coord_isFilled) if x 370 | ] # recently added 371 | 372 | # 2.2 Serialize gold_cell_coords from top-left to bottom-right 373 | filtered_gold_cell_coords, filtered_gold_cell_texts, serialized_org_index_list = ( 374 | serialize_bbox_top_left_bottom_right( 375 | filtered_gold_cell_coords, filtered_gold_cell_texts 376 | ) 377 | ) 378 | filtered_org_index_list = [ 379 | filtered_org_index_list[i] for i in serialized_org_index_list 380 | ] 381 | 382 | # 2.3 In event that no. of gold cells exceed bbox_token_cnt, slice it. NOTE -1 as first bbox is used for empty cell pointing 383 | if len(filtered_gold_cell_coords) > (bbox_token_cnt - 1): 384 | filtered_gold_cell_coords = filtered_gold_cell_coords[: bbox_token_cnt - 1] 385 | filtered_gold_cell_texts = filtered_gold_cell_texts[: bbox_token_cnt - 1] 386 | filtered_org_index_list = filtered_org_index_list[: bbox_token_cnt - 1] 387 | 388 | # 2.4 Shuffle gold_cell_coords at cell_shuffle_rate 389 | if random.random() < cell_shuffle_rate: 390 | zipped_tmp = list( 391 | zip( 392 | filtered_org_index_list, 393 | filtered_gold_cell_coords, 394 | filtered_gold_cell_texts, 395 | ) 396 | ) 397 | random.shuffle(zipped_tmp) 398 | new_gold_cell_indices, new_gold_cell_coords, new_gold_cell_texts = zip( 399 | *zipped_tmp 400 | ) 401 | else: 402 | new_gold_cell_indices = filtered_org_index_list 403 | new_gold_cell_coords = filtered_gold_cell_coords 404 | new_gold_cell_texts = filtered_gold_cell_texts 405 | 406 | # 3. Get pointer label 407 | pointer_label = [] 408 | dataIndex2bboxIndex = {} 409 | data_cell_index = 0 410 | for i in range(is_data_tensor.shape[0]): # iter over (seq_len - bbox_token_cnt - 1) 411 | tmp_label = torch.zeros((bbox_token_cnt), dtype=input_ids.dtype) # seq of 0 412 | 413 | if is_data_tensor[ 414 | i 415 | ]: # this is a data token, i.e. either '' or ') 470 | pointer_label = pointer_label[ 471 | 1: 472 | ] # (seq_length - bbox_token_cnt - 2, bbox_token_cnt) 473 | is_data_tensor = is_data_tensor[1:] # (seq_length - bbox_token_cnt - 2) 474 | 475 | return ( 476 | pointer_label, 477 | is_data_tensor, 478 | new_gold_cell_coords, 479 | new_gold_cell_texts, 480 | coeff_tensor, 481 | ) 482 | 483 | 484 | def serialize_bbox_top_left_bottom_right(coord_list, cell_text_list): 485 | """Serialize bounding boxes from top-left to bottom-right. 486 | 487 | Args: 488 | coord_list list(list(int)): list of bbox coords, each bbox coord is a list of 4 int (x1, y1, x2, y2) 489 | cell_text_list list(str): list of str containing text in bbox 490 | 491 | NOTE: 492 | Both coord_list and cell_text_list are pertaining to FILLED cells ONLY. 493 | 494 | Returns: 495 | serialized_coord_list list(list(int)): list of bbox coords, each bbox coord is a list of 4 int (x1, y1, x2, y2) 496 | serialized_cell_text_list list(str): list of str containing text in bbox 497 | serialized_org_index_list list(int): list of index of bbox in original coord_list 498 | """ 499 | # 1. Get minimum height across bboxes 500 | min_height = min( 501 | [abs(x[1] - x[3]) for x in coord_list] 502 | ) # TODO Change to median height to see if it works better 503 | org_coord_indices = list(range(len(coord_list))) 504 | 505 | # 2. Group up bboxes by row 506 | row_groups = {} 507 | for coord, text, org_index in zip(coord_list, cell_text_list, org_coord_indices): 508 | top_height = coord[1] 509 | row_index = int(top_height // min_height) 510 | if row_index not in row_groups: 511 | row_groups[row_index] = [] 512 | row_groups[row_index].append((coord, text, org_index)) 513 | 514 | # 3. Sort each row by x1 515 | for row_index in row_groups: 516 | row_groups[row_index] = sorted(row_groups[row_index], key=lambda x: x[0][0]) 517 | 518 | # 4. Output serialized coord_list and cell_text_list 519 | serialized_coord_list = [] 520 | serialized_cell_text_list = [] 521 | serialized_org_index_list = [] 522 | for row_index in sorted(row_groups.keys()): 523 | for coord, text, org_index in row_groups[row_index]: 524 | serialized_coord_list.append(coord) 525 | serialized_cell_text_list.append(text) 526 | serialized_org_index_list.append(org_index) 527 | 528 | return ( 529 | serialized_coord_list, 530 | serialized_cell_text_list, 531 | serialized_org_index_list, 532 | ) 533 | 534 | 535 | def get_bbox_iou(bbox1, bbox2): 536 | """Calculate Intersection over Union (IoU) between two bounding boxes. 537 | 538 | Args: 539 | bbox1 (Tuple): Tuple of 4 float values (x1, y1, x2, y2). 540 | bbox2 (Tuple): Tuple of 4 float values (x1, y1, x2, y2). 541 | """ 542 | x1, y1, x2, y2 = bbox1 543 | x3, y3, x4, y4 = bbox2 544 | 545 | x5 = max(x1, x3) 546 | y5 = max(y1, y3) 547 | x6 = min(x2, x4) 548 | y6 = min(y2, y4) 549 | 550 | if x5 >= x6 or y5 >= y6: 551 | return 0 552 | 553 | intersection_area = (x6 - x5) * (y6 - y5) 554 | bbox1_area = (x2 - x1) * (y2 - y1) 555 | bbox2_area = (x4 - x3) * (y4 - y3) 556 | 557 | return intersection_area / (bbox1_area + bbox2_area - intersection_area + 1e-6) 558 | -------------------------------------------------------------------------------- /tflop/datamodule/preprocess/hi_mul_con_table.py: -------------------------------------------------------------------------------- 1 | # Script for implementing Hierarchical Multi-Label Contrastive Learning in Tables (HiMulConET) 2 | import torch 3 | 4 | 5 | def get_rowwise_HiMulConET_Coeff( 6 | sliced_input_ids, 7 | tokenizer, 8 | bbox_token_cnt, 9 | is_data_tensor, 10 | tag2coord_map, 11 | rep_mode="OTSL", 12 | rowspan_coeff_mode="constant", 13 | ): 14 | """ 15 | Function to get rowwise HiMulConET Coefficient Matrix 16 | 17 | Args: 18 | sliced_input_ids (torch.Tensor): [seq_len - bbox_token_cnt - 1], 19 | tokenizer (transformers.tokenization_bert.BertTokenizer): tokenizer 20 | bbox_token_cnt (int): number of bbox tokens used in the model 21 | tag2coord_map (dict): mapping of sliced_input_ids index to bbox token index 22 | rep_mode (str): representation mode, only "OTSL" is supported 23 | rowspan_coeff_mode (str): coefficient mode for rowspan, choice of ["constant", "proportional"] 24 | """ 25 | # sanity check 26 | assert rep_mode in [ 27 | "OTSL", 28 | ], "Non-OTSL modes are deprecated" 29 | 30 | rowwise_HiMulConET_Coeff = torch.zeros((bbox_token_cnt)) 31 | rowspan_HiMulConET_Coeff = torch.zeros((bbox_token_cnt, bbox_token_cnt)) 32 | table_breakdown = breakdown_otsl_seq(sliced_input_ids.tolist(), tokenizer) 33 | 34 | table_data_idx = (is_data_tensor).nonzero(as_tuple=True)[0] 35 | # First, associate each cell in table_breakdown with bbox_indices 36 | table_breakdown_row_id, table_breakdown_col_id = 0, 0 37 | for data_token_id in table_data_idx: 38 | while table_breakdown[table_breakdown_row_id][table_breakdown_col_id] is None: 39 | table_breakdown_col_id += 1 40 | if table_breakdown_col_id >= len(table_breakdown[table_breakdown_row_id]): 41 | table_breakdown_row_id += 1 42 | table_breakdown_col_id = 0 43 | 44 | if data_token_id.item() not in tag2coord_map: 45 | table_breakdown_col_id += 1 46 | if table_breakdown_col_id >= len(table_breakdown[table_breakdown_row_id]): 47 | table_breakdown_row_id += 1 48 | table_breakdown_col_id = 0 49 | continue 50 | 51 | bbox_indices = [ 52 | x + 1 for x in tag2coord_map[data_token_id.item()] 53 | ] # add 1 as the first bbox token is for empty cells 54 | table_breakdown[table_breakdown_row_id][table_breakdown_col_id].append( 55 | bbox_indices 56 | ) 57 | 58 | table_breakdown_col_id += 1 59 | if table_breakdown_col_id >= len(table_breakdown[table_breakdown_row_id]): 60 | table_breakdown_row_id += 1 61 | table_breakdown_col_id = 0 62 | 63 | column_count = len(table_breakdown[0]) 64 | if len(table_breakdown[-1]) < column_count: 65 | table_breakdown[-1].extend( 66 | [None for _ in range(column_count - len(table_breakdown[-1]))] 67 | ) 68 | 69 | # Second, use table_breakdown to fill in rowwise_HiMulConET_Coeff 70 | for table_row_id, table_row in enumerate(table_breakdown): 71 | for table_col_id, table_cell in enumerate(table_row): 72 | if table_cell is None or len(table_cell) == 2: 73 | continue 74 | 75 | if table_cell[0] > 1: # this is rowspan 76 | rowspan_cnt = table_cell[0] 77 | span_cell_bbox_indices = table_cell[2] 78 | rowspan_HiMulConET_Coeff[ 79 | torch.tensor(span_cell_bbox_indices).unsqueeze(1), 80 | span_cell_bbox_indices, 81 | ] = 1 82 | 83 | for rowspan_id in range(rowspan_cnt): 84 | tmp_table_row = table_breakdown[table_row_id + rowspan_id] 85 | for tmp_col_id, tmp_cell in enumerate(tmp_table_row): 86 | if tmp_cell is None or len(tmp_cell) == 2: 87 | continue 88 | if tmp_col_id == table_col_id: 89 | continue 90 | # cells in a row within the rowspan 91 | num_rows_remaining = rowspan_cnt - rowspan_id 92 | row_overlap_cnt = min(num_rows_remaining, tmp_cell[0]) 93 | src2tar_coeff, tar2src_coeff = eval_rowspan_coeff( 94 | rowspan_id, 95 | src_row_span_cnt=rowspan_cnt, 96 | tar_row_span_cnt=tmp_cell[0], 97 | row_overlap_cnt=row_overlap_cnt, 98 | mode=rowspan_coeff_mode, 99 | ) 100 | 101 | rowspan_HiMulConET_Coeff[ 102 | torch.tensor(span_cell_bbox_indices).unsqueeze(1), 103 | tmp_cell[2], 104 | ] = src2tar_coeff 105 | rowspan_HiMulConET_Coeff[ 106 | torch.tensor(tmp_cell[2]).unsqueeze(1), 107 | span_cell_bbox_indices, 108 | ] = tar2src_coeff 109 | else: 110 | rowwise_HiMulConET_Coeff[table_cell[2]] = table_row_id + 1 111 | 112 | # Third, get coeff matrix 113 | rowwise_HiMulConET_Coeff = rowwise_HiMulConET_Coeff.unsqueeze( 114 | -1 115 | ) # [bbox_token_cnt, 1] 116 | match_index = torch.eq( 117 | rowwise_HiMulConET_Coeff, rowwise_HiMulConET_Coeff.T 118 | ) # [bbox_token_cnt, bbox_token_cnt] 119 | match_index[rowwise_HiMulConET_Coeff.squeeze(-1) == 0, :] = ( 120 | 0 # empty token should not be matched with any other tokens 121 | ) 122 | match_index = match_index.type(torch.float) 123 | 124 | # Lastly, combine match_index and rowspan_HiMulConET_Coeff 125 | match_index = match_index + rowspan_HiMulConET_Coeff 126 | 127 | return match_index 128 | 129 | 130 | def get_columnwise_HiMulConET_Coeff( 131 | sliced_input_ids, 132 | tokenizer, 133 | bbox_token_cnt, 134 | is_data_tensor, 135 | tag2coord_map, 136 | rep_mode="OTSL", 137 | colspan_coeff_mode="constant", 138 | ): 139 | """ 140 | Function to get columnwise HiMulConET Coefficient Matrix 141 | 142 | Args: 143 | sliced_input_ids (torch.Tensor): [seq_len - bbox_token_cnt - 1], 144 | tokenizer (transformers.tokenization_bert.BertTokenizer): tokenizer 145 | bbox_token_cnt (int): number of bbox tokens used in the model 146 | tag2coord_map (dict): mapping of sliced_input_ids index to bbox token index 147 | rep_mode (str): representation mode, only "OTSL" is supported 148 | colspan_coeff_mode (str): coefficient mode for colspan, choice of ["constant", "proportional"] 149 | 150 | """ 151 | 152 | # sanity check 153 | assert rep_mode in [ 154 | "OTSL", 155 | ], "Non-OTSL modes are deprecated" 156 | 157 | columnwise_HiMulConET_Coeff = torch.zeros((bbox_token_cnt)) 158 | colspan_HiMulConET_Coeff = torch.zeros((bbox_token_cnt, bbox_token_cnt)) 159 | table_breakdown = breakdown_otsl_seq(sliced_input_ids.tolist(), tokenizer) 160 | 161 | table_data_idx = (is_data_tensor).nonzero(as_tuple=True)[0] 162 | # First, associate each cell in table_breakdown with bbox_indices 163 | table_breakdown_row_id, table_breakdown_col_id = 0, 0 164 | for data_token_id in table_data_idx: 165 | while table_breakdown[table_breakdown_row_id][table_breakdown_col_id] is None: 166 | table_breakdown_col_id += 1 167 | if table_breakdown_col_id >= len(table_breakdown[table_breakdown_row_id]): 168 | table_breakdown_row_id += 1 169 | table_breakdown_col_id = 0 170 | 171 | if data_token_id.item() not in tag2coord_map: 172 | table_breakdown_col_id += 1 173 | if table_breakdown_col_id >= len(table_breakdown[table_breakdown_row_id]): 174 | table_breakdown_row_id += 1 175 | table_breakdown_col_id = 0 176 | continue 177 | 178 | bbox_indices = [x + 1 for x in tag2coord_map[data_token_id.item()]] 179 | table_breakdown[table_breakdown_row_id][table_breakdown_col_id].append( 180 | bbox_indices 181 | ) 182 | 183 | table_breakdown_col_id += 1 184 | if table_breakdown_col_id >= len(table_breakdown[table_breakdown_row_id]): 185 | table_breakdown_row_id += 1 186 | table_breakdown_col_id = 0 187 | 188 | # Second, transpose table_breakdown to get columnwise information 189 | column_count = len(table_breakdown[0]) 190 | if len(table_breakdown[-1]) < column_count: 191 | table_breakdown[-1].extend( 192 | [None for _ in range(column_count - len(table_breakdown[-1]))] 193 | ) 194 | table_breakdown = list(map(list, zip(*table_breakdown))) 195 | 196 | # Third, use table_breakdown to fill in columnwise_HiMulConET_Coeff 197 | for table_col_id, table_col in enumerate(table_breakdown): 198 | for table_row_id, table_cell in enumerate(table_col): 199 | if table_cell is None or len(table_cell) == 2: 200 | continue 201 | 202 | if table_cell[1] > 1: # this is colspan 203 | colspan_cnt = table_cell[1] 204 | span_cell_bbox_indices = table_cell[2] 205 | colspan_HiMulConET_Coeff[ 206 | torch.tensor(span_cell_bbox_indices).unsqueeze(1), 207 | span_cell_bbox_indices, 208 | ] = 1 209 | 210 | for colspan_id in range(colspan_cnt): 211 | tmp_table_col = table_breakdown[table_col_id + colspan_id] 212 | for tmp_row_id, tmp_cell in enumerate(tmp_table_col): 213 | if tmp_cell is None or len(tmp_cell) == 2: 214 | continue 215 | if tmp_row_id == table_row_id: 216 | continue 217 | # cells in a column within the colspan 218 | num_cols_remaining = colspan_cnt - colspan_id 219 | col_overlap_cnt = min(num_cols_remaining, tmp_cell[1]) 220 | src2tar_coeff, tar2src_coeff = eval_colspan_coeff( 221 | colspan_id, 222 | src_col_span_cnt=colspan_cnt, 223 | tar_col_span_cnt=tmp_cell[1], 224 | col_overlap_cnt=col_overlap_cnt, 225 | mode=colspan_coeff_mode, 226 | ) 227 | colspan_HiMulConET_Coeff[ 228 | torch.tensor(span_cell_bbox_indices).unsqueeze(1), 229 | tmp_cell[2], 230 | ] = src2tar_coeff 231 | colspan_HiMulConET_Coeff[ 232 | torch.tensor(tmp_cell[2]).unsqueeze(1), 233 | span_cell_bbox_indices, 234 | ] = tar2src_coeff 235 | else: 236 | columnwise_HiMulConET_Coeff[table_cell[2]] = table_col_id + 1 237 | 238 | # Fourth, get coeff matrix 239 | columnwise_HiMulConET_Coeff = columnwise_HiMulConET_Coeff.unsqueeze(-1) 240 | match_index = torch.eq(columnwise_HiMulConET_Coeff, columnwise_HiMulConET_Coeff.T) 241 | match_index[columnwise_HiMulConET_Coeff.squeeze(-1) == 0, :] = 0 242 | match_index = match_index.type(torch.float) 243 | 244 | # Lastly, combine match_index and colspan_HiMulConET_Coeff 245 | match_index = match_index + colspan_HiMulConET_Coeff 246 | 247 | return match_index 248 | 249 | 250 | # ----------------- Auxiliary Functions -----------------# 251 | def breakdown_otsl_seq(tok_input_ids, tokenizer): 252 | """Convert OTSL sequence into a list of lists, where each list is a row of the table 253 | 254 | NOTE: 255 | # of list matches number of rows, and # of elements in each list matches number of columns 256 | Each element is either None or [num_rows, num_cols] 257 | 258 | Args: 259 | tok_input_ids (list): list of token ids 260 | tokenizer (transformers.tokenization_bert.BertTokenizer): tokenizer 261 | 262 | Returns: 263 | otsl_full_compilation (list): list of lists 264 | - each list is a row of the table 265 | - each element is either None or [num_rows, num_cols] 266 | """ 267 | otsl_token_list = tokenizer.convert_ids_to_tokens(tok_input_ids) 268 | otsl_full_compilation, otsl_row_compilation = [], [] 269 | curr_column_index = 0 270 | for tok_ind, tok in enumerate(otsl_token_list): 271 | if tok == "C-tag": 272 | otsl_row_compilation.append([1, 1]) # [num_rows, num_cols] 273 | curr_column_index += 1 274 | elif tok == "NL-tag": 275 | otsl_full_compilation.append(otsl_row_compilation) 276 | otsl_row_compilation = [] 277 | curr_column_index = 0 278 | elif tok == "L-tag": 279 | for col_i in range(len(otsl_row_compilation)): 280 | # traverse backwards 281 | if otsl_row_compilation[-1 - col_i] is not None: 282 | otsl_row_compilation[-1 - col_i][1] += 1 283 | break 284 | otsl_row_compilation.append(None) 285 | curr_column_index += 1 286 | elif tok == "U-tag": 287 | for row_i in range(len(otsl_full_compilation)): 288 | # traverse backwards 289 | if otsl_full_compilation[-1 - row_i][curr_column_index] is not None: 290 | otsl_full_compilation[-1 - row_i][curr_column_index][0] += 1 291 | break 292 | otsl_row_compilation.append(None) 293 | curr_column_index += 1 294 | elif tok == "X-tag": 295 | otsl_row_compilation.append(None) 296 | curr_column_index += 1 297 | else: 298 | continue 299 | 300 | if len(otsl_row_compilation) > 0: 301 | otsl_full_compilation.append(otsl_row_compilation) 302 | 303 | return otsl_full_compilation 304 | 305 | 306 | def eval_rowspan_coeff( 307 | curr_row_id, src_row_span_cnt, tar_row_span_cnt, row_overlap_cnt, mode="constant" 308 | ): 309 | # sanity check 310 | assert mode in ["constant", "proportional"] 311 | 312 | if mode == "constant": 313 | return 1, 1 314 | elif mode == "proportional": 315 | src2tar_coeff = row_overlap_cnt / src_row_span_cnt 316 | tar2src_coeff = row_overlap_cnt / tar_row_span_cnt 317 | coeff_val = src2tar_coeff * tar2src_coeff 318 | return coeff_val, coeff_val 319 | 320 | else: 321 | raise NotImplementedError 322 | 323 | 324 | def eval_colspan_coeff( 325 | curr_col_id, src_col_span_cnt, tar_col_span_cnt, col_overlap_cnt, mode="constant" 326 | ): 327 | # sanity check 328 | assert mode in ["constant", "proportional"] 329 | 330 | if mode == "constant": 331 | return 1, 1 332 | elif mode == "proportional": 333 | src2tar_coeff = col_overlap_cnt / src_col_span_cnt 334 | tar2src_coeff = col_overlap_cnt / tar_col_span_cnt 335 | coeff_val = src2tar_coeff * tar2src_coeff 336 | return coeff_val, coeff_val 337 | 338 | else: 339 | raise NotImplementedError 340 | -------------------------------------------------------------------------------- /tflop/datamodule/preprocess/image_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import PIL 4 | from PIL import ImageOps 5 | import numpy as np 6 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 7 | import torch 8 | from torchvision import transforms 9 | from torchvision.transforms.functional import resize 10 | 11 | 12 | def convert_PIL_to_tensor( 13 | input_image: PIL.Image.Image, normalize_channels: bool = True 14 | ): 15 | """ 16 | Converts a PIL Image to a PyTorch tensor and optionally normalizes the channels. 17 | 18 | Args: 19 | input_image (PIL.Image.Image): The input image in PIL format. 20 | normalize_channels (bool, optional): If True, normalizes the channels using the ImageNet 21 | default mean and standard deviation. Defaults to True. 22 | 23 | Returns: 24 | torch.Tensor: The converted image as a PyTorch tensor. 25 | """ 26 | tensor_img = transforms.ToTensor()(input_image) 27 | if normalize_channels: 28 | mean_tensor = torch.tensor(IMAGENET_DEFAULT_MEAN).view(3, 1, 1) 29 | std_tensor = torch.tensor(IMAGENET_DEFAULT_STD).view(3, 1, 1) 30 | tensor_img = (tensor_img - mean_tensor) / std_tensor 31 | 32 | return tensor_img 33 | 34 | 35 | def prepare_image_tensor( 36 | input_image: PIL.Image.Image, 37 | target_img_size: tuple, 38 | random_padding: bool = False, 39 | normalize_channels: bool = True, 40 | ): 41 | """ 42 | Prepares an image tensor by resizing, padding, and converting a PIL Image to a PyTorch tensor. 43 | 44 | Args: 45 | input_image (PIL.Image.Image): The input image in PIL format. 46 | target_img_size (tuple): The target size of the image as a tuple (width, height). 47 | random_padding (bool, optional): If True, applies random padding to the image. If False, 48 | centers the image with padding. Defaults to False. 49 | normalize_channels (bool, optional): If True, normalizes the channels using the ImageNet 50 | default mean and standard deviation. Defaults to True. 51 | 52 | Returns: 53 | torch.Tensor: The prepared image as a PyTorch tensor. 54 | tuple: The original size of the input image as a tuple (width, height). 55 | tuple: The padding dimensions applied to the image as a tuple (left, top, right, bottom). 56 | """ 57 | original_size = input_image.size 58 | 59 | # Resize image 60 | target_width, target_height = target_img_size 61 | try: 62 | resized_img = resize( 63 | input_image.convert("RGB"), min(target_img_size) 64 | ) # Resized with smaller edge = min(target_img_size) 65 | except: 66 | print("Error resizing image: ", input_image.filename) 67 | raise ValueError("Error in resizing image.") 68 | resized_img.thumbnail( 69 | size=(target_width, target_height) 70 | ) # NOTE: thumbnail size is (width, height) 71 | 72 | # Pad image 73 | delta_width = target_width - resized_img.size[0] 74 | delta_height = target_height - resized_img.size[1] 75 | if random_padding: 76 | pad_width = np.random.randint(low=0, high=(delta_width + 1)) 77 | pad_height = np.random.randint(low=0, high=(delta_height + 1)) 78 | else: 79 | # Center image if not random padding 80 | pad_width = delta_width // 2 81 | pad_height = delta_height // 2 82 | padding_dims = ( 83 | pad_width, 84 | pad_height, 85 | delta_width - pad_width, 86 | delta_height - pad_height, 87 | ) 88 | padded_img = ImageOps.expand(resized_img, padding_dims) 89 | 90 | # Convert to tensor 91 | tensor_img = convert_PIL_to_tensor( 92 | padded_img, normalize_channels=normalize_channels 93 | ) 94 | 95 | return tensor_img, original_size, padding_dims 96 | -------------------------------------------------------------------------------- /tflop/evaluator.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is based on the TED metric implementation code from the PubTabNet GitHub repository: 3 | https://github.com/ibm-aur-nlp/PubTabNet/blob/master/src/metric.py 4 | """ 5 | 6 | from collections import deque 7 | 8 | from apted import APTED, Config 9 | from apted.helpers import Tree 10 | import distance 11 | from lxml import etree, html 12 | 13 | 14 | class TableTree(Tree): 15 | """ 16 | A representation of a tree structure used for computing edit distances between tables. 17 | 18 | Args: 19 | tag (str): The HTML tag of the node (e.g., "td", "tr", etc.). 20 | colspan (int, optional): The colspan attribute of the cell. Defaults to None. 21 | rowspan (int, optional): The rowspan attribute of the cell. Defaults to None. 22 | content (list, optional): Tokenized content of the cell. Defaults to None. 23 | *children: Child nodes of the current node. 24 | 25 | Methods: 26 | bracket(): 27 | Returns a string representation of the tree in bracket notation. 28 | """ 29 | 30 | def __init__(self, tag, colspan=None, rowspan=None, content=None, *children): 31 | self.tag = tag 32 | self.colspan = colspan 33 | self.rowspan = rowspan 34 | self.content = content 35 | self.children = list(children) 36 | 37 | def bracket(self): 38 | """ 39 | Returns a string representation of the tree using bracket notation. 40 | 41 | Returns: 42 | str: Bracketed representation of the tree. 43 | """ 44 | if self.tag == "td": 45 | result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % ( 46 | self.tag, 47 | self.colspan, 48 | self.rowspan, 49 | self.content, 50 | ) 51 | else: 52 | result = '"tag": %s' % self.tag 53 | for child in self.children: 54 | result += child.bracket() 55 | return "{{{}}}".format(result) 56 | 57 | 58 | class CustomConfig(Config): 59 | """ 60 | Custom configuration for APTED (Tree Edit Distance). 61 | 62 | Methods: 63 | maximum(*sequences): 64 | Returns the maximum possible value of a sequence. 65 | normalized_distance(*sequences): 66 | Computes a normalized Levenshtein distance between sequences. 67 | rename(node1, node2): 68 | Compares two nodes based on their attributes and content. 69 | """ 70 | 71 | @staticmethod 72 | def maximum(*sequences): 73 | """ 74 | Returns the maximum length among the given sequences. 75 | 76 | Args: 77 | *sequences: A variable number of sequences. 78 | 79 | Returns: 80 | int: The maximum length. 81 | """ 82 | return max(map(len, sequences)) 83 | 84 | def normalized_distance(self, *sequences): 85 | """ 86 | Computes a normalized Levenshtein distance between sequences. 87 | 88 | Args: 89 | *sequences: A variable number of sequences. 90 | 91 | Returns: 92 | float: A normalized distance between 0 and 1. 93 | """ 94 | return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) 95 | 96 | def rename(self, node1, node2): 97 | """ 98 | Compares attributes and content of two tree nodes. 99 | 100 | Args: 101 | node1 (TableTree): The first node. 102 | node2 (TableTree): The second node. 103 | 104 | Returns: 105 | float: The cost of renaming the nodes. Returns 0.0 if identical, 1.0 otherwise. 106 | """ 107 | if ( 108 | (node1.tag != node2.tag) 109 | or (node1.colspan != node2.colspan) 110 | or (node1.rowspan != node2.rowspan) 111 | ): 112 | return 1.0 113 | if node1.tag == "td": 114 | if node1.content or node2.content: 115 | return self.normalized_distance(node1.content, node2.content) 116 | return 0.0 117 | 118 | 119 | class TEDS: 120 | """ 121 | Tree Edit Distance-based Similarity (TEDS) for evaluating table structure similarity. 122 | 123 | Args: 124 | structure_only (bool): If True, evaluates only the table structure, ignoring content. 125 | n_jobs (int): Number of parallel jobs for computation. Defaults to 1. 126 | ignore_nodes (list, optional): List of node tags to ignore during evaluation. 127 | 128 | Methods: 129 | tokenize(node): 130 | Tokenizes the text and structure of an HTML node. 131 | load_html_tree(node, parent=None): 132 | Converts an HTML tree into a tree structure compatible with APTED. 133 | evaluate(pred, true): 134 | Computes the TEDS score between predicted and ground truth table structures. 135 | """ 136 | 137 | def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None): 138 | assert isinstance(n_jobs, int) and ( 139 | n_jobs >= 1 140 | ), "n_jobs must be an integer greater than 1" 141 | self.structure_only = structure_only 142 | self.n_jobs = n_jobs 143 | self.ignore_nodes = ignore_nodes 144 | self.__tokens__ = [] 145 | 146 | def tokenize(self, node): 147 | """ 148 | Tokenizes an HTML node and its content into a list of tokens. 149 | 150 | Args: 151 | node (lxml.etree.Element): The HTML node to tokenize. 152 | """ 153 | self.__tokens__.append("<%s>" % node.tag) 154 | if node.text is not None: 155 | self.__tokens__ += list(node.text) 156 | for n in node.getchildren(): 157 | self.tokenize(n) 158 | if node.tag != "unk": 159 | self.__tokens__.append("" % node.tag) 160 | if node.tag != "td" and node.tail is not None: 161 | self.__tokens__ += list(node.tail) 162 | 163 | def load_html_tree(self, node, parent=None): 164 | """ 165 | Converts an HTML tree into a TableTree structure for APTED computation. 166 | 167 | Args: 168 | node (lxml.etree.Element): The root HTML node. 169 | parent (TableTree, optional): The parent node in the TableTree structure. 170 | 171 | Returns: 172 | TableTree: The converted TableTree structure. 173 | """ 174 | if node.tag == "td": 175 | if self.structure_only: 176 | cell = [] 177 | else: 178 | self.__tokens__ = [] 179 | self.tokenize(node) 180 | cell = self.__tokens__[1:-1].copy() 181 | new_node = TableTree( 182 | node.tag, 183 | int(node.attrib.get("colspan", "1")), 184 | int(node.attrib.get("rowspan", "1")), 185 | cell, 186 | *deque(), 187 | ) 188 | else: 189 | new_node = TableTree(node.tag, None, None, None, *deque()) 190 | if parent is not None: 191 | parent.children.append(new_node) 192 | if node.tag != "td": 193 | for n in node.getchildren(): 194 | self.load_html_tree(n, new_node) 195 | if parent is None: 196 | return new_node 197 | 198 | def evaluate(self, pred, true): 199 | """ 200 | Computes the TEDS score between a predicted table and a ground truth table. 201 | 202 | Args: 203 | pred (str): HTML string of the predicted table. 204 | true (str): HTML string of the ground truth table. 205 | 206 | Returns: 207 | float: TEDS score between 0.0 and 1.0, where 1.0 indicates perfect similarity. 208 | """ 209 | if (not pred) or (not true): 210 | return 0.0 211 | parser = html.HTMLParser(remove_comments=True, encoding="utf-8") 212 | pred = html.fromstring(pred, parser=parser) 213 | true = html.fromstring(true, parser=parser) 214 | 215 | if pred.xpath("body/table") and true.xpath("body/table"): 216 | pred = pred.xpath("body/table")[0] 217 | true = true.xpath("body/table")[0] 218 | if self.ignore_nodes: 219 | etree.strip_tags(pred, *self.ignore_nodes) 220 | etree.strip_tags(true, *self.ignore_nodes) 221 | n_nodes_pred = len(pred.xpath(".//*")) 222 | n_nodes_true = len(true.xpath(".//*")) 223 | n_nodes = max(n_nodes_pred, n_nodes_true) 224 | tree_pred = self.load_html_tree(pred) 225 | tree_true = self.load_html_tree(true) 226 | distance = APTED( 227 | tree_pred, tree_true, CustomConfig() 228 | ).compute_edit_distance() 229 | return 1.0 - (float(distance) / n_nodes) 230 | else: 231 | return 0.0 232 | -------------------------------------------------------------------------------- /tflop/lightning_module/lightning_module.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | from pathlib import Path 5 | import random 6 | 7 | from Levenshtein import distance 8 | import numpy as np 9 | import omegaconf 10 | import pytorch_lightning as pl 11 | from pytorch_lightning.utilities import rank_zero_only 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.nn.utils.rnn import pad_sequence 15 | from torch.optim.lr_scheduler import LambdaLR 16 | from torch.utils.data import DataLoader 17 | from transformers import PreTrainedTokenizer 18 | 19 | from tflop.model.model.TFLOP import TFLOP 20 | from tflop.model.model.TFLOP_Config import TFLOPConfig 21 | from tflop.utils import custom_format_html, resolve_missing_config 22 | 23 | 24 | class TFLOPModelPLModule(pl.LightningModule): 25 | def __init__( 26 | self: "TFLOPModelPLModule", config, tokenizer: PreTrainedTokenizer, mode: str 27 | ): 28 | super().__init__() 29 | self.config = config 30 | self.tokenizer = tokenizer 31 | self.mode = mode 32 | 33 | model_config_dict = { 34 | k: v 35 | for k, v in self.config.items() 36 | if k in TFLOPConfig.get_member_variables() 37 | } 38 | model_config_dict = resolve_missing_config(model_config_dict) 39 | 40 | # Set-up data tag ids 41 | if not self.config.use_OTSL: 42 | raise NotImplementedError("Non-OTSL mode is deprecated") 43 | data_ids = ["C-tag"] 44 | 45 | # Set-up model 46 | self.model = TFLOP( 47 | config=TFLOPConfig(**model_config_dict), 48 | tokenizer=self.tokenizer, 49 | data_ids=data_ids, 50 | ) 51 | self.load_pretrained_weights() 52 | 53 | def training_step(self, batch, batch_idx): 54 | """Training step""" 55 | 56 | # Sanity check -- Pointer decoder is always used in TFLOP 57 | assert self.config.use_ptr_decoder, "Pointer decoder is always used in TFLOP" 58 | 59 | # Forward pass 60 | model_output = self.pointer_regular_train_forward(batch) 61 | 62 | # Losses 63 | loss = model_output.loss 64 | token_cls_loss = model_output.token_cls_loss 65 | tag2coord_pointer_loss = model_output.tag2coord_pointer_loss 66 | tag2coord_pointer_acc = model_output.tag2coord_pointer_acc 67 | 68 | bbox_TableCL_loss = model_output.bbox_TableCL_loss 69 | rowwise_loss = model_output.rowwise_loss 70 | colwise_loss = model_output.colwise_loss 71 | 72 | self.log_dict({"train_loss": loss}, sync_dist=True) 73 | self.log_dict({"train_token_cls_loss": token_cls_loss}, sync_dist=True) 74 | self.log_dict( 75 | {"train_tag2coord_pointer_loss": tag2coord_pointer_loss}, sync_dist=True 76 | ) 77 | self.log_dict( 78 | {"train_tag2coord_pointer_acc": tag2coord_pointer_acc}, sync_dist=True 79 | ) 80 | self.log_dict({"train_bbox_TableCL_loss": bbox_TableCL_loss}, sync_dist=True) 81 | self.log_dict({"train_rowwise_loss": rowwise_loss}, sync_dist=True) 82 | self.log_dict({"train_colwise_loss": colwise_loss}, sync_dist=True) 83 | 84 | return loss 85 | 86 | def on_validation_epoch_start(self) -> None: 87 | """Prepare for validation step""" 88 | 89 | super().on_validation_epoch_start() 90 | self.validation_step_outputs = [[]] 91 | return 92 | 93 | def validation_step(self, batch, batch_idx, dataloader_idx=0): 94 | """Validation step""" 95 | 96 | # Sanity check -- Pointer decoder is always used in TFLOP 97 | assert self.config.use_ptr_decoder, "Pointer decoder is always used in TFLOP" 98 | 99 | # Forward pass 100 | preds, answers, html_with_content, cell_texts = ( 101 | self.pointer_regular_validation_forward(batch) 102 | ) 103 | 104 | # Get html seq with content 105 | pred_collection = [] 106 | for data_i in range(preds["output_sequences"].shape[0]): 107 | token_id_seq = preds["output_sequences"][data_i] 108 | token_seq = self.tokenizer.convert_ids_to_tokens(token_id_seq) 109 | 110 | if self.config.get("use_OTSL", False): 111 | output_seq_tokens = [] 112 | for token_pred in token_seq: 113 | if token_pred == "▁": 114 | token_to_add = " " 115 | else: 116 | token_to_add = token_pred.replace("▁", "") 117 | output_seq_tokens.append(token_to_add) 118 | output_seq_tokens = "".join(output_seq_tokens) 119 | else: 120 | raise NotImplementedError("Non-OTSL mode is deprecated") 121 | 122 | pred_collection.append(output_seq_tokens) 123 | 124 | # Get scores 125 | scores = [] 126 | for pred, answer in zip(pred_collection, answers): 127 | score_set = [] 128 | 129 | pred_string, refined_pred = custom_format_html(pred, self.tokenizer) 130 | answer_string, refined_gold = custom_format_html(answer, self.tokenizer) 131 | 132 | score_set.append( 133 | distance(pred_string, answer_string) 134 | / max(len(pred_string), len(answer_string)) 135 | ) 136 | 137 | ted_score_structure_only, ted_score_full = 0.0, 0.0 138 | score_set.append(ted_score_structure_only) 139 | score_set.append(ted_score_full) 140 | 141 | scores.append(score_set) 142 | 143 | self.validation_step_outputs[dataloader_idx].append(scores) 144 | 145 | return scores 146 | 147 | def on_validation_epoch_end(self): 148 | """Validation epoch end""" 149 | # Sanity check 150 | assert len(self.validation_step_outputs) == 1 151 | cnt, edit_dist_metric, ted_no_support_metric, ted_metric, val_metric = ( 152 | [0], 153 | [0], 154 | [0], 155 | [0], 156 | [0], 157 | ) 158 | 159 | for scores in self.validation_step_outputs[0]: 160 | cnt[0] += len(scores) 161 | 162 | edit_dist_metric[0] += np.sum([x[0] for x in scores]) 163 | ted_no_support_metric[0] += np.sum([x[1] for x in scores]) 164 | ted_metric[0] += np.sum([x[2] for x in scores]) 165 | 166 | val_metric[0] = edit_dist_metric[0] / cnt[0] 167 | val_metric_name = f"val_metric_{0}th_dataset" 168 | self.log_dict({val_metric_name: val_metric[0]}, sync_dist=True) 169 | 170 | self.log_dict( 171 | { 172 | "val_metric": np.sum(edit_dist_metric) / np.sum(cnt), 173 | "ted_no_support": np.sum(ted_no_support_metric) / np.sum(cnt), 174 | "ted": np.sum(ted_metric) / np.sum(cnt), 175 | }, 176 | sync_dist=True, 177 | ) 178 | 179 | def configure_optimizers(self): 180 | """Prepare optimizer and scheduler""" 181 | 182 | max_iter = self.config.max_steps 183 | assert max_iter is not None 184 | optimizer = torch.optim.Adam(self.parameters(), lr=self.config.lr) 185 | 186 | # set warmup steps to 2% of max_iter 187 | num_warmup_steps = int(max_iter * 0.02) # ~5K if max_iter=250K 188 | scheduler = { 189 | "scheduler": self.cosine_scheduler(optimizer, max_iter, num_warmup_steps), 190 | "name": "learning_rate", 191 | "interval": "step", 192 | } 193 | return [optimizer], [scheduler] 194 | 195 | @staticmethod 196 | def cosine_scheduler(optimizer, training_steps, warmup_steps): 197 | """Create cosine scheduler with warmup""" 198 | 199 | def lr_lambda(current_step): 200 | if current_step < warmup_steps: 201 | return current_step / max(1, warmup_steps) 202 | progress = current_step - warmup_steps 203 | progress /= max(1, training_steps - warmup_steps) 204 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) 205 | 206 | return LambdaLR(optimizer, lr_lambda) 207 | 208 | @rank_zero_only 209 | def on_save_checkpoint(self, checkpoint): 210 | """Save model and tokenizer""" 211 | save_path = ( 212 | Path(self.config.result_path) 213 | / self.config.exp_name 214 | / self.config.exp_version 215 | ) 216 | save_path = save_path / ( 217 | "epoch_%s_step_%s" % (self.current_epoch, self.global_step) 218 | ) 219 | self.model.save_pretrained(save_path) 220 | self.model.decoder.tokenizer.save_pretrained(save_path) 221 | 222 | def load_pretrained_weights(self): 223 | """Load pretrained weights if available""" 224 | 225 | if self.config.get("pretrained_model_name_or_path", False): 226 | loaded_state_dict = torch.load( 227 | os.path.join( 228 | self.config.pretrained_model_name_or_path, "pytorch_model.bin" 229 | ) 230 | ) 231 | saved_config = json.load( 232 | open( 233 | os.path.join( 234 | self.config.pretrained_model_name_or_path, "config.json" 235 | ), 236 | "r", 237 | ) 238 | ) 239 | 240 | # First adjust saved state_dict to match that of current model, then load_state_dict 241 | 242 | # 1. truncate or interplolate position embeddings of donut decoder 243 | if self.config.max_length != saved_config["max_length"]: 244 | print( 245 | "NOTE: max_length of pretrained model differs max_length you want to train" 246 | ) 247 | weight_tensor = self.model.decoder.resize_bart_abs_pos_emb( 248 | loaded_state_dict[ 249 | "decoder.model.model.decoder.embed_positions.weight" 250 | ], 251 | self.config.max_length 252 | + 2, # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119 253 | ) 254 | weight_tensor = weight_tensor.contiguous() 255 | loaded_state_dict[ 256 | "decoder.model.model.decoder.embed_positions.weight" 257 | ] = weight_tensor 258 | 259 | # 2. adjust swin encoder if window size mismatch 260 | if type(self.config.input_size) == omegaconf.dictconfig.DictConfig: 261 | input_size_mismatch = [ 262 | self.config.input_size["width"], 263 | self.config.input_size["height"], 264 | ] != saved_config["input_size"] 265 | else: 266 | input_size_mismatch = ( 267 | self.config.input_size != saved_config["input_size"] 268 | ) 269 | window_size_mismatch = ( 270 | self.config.window_size != saved_config["window_size"] 271 | ) 272 | if input_size_mismatch or window_size_mismatch: 273 | print( 274 | "NOTE: input_size or window_size of pretrained model differs input_size or window_size you want to train" 275 | ) 276 | 277 | curr_state_dict = self.model.encoder.state_dict() 278 | for x in curr_state_dict: 279 | if x.endswith("relative_position_index") or x.endswith("attn_mask"): 280 | pass 281 | elif ( 282 | x.endswith("relative_position_bias_table") 283 | and self.model.encoder.model.layers[0] 284 | .blocks[0] 285 | .attn.window_size[0] 286 | != saved_config["window_size"] 287 | ): 288 | pos_bias = loaded_state_dict["encoder." + x].unsqueeze(0)[0] 289 | old_len = int(math.sqrt(len(pos_bias))) 290 | new_len = int(2 * self.config.window_size - 1) 291 | pos_bias = pos_bias.reshape(1, old_len, old_len, -1).permute( 292 | 0, 3, 1, 2 293 | ) 294 | pos_bias = F.interpolate( 295 | pos_bias, 296 | size=(new_len, new_len), 297 | mode="bicubic", 298 | align_corners=False, 299 | ) 300 | curr_state_dict[x] = ( 301 | pos_bias.permute(0, 2, 3, 1) 302 | .reshape(1, new_len**2, -1) 303 | .squeeze(0) 304 | ) 305 | else: 306 | curr_state_dict[x] = loaded_state_dict["encoder." + x] 307 | 308 | for swin_enc_key in curr_state_dict.keys(): 309 | loaded_state_dict["encoder." + swin_enc_key] = curr_state_dict[ 310 | swin_enc_key 311 | ] 312 | 313 | # Now, load state dict 314 | encoder_state_dicts = { 315 | k[len("encoder.") :]: v 316 | for k, v in loaded_state_dict.items() 317 | if k.startswith("encoder.") 318 | } 319 | decoder_state_dicts = { 320 | k[len("decoder.") :]: v 321 | for k, v in loaded_state_dict.items() 322 | if k.startswith("decoder.") 323 | } 324 | 325 | # But first remove size_mismatched keys 326 | tmp_current_encoder_statedict = self.model.encoder.state_dict() 327 | tmp_current_decoder_statedict = self.model.decoder.state_dict() 328 | 329 | encoder_size_mismatched_keys = [] 330 | encoder_keys_to_be_deleted = [] 331 | for encoder_key in encoder_state_dicts.keys(): 332 | if ( 333 | encoder_key in tmp_current_encoder_statedict 334 | and tmp_current_encoder_statedict[encoder_key].shape 335 | != encoder_state_dicts[encoder_key].shape 336 | ): 337 | encoder_size_mismatched_keys.append( 338 | [ 339 | encoder_key, 340 | tmp_current_encoder_statedict[encoder_key].shape, 341 | encoder_state_dicts[encoder_key].shape, 342 | ] 343 | ) 344 | encoder_keys_to_be_deleted.append(encoder_key) 345 | 346 | decoder_size_mismatched_keys = [] 347 | decoder_keys_to_be_deleted = [] 348 | for decoder_key in decoder_state_dicts.keys(): 349 | if ( 350 | decoder_key in tmp_current_decoder_statedict 351 | and tmp_current_decoder_statedict[decoder_key].shape 352 | != decoder_state_dicts[decoder_key].shape 353 | ): 354 | decoder_size_mismatched_keys.append( 355 | [ 356 | decoder_key, 357 | tmp_current_decoder_statedict[decoder_key].shape, 358 | decoder_state_dicts[decoder_key].shape, 359 | ] 360 | ) 361 | 362 | decoder_keys_to_be_deleted.append(decoder_key) 363 | 364 | encoder_state_dicts = { 365 | k: v 366 | for k, v in encoder_state_dicts.items() 367 | if k not in encoder_keys_to_be_deleted 368 | } 369 | decoder_state_dicts = { 370 | k: v 371 | for k, v in decoder_state_dicts.items() 372 | if k not in decoder_keys_to_be_deleted 373 | } 374 | 375 | encoder_missing_keys, encoder_unexpected_keys = ( 376 | self.model.encoder.load_state_dict(encoder_state_dicts, strict=False) 377 | ) 378 | decoder_missing_keys, decoder_unexpected_keys = ( 379 | self.model.decoder.load_state_dict(decoder_state_dicts, strict=False) 380 | ) 381 | 382 | print("-----Size Mismatched Keys-----") 383 | print("Encoder:") 384 | if len(encoder_size_mismatched_keys) > 0: 385 | for key in encoder_size_mismatched_keys: 386 | mismatched_keyname, curr_shape, loaded_shape = key 387 | print( 388 | f"{mismatched_keyname}: trying to load: {loaded_shape} -> into: {curr_shape}" 389 | ) 390 | else: 391 | print("None") 392 | print("\nDecoder:") 393 | if len(decoder_size_mismatched_keys) > 0: 394 | for key in decoder_size_mismatched_keys: 395 | mismatched_keyname, curr_shape, loaded_shape = key 396 | print( 397 | f"{mismatched_keyname}: trying to load: {loaded_shape} -> into: {curr_shape}" 398 | ) 399 | else: 400 | print("None") 401 | print("-------------------------------") 402 | print("----------Missing Keys---------") 403 | print("Encoder:") 404 | if len(encoder_missing_keys) > 0: 405 | for key in encoder_missing_keys: 406 | print(key) 407 | else: 408 | print("None") 409 | print("\nDecoder:") 410 | if len(decoder_missing_keys) > 0: 411 | for key in decoder_missing_keys: 412 | print(key) 413 | else: 414 | print("None") 415 | print("-------------------------------") 416 | print("--------Unexpected Keys--------") 417 | print("Encoder:") 418 | if len(encoder_unexpected_keys) > 0: 419 | for key in encoder_unexpected_keys: 420 | print(key) 421 | else: 422 | print("None") 423 | print("\nDecoder:") 424 | if len(decoder_unexpected_keys) > 0: 425 | for key in decoder_unexpected_keys: 426 | print(key) 427 | else: 428 | print("None") 429 | print("-------------------------------") 430 | 431 | def pointer_regular_train_forward(self, batch): 432 | """Forward pass for regular training stage""" 433 | image_tensors = batch[0] # (batch_size, 3, height, width) 434 | decoder_input_ids = batch[1] # (batch_size, text_token_length) 435 | coord_input_idx = batch[2] # (batch_size, bbox_token_length, 4) 436 | coord_input_length = batch[3] # (batch_size,) 437 | decoder_token_labels = batch[4] # (batch_size, text_token_length) 438 | pointer_labels = batch[ 439 | 5 440 | ] # (batch_size, text_token_length - 2, bbox_token_length) 441 | pointer_mask_labels = batch[6] # (batch_size, text_token_length - 2) 442 | bbox_coeff_tensor = batch[ 443 | 7 444 | ] # (batch_size, 5, bbox_token_length, bbox_token_length) 445 | 446 | pointer_args = { 447 | "coord_input_idx": coord_input_idx, 448 | "coord_input_length": coord_input_length, 449 | "pointer_labels": pointer_labels, 450 | "pointer_mask_labels": pointer_mask_labels, 451 | "bbox_coeff_tensor": bbox_coeff_tensor, 452 | } 453 | 454 | model_output = self.model( 455 | image_tensors=image_tensors, 456 | decoder_input_ids=decoder_input_ids, 457 | decoder_labels=decoder_token_labels, 458 | pointer_args=pointer_args, 459 | ) 460 | 461 | return model_output 462 | 463 | def pointer_regular_validation_forward(self, batch): 464 | """Forward pass for regular validation stage""" 465 | image_tensors = batch[0] # (bsz, 3, height, width) 466 | decoder_input_ids = batch[1] # (bsz, text_token_length) 467 | coord_input_idx = batch[2] # (bsz, bbox_token_length, 4) 468 | coord_input_length = batch[3] # (bsz,) 469 | prompt_end_idxs = batch[4] # (bsz,) 470 | answers = batch[5] # list of length==bsz 471 | pointer_labels = batch[6] # (bsz, text_token_length - 2, bbox_token_length) 472 | pointer_mask_labels = batch[7] # (bsz, text_token_length - 2) 473 | html_with_content = batch[8] # list of length==bsz 474 | cell_texts = batch[9] # list of length==bsz 475 | file_names = batch[10] # list of length==bsz 476 | bbox_coeff_tensor = batch[11] # (bsz, 5, bbox_token_length, bbox_token_length) 477 | 478 | pointer_args = { 479 | "coord_input_idx": coord_input_idx, 480 | "coord_input_length": coord_input_length, 481 | "pointer_labels": pointer_labels, 482 | "pointer_mask_labels": pointer_mask_labels, 483 | } 484 | 485 | decoder_prompts = pad_sequence( 486 | [ 487 | input_id[: end_idx + 1] 488 | for input_id, end_idx in zip(decoder_input_ids, prompt_end_idxs) 489 | ], 490 | batch_first=True, 491 | ) 492 | preds = self.model.inference( 493 | image_tensors=image_tensors, 494 | prompt_tensors=decoder_prompts, 495 | return_json=False, 496 | return_attentions=False, 497 | pointer_args=pointer_args, 498 | ) 499 | 500 | return preds, answers, html_with_content, cell_texts 501 | 502 | 503 | class DataPLModule(pl.LightningDataModule): 504 | def __init__(self: "DataPLModule", config): 505 | super().__init__() 506 | self.config = config 507 | 508 | self.train_dataset = None 509 | self.val_dataset = None 510 | self.test_dataset = None 511 | 512 | self.train_batch_size = self.config.train_batch_size 513 | self.val_batch_size = self.config.val_batch_size 514 | self.test_batch_size = self.config.test_batch_size 515 | 516 | self.g = torch.Generator() 517 | self.g.manual_seed(self.config.seed) 518 | 519 | for ds in [self.train_dataset, self.val_dataset, self.test_dataset]: 520 | if ds is not None: 521 | assert type(ds) == torch.utils.data.Dataset 522 | 523 | def train_dataloader(self): 524 | dataloader = DataLoader( 525 | self.train_dataset, 526 | batch_size=self.train_batch_size, 527 | num_workers=self.config.num_workers, 528 | pin_memory=True, 529 | shuffle=True, 530 | ) 531 | return dataloader 532 | 533 | def val_dataloader(self): 534 | dataloader = DataLoader( 535 | self.val_dataset, 536 | batch_size=self.val_batch_size, 537 | num_workers=self.config.num_workers, 538 | pin_memory=True, 539 | shuffle=False, 540 | ) 541 | return dataloader 542 | 543 | def test_dataloader(self): 544 | dataloader = DataLoader( 545 | self.test_dataset, 546 | batch_size=self.test_batch_size, 547 | num_workers=self.config.num_workers, 548 | pin_memory=True, 549 | shuffle=False, 550 | ) 551 | return dataloader 552 | 553 | @staticmethod 554 | def seed_worker(wordker_id): 555 | worker_seed = torch.initial_seed() % 2**32 556 | np.random.seed(worker_seed) 557 | random.seed(worker_seed) 558 | -------------------------------------------------------------------------------- /tflop/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Following code is referenced from the pytorch implementation of the Supervised Contrastive Loss: 3 | https://github.com/HobbitLong/SupContrast 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class TableCL(nn.Module): 11 | """ 12 | A contrastive learning model for table data with masks. 13 | 14 | Attributes: 15 | temperature (float): Temperature scaling factor for the loss function. 16 | sup_con_loss (SupConLoss): Instance of the supervised contrastive loss. 17 | """ 18 | 19 | def __init__(self, temperature=0.1): 20 | """ 21 | Initialize the TableCL module. 22 | 23 | Args: 24 | temperature (float): Temperature scaling factor for the contrastive loss. 25 | """ 26 | super(TableCL, self).__init__() 27 | self.temperature = temperature 28 | self.sup_con_loss = SupConLoss(temperature) 29 | 30 | def forward(self, features, masks, input_coords_length): 31 | """ 32 | Compute the batch loss for the given features and masks. 33 | 34 | Args: 35 | features (torch.Tensor): Feature representations, shape [batch_size, bbox_token_length, d_model]. 36 | masks (torch.Tensor): Masks, shape [batch_size, num_layers, bbox_token_length, bbox_token_length]. 37 | input_coords_length (torch.Tensor): Lengths of input coordinates, shape [batch_size]. 38 | 39 | Returns: 40 | torch.Tensor: Average batch loss. 41 | """ 42 | batch_loss, valid_batch_size = 0, 0 43 | batch_size, _, _, _ = masks.shape 44 | for data_i in range(batch_size): 45 | selected_mask = masks[data_i][ 46 | :, 47 | : (input_coords_length[data_i] + 1), 48 | : (input_coords_length[data_i] + 1), 49 | ] # [1, bbox_tok_cnt + 1, bbox_tok_cnt + 1] 50 | assert selected_mask.shape[0] == 1 51 | selected_feature = features[data_i][ 52 | : (input_coords_length[data_i] + 1) 53 | ] # [bbox_tok_cnt + 1, d_model] 54 | selected_feature = selected_feature.unsqueeze( 55 | 0 56 | ) # [1, bbox_tok_cnt + 1, d_model] 57 | 58 | batch_loss += self.sup_con_loss(selected_feature, mask=selected_mask) 59 | 60 | # check if the data is valid 61 | float_selected_mask = selected_mask[0].to( 62 | torch.float 63 | ) # [bbox_tok_cnt + 1, bbox_tok_cnt + 1] 64 | sanity_tensor = torch.eye( 65 | float_selected_mask.shape[0], 66 | dtype=float_selected_mask.dtype, 67 | device=float_selected_mask.device, 68 | ) 69 | sanity_tensor[0, 0] = 0 70 | if torch.sum(float_selected_mask != sanity_tensor) != 0: 71 | valid_batch_size += 1 72 | 73 | valid_batch_size = max(valid_batch_size, 1) 74 | batch_loss = batch_loss / valid_batch_size 75 | 76 | return batch_loss 77 | 78 | 79 | class SupConLoss(nn.Module): 80 | """ 81 | A PyTorch implementation of a modified version of Supervised Contrastive Loss. 82 | 83 | Args: 84 | temperature (float): Temperature scaling factor for contrastive loss. Default is 0.1. 85 | 86 | Methods: 87 | forward(features, mask): 88 | Computes the modified supervised contrastive loss for the given features and mask. 89 | """ 90 | 91 | def __init__(self, temperature=0.1): 92 | super(SupConLoss, self).__init__() 93 | self.temperature = temperature 94 | 95 | def forward(self, features, mask): 96 | """ 97 | Forward pass to compute the supervised contrastive loss. 98 | 99 | Args: 100 | features (torch.Tensor): Feature representations, shape [batch_size, bbox_token_length, d_model]. 101 | masks (torch.Tensor): Masks, shape [batch_size, num_layers, bbox_token_length, bbox_token_length]. 102 | 103 | Returns: 104 | torch.Tensor: A scalar tensor representing the computed contrastive loss. 105 | """ 106 | batch_size, bbox_token_length, d_model = features.shape 107 | 108 | # compute logits 109 | dot_contrast = torch.div( 110 | torch.matmul(features, features.transpose(1, 2)), self.temperature 111 | ) # [batch_size, bbox_token_length, bbox_token_length] 112 | 113 | # for numerical stability 114 | logits_max, _ = torch.max( 115 | dot_contrast, dim=-1, keepdim=True 116 | ) # [batch_size, bbox_token_length, 1] 117 | logits = ( 118 | dot_contrast - logits_max.detach() 119 | ) # [batch_size, bbox_token_length, bbox_token_length] 120 | 121 | # logits mask (diagonal) 122 | logits_mask = 1 - torch.eye( 123 | bbox_token_length, dtype=logits.dtype, device=logits.device 124 | ).unsqueeze( 125 | 0 126 | ) # [1, bbox_token_length, bbox_token_length] 127 | 128 | # compute log_prob 129 | exp_logits = torch.exp(logits) * logits_mask 130 | 131 | negative_mask = 1 - mask 132 | negative_mask[negative_mask < 1] = 0 133 | negative_denom = torch.sum( 134 | exp_logits * negative_mask, dim=-1, keepdim=True 135 | ) # [batch_size, bbox_token_length, 1] 136 | 137 | positive_mask = mask.clone() 138 | positive_mask[positive_mask > 0] = 1 139 | positive_denom = torch.sum( 140 | exp_logits * positive_mask, dim=-1, keepdim=True 141 | ) # [batch_size, bbox_token_length, 1] 142 | 143 | denominator = negative_denom + positive_denom.detach() + 1e-6 144 | log_prob = logits - torch.log(denominator) 145 | 146 | # compute mean of log-likelihood over positive 147 | mask = mask * logits_mask 148 | mean_log_prob_pos = (mask * log_prob).sum(-1) / ( 149 | mask.sum(-1) + 1e-6 150 | ) # [batch_size, bbox_token_length] 151 | 152 | # loss 153 | loss = -1 * mean_log_prob_pos.mean() 154 | 155 | return loss 156 | -------------------------------------------------------------------------------- /tflop/model/decoder/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch MBART model.""" 16 | from functools import partial 17 | import random 18 | from typing import Any, Callable, Optional, Tuple, Union 19 | 20 | import torch 21 | from torch import nn 22 | import torch.utils.checkpoint 23 | from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions 24 | from transformers.models.mbart.modeling_mbart import ( 25 | MBartAttention, 26 | MBartDecoder, 27 | MBartForCausalLM, 28 | _expand_mask, 29 | logger, 30 | ) 31 | import xformers.ops as xops 32 | 33 | 34 | def apply_fast_mbart_decoder(model: MBartForCausalLM) -> None: 35 | for module in model.model.modules(): 36 | if isinstance(module, MBartDecoder): 37 | module.forward = partial(mbart_decoder_fast_forward, module) 38 | if isinstance(module, MBartAttention): 39 | module.forward = partial(mbart_attention_fast_forward, module) 40 | 41 | 42 | def mbart_attention_fast_forward( 43 | mbart_attention_module: MBartAttention, 44 | hidden_states: torch.Tensor, 45 | key_value_states: Optional[torch.Tensor] = None, 46 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 47 | attention_mask: Optional[torch.Tensor] = None, 48 | layer_head_mask: Optional[torch.Tensor] = None, 49 | output_attentions: bool = False, 50 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 51 | """Input shape: Batch x Time x Channel""" 52 | 53 | # if key_value_states are provided this layer is used as a cross-attention layer 54 | # for the decoder 55 | is_cross_attention = key_value_states is not None 56 | 57 | bsz, tgt_len, _ = hidden_states.size() 58 | 59 | # get query proj 60 | query_states = mbart_attention_module.q_proj(hidden_states) 61 | # get key, value proj 62 | # `past_key_value[0].shape[2] == key_value_states.shape[1]` 63 | # is checking that the `sequence_length` of the `past_key_value` is the same as 64 | # the provided `key_value_states` to support prefix tuning 65 | if ( 66 | is_cross_attention 67 | and past_key_value is not None 68 | and past_key_value[0].shape[2] == key_value_states.shape[1] 69 | ): 70 | # reuse k,v, cross_attentions 71 | key_states = past_key_value[0] 72 | value_states = past_key_value[1] 73 | elif is_cross_attention: 74 | # cross_attentions 75 | key_states = mbart_attention_module._shape( 76 | mbart_attention_module.k_proj(key_value_states), -1, bsz 77 | ) 78 | value_states = mbart_attention_module._shape( 79 | mbart_attention_module.v_proj(key_value_states), -1, bsz 80 | ) 81 | elif past_key_value is not None: 82 | # reuse k, v, self_attention 83 | key_states = mbart_attention_module._shape( 84 | mbart_attention_module.k_proj(hidden_states), -1, bsz 85 | ) 86 | value_states = mbart_attention_module._shape( 87 | mbart_attention_module.v_proj(hidden_states), -1, bsz 88 | ) 89 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 90 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 91 | else: 92 | # self_attention 93 | key_states = mbart_attention_module._shape( 94 | mbart_attention_module.k_proj(hidden_states), -1, bsz 95 | ) 96 | value_states = mbart_attention_module._shape( 97 | mbart_attention_module.v_proj(hidden_states), -1, bsz 98 | ) 99 | 100 | if mbart_attention_module.is_decoder: 101 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 102 | # Further calls to cross_attention layer can then reuse all cross-attention 103 | # key/value_states (first "if" case) 104 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 105 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 106 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 107 | # if encoder bi-directional self-attention `past_key_value` is always `None` 108 | past_key_value = (key_states, value_states) 109 | 110 | proj_shape = ( 111 | bsz, 112 | mbart_attention_module.num_heads, 113 | -1, 114 | mbart_attention_module.head_dim, 115 | ) 116 | query_states = ( 117 | mbart_attention_module._shape(query_states, tgt_len, bsz) 118 | .view(*proj_shape) 119 | .transpose(1, 2) 120 | ) 121 | key_states = key_states.reshape(*proj_shape).transpose(1, 2) 122 | value_states = value_states.reshape(*proj_shape).transpose(1, 2) 123 | 124 | # xformers memory efficient attention with no mask or causal mask 125 | attn_bias = None 126 | if attention_mask is not None: 127 | seq_len = attention_mask.shape[3] 128 | if seq_len % 8 != 0: 129 | strided_seq_len = (seq_len // 8 + 1) * 8 130 | attn_bias = torch.zeros( 131 | attention_mask.size()[:-1] + (strided_seq_len,), 132 | device=attention_mask.device, 133 | dtype=query_states.dtype, 134 | ) 135 | attn_bias[:, :, :, :seq_len] = attention_mask 136 | attn_bias = attn_bias[:, :, :, :seq_len] 137 | else: 138 | attn_bias = attention_mask 139 | if attn_bias.shape[1] == 1: 140 | attn_bias = attn_bias.expand( 141 | -1, query_states.shape[2], -1, -1 142 | ) # expand by num_heads 143 | 144 | # if not is_cross_attention: 145 | # if mbart_attention_module.training: 146 | # attn_bias = xops.LowerTriangularMask() 147 | attn_output = xops.memory_efficient_attention( 148 | query_states, 149 | key_states, 150 | value_states, 151 | p=mbart_attention_module.dropout if mbart_attention_module.training else 0.0, 152 | attn_bias=attn_bias, 153 | ) 154 | 155 | if attn_output.size() != ( 156 | bsz, 157 | tgt_len, 158 | mbart_attention_module.num_heads, 159 | mbart_attention_module.head_dim, 160 | ): 161 | raise ValueError( 162 | "`attn_output` should be of size " 163 | f"{(bsz, tgt_len, mbart_attention_module.num_heads, mbart_attention_module.head_dim)}, " 164 | f"but is {attn_output.size()}" 165 | ) 166 | 167 | attn_output = attn_output.view( 168 | bsz, tgt_len, mbart_attention_module.num_heads, mbart_attention_module.head_dim 169 | ) 170 | # attn_output = attn_output.transpose(1, 2) 171 | 172 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 173 | # partitioned across GPUs when using tensor-parallelism. 174 | attn_output = attn_output.reshape(bsz, tgt_len, mbart_attention_module.embed_dim) 175 | 176 | attn_output = mbart_attention_module.out_proj(attn_output) 177 | 178 | return attn_output, None, past_key_value 179 | 180 | 181 | def mbart_decoder_fast_forward( 182 | mbart_decoder_module: MBartDecoder, 183 | input_ids: torch.LongTensor = None, 184 | attention_mask: Optional[torch.Tensor] = None, 185 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 186 | encoder_attention_mask: Optional[torch.LongTensor] = None, 187 | head_mask: Optional[torch.Tensor] = None, 188 | cross_attn_head_mask: Optional[torch.Tensor] = None, 189 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 190 | inputs_embeds: Optional[torch.FloatTensor] = None, 191 | use_cache: Optional[bool] = None, 192 | output_attentions: Optional[bool] = None, 193 | output_hidden_states: Optional[bool] = None, 194 | return_dict: Optional[bool] = None, 195 | ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: 196 | r""" 197 | Args: 198 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 199 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 200 | provide it. 201 | 202 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 203 | [`PreTrainedTokenizer.__call__`] for details. 204 | 205 | [What are input IDs?](../glossary#input-ids) 206 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 207 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 208 | 209 | - 1 for tokens that are **not masked**, 210 | - 0 for tokens that are **masked**. 211 | 212 | [What are attention masks?](../glossary#attention-mask) 213 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): 214 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 215 | of the decoder. 216 | encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): 217 | Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values 218 | selected in `[0, 1]`: 219 | 220 | - 1 for tokens that are **not masked**, 221 | - 0 for tokens that are **masked**. 222 | 223 | [What are attention masks?](../glossary#attention-mask) 224 | head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 225 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: 226 | 227 | - 1 indicates the head is **not masked**, 228 | - 0 indicates the head is **masked**. 229 | 230 | cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 231 | Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing 232 | cross-attention on hidden heads. Mask values selected in `[0, 1]`: 233 | 234 | - 1 indicates the head is **not masked**, 235 | - 0 indicates the head is **masked**. 236 | 237 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, 238 | returned when `use_cache=True` is passed or when `config.use_cache=True`): 239 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 240 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of 241 | shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 242 | 243 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the 244 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 245 | 246 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those 247 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of 248 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of 249 | shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing 250 | `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more 251 | control over how to convert `input_ids` indices into associated vectors than the model's internal 252 | embedding lookup matrix. 253 | output_attentions (`bool`, *optional*): 254 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 255 | returned tensors for more detail. 256 | output_hidden_states (`bool`, *optional*): 257 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 258 | for more detail. 259 | return_dict (`bool`, *optional*): 260 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 261 | """ 262 | output_attentions = ( 263 | output_attentions 264 | if output_attentions is not None 265 | else mbart_decoder_module.config.output_attentions 266 | ) 267 | output_hidden_states = ( 268 | output_hidden_states 269 | if output_hidden_states is not None 270 | else mbart_decoder_module.config.output_hidden_states 271 | ) 272 | use_cache = ( 273 | use_cache if use_cache is not None else mbart_decoder_module.config.use_cache 274 | ) 275 | return_dict = ( 276 | return_dict 277 | if return_dict is not None 278 | else mbart_decoder_module.config.use_return_dict 279 | ) 280 | 281 | # retrieve input_ids and inputs_embeds 282 | if input_ids is not None and inputs_embeds is not None: 283 | raise ValueError( 284 | "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" 285 | ) 286 | elif input_ids is not None: 287 | input = input_ids 288 | input_shape = input.size() 289 | input_ids = input_ids.view(-1, input_shape[-1]) 290 | elif inputs_embeds is not None: 291 | input_shape = inputs_embeds.size()[:-1] 292 | input = inputs_embeds[:, :, -1] 293 | else: 294 | raise ValueError( 295 | "You have to specify either decoder_input_ids or decoder_inputs_embeds" 296 | ) 297 | 298 | # past_key_values_length 299 | past_key_values_length = ( 300 | past_key_values[0][0].shape[2] if past_key_values is not None else 0 301 | ) 302 | 303 | if inputs_embeds is None: 304 | inputs_embeds = ( 305 | mbart_decoder_module.embed_tokens(input_ids) 306 | * mbart_decoder_module.embed_scale 307 | ) 308 | 309 | # No need to make attention_mask for fast attention -> revived to allow custom attention masking 310 | attention_mask = mbart_decoder_module._prepare_decoder_attention_mask( 311 | attention_mask, input_shape, inputs_embeds, past_key_values_length 312 | ) 313 | 314 | # expand encoder attention mask 315 | if encoder_hidden_states is not None and encoder_attention_mask is not None: 316 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 317 | encoder_attention_mask = _expand_mask( 318 | encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] 319 | ) 320 | 321 | # embed positions 322 | positions = mbart_decoder_module.embed_positions(input, past_key_values_length) 323 | 324 | hidden_states = inputs_embeds + positions.to(inputs_embeds.device) 325 | hidden_states = mbart_decoder_module.layernorm_embedding(hidden_states) 326 | 327 | hidden_states = nn.functional.dropout( 328 | hidden_states, 329 | p=mbart_decoder_module.dropout, 330 | training=mbart_decoder_module.training, 331 | ) 332 | 333 | if mbart_decoder_module.gradient_checkpointing and mbart_decoder_module.training: 334 | if use_cache: 335 | logger.warning_once( 336 | "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." 337 | ) 338 | use_cache = False 339 | 340 | # decoder layers 341 | all_hidden_states = () if output_hidden_states else None 342 | all_self_attns = () if output_attentions else None 343 | all_cross_attentions = ( 344 | () if (output_attentions and encoder_hidden_states is not None) else None 345 | ) 346 | next_decoder_cache = () if use_cache else None 347 | 348 | # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired 349 | for attn_mask, mask_name in zip( 350 | [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"] 351 | ): 352 | if attn_mask is not None: 353 | if attn_mask.size()[0] != len(mbart_decoder_module.layers): 354 | raise ValueError( 355 | f"The `{mask_name}` should be specified for {len(mbart_decoder_module.layers)} layers, but it is for" 356 | f" {attn_mask.size()[0]}." 357 | ) 358 | for idx, decoder_layer in enumerate(mbart_decoder_module.layers): 359 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 360 | if output_hidden_states: 361 | all_hidden_states += (hidden_states,) 362 | dropout_probability = random.uniform(0, 1) 363 | if mbart_decoder_module.training and ( 364 | dropout_probability < mbart_decoder_module.layerdrop 365 | ): 366 | continue 367 | 368 | past_key_value = past_key_values[idx] if past_key_values is not None else None 369 | 370 | if ( 371 | mbart_decoder_module.gradient_checkpointing 372 | and mbart_decoder_module.training 373 | ): 374 | 375 | def create_custom_forward(module: nn.Module) -> Callable: 376 | def custom_forward(*inputs: Any) -> Any: 377 | # None for past_key_value 378 | return module(*inputs, output_attentions, use_cache) 379 | 380 | return custom_forward 381 | 382 | layer_outputs = torch.utils.checkpoint.checkpoint( 383 | create_custom_forward(decoder_layer), 384 | hidden_states, 385 | attention_mask, 386 | encoder_hidden_states, 387 | encoder_attention_mask, 388 | head_mask[idx] if head_mask is not None else None, 389 | cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, 390 | None, 391 | ) 392 | else: 393 | layer_outputs = decoder_layer( 394 | hidden_states, 395 | attention_mask=attention_mask, 396 | encoder_hidden_states=encoder_hidden_states, 397 | encoder_attention_mask=encoder_attention_mask, 398 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 399 | cross_attn_layer_head_mask=( 400 | cross_attn_head_mask[idx] 401 | if cross_attn_head_mask is not None 402 | else None 403 | ), 404 | past_key_value=past_key_value, 405 | output_attentions=output_attentions, 406 | use_cache=use_cache, 407 | ) 408 | hidden_states = layer_outputs[0] 409 | 410 | if use_cache: 411 | next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) 412 | 413 | if output_attentions: 414 | all_self_attns += (layer_outputs[1],) 415 | 416 | if encoder_hidden_states is not None: 417 | all_cross_attentions += (layer_outputs[2],) 418 | 419 | hidden_states = mbart_decoder_module.layer_norm(hidden_states) 420 | 421 | # add hidden states from the last decoder layer 422 | if output_hidden_states: 423 | all_hidden_states += (hidden_states,) 424 | 425 | next_cache = next_decoder_cache if use_cache else None 426 | if not return_dict: 427 | return tuple( 428 | v 429 | for v in [ 430 | hidden_states, 431 | next_cache, 432 | all_hidden_states, 433 | all_self_attns, 434 | all_cross_attentions, 435 | ] 436 | if v is not None 437 | ) 438 | return BaseModelOutputWithPastAndCrossAttentions( 439 | last_hidden_state=hidden_states, 440 | past_key_values=next_cache, 441 | hidden_states=all_hidden_states, 442 | attentions=all_self_attns, 443 | cross_attentions=all_cross_attentions, 444 | ) 445 | -------------------------------------------------------------------------------- /tflop/model/model/TFLOP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import PreTrainedModel, PreTrainedTokenizer 3 | from transformers.file_utils import ModelOutput 4 | 5 | from tflop.model.decoder.mbart_decoder import MBARTDecoder 6 | from tflop.model.model.TFLOP_Config import TFLOPConfig 7 | from tflop.model.visual_encoder.swin import SwinEncoder 8 | 9 | 10 | class TFLOP(PreTrainedModel): 11 | config_class = TFLOPConfig 12 | base_model_prefix = "tflop" 13 | 14 | def __init__( 15 | self: "TFLOP", 16 | config: TFLOPConfig, 17 | tokenizer: PreTrainedTokenizer, 18 | data_ids: list = None, 19 | ): 20 | super().__init__(config) 21 | self.config = config 22 | self.tokenizer = tokenizer 23 | 24 | # Setup Encoder 25 | swin_input_size = ( 26 | self.config.input_size[1], 27 | self.config.input_size[0], 28 | ) # NOTE: Swin input is (height, width) 29 | self.encoder = SwinEncoder( 30 | input_size=swin_input_size, 31 | align_long_axis=self.config.align_along_axis, 32 | window_size=self.config.window_size, 33 | name_or_path=self.config.name_or_path, 34 | encoder_layer=self.config.encoder_layer, 35 | ) 36 | 37 | # Setup Decoder 38 | contrastive_loss_config = { 39 | "use_RowWise_contLearning": self.config.use_RowWise_contLearning, 40 | "use_ColWise_contLearning": self.config.use_ColWise_contLearning, 41 | } 42 | self.decoder = MBARTDecoder( 43 | tokenizer=self.tokenizer, 44 | decoder_layer=self.config.decoder_layer, 45 | max_length=self.config.max_length, 46 | name_or_path=self.config.name_or_path, 47 | max_position_embeddings=self.config.max_position_embeddings, 48 | use_fast=self.config.use_fast_decoder, 49 | input_size=self.config.input_size, 50 | bbox_token_cnt=self.config.bbox_token_cnt, 51 | max_num_row=self.config.max_num_row, 52 | max_num_col=self.config.max_num_col, 53 | use_bbox_HiMulConET=self.config.use_bbox_HiMulConET, 54 | use_imgRoiAlign=self.config.use_imgRoiAlign, 55 | contrastive_loss_config=contrastive_loss_config, 56 | empty_cell_ptr_loss_coeff=self.config.empty_cell_ptr_loss_coeff, 57 | non_empty_cell_ptr_loss_coeff=self.config.non_empty_cell_ptr_loss_coeff, 58 | ) 59 | 60 | self.data_ids = [ 61 | self.tokenizer.convert_tokens_to_ids(token) for token in data_ids 62 | ] 63 | 64 | def forward( 65 | self: "TFLOP", 66 | image_tensors: torch.Tensor, 67 | decoder_input_ids: torch.Tensor, 68 | decoder_labels: torch.Tensor, 69 | pointer_args: dict = None, 70 | ) -> ModelOutput: 71 | # vision encoding 72 | encoder_outputs = self.encoder( 73 | image_tensors 74 | ) # image_tensors: (bsz, 3, 768, 768), encoder_outptus: (bsz, 24*24, 1024) 75 | 76 | # text decoding 77 | decoder_outputs = self.decoder( 78 | input_ids=decoder_input_ids, 79 | input_coords=pointer_args["coord_input_idx"], 80 | input_coords_length=pointer_args["coord_input_length"], 81 | encoder_hidden_states=encoder_outputs, 82 | labels=decoder_labels, 83 | pointer_labels=pointer_args["pointer_labels"], 84 | pointer_mask_labels=pointer_args["pointer_mask_labels"], 85 | bbox_coeff_tensor=pointer_args["bbox_coeff_tensor"], 86 | ) 87 | 88 | return decoder_outputs 89 | 90 | def inference( 91 | self: "TFLOP", 92 | image_tensors: torch.Tensor, 93 | prompt_tensors: torch.Tensor, 94 | return_json: bool = True, 95 | return_attentions: bool = False, 96 | pointer_args: dict = None, 97 | return_last_hidden_state: bool = False, 98 | ): 99 | """ 100 | Perform inference using the TFLOP model. 101 | 102 | This method processes input image and prompt tensors through the TFLOP model's 103 | vision encoder and text decoder to produce the desired output. 104 | 105 | Args: 106 | image_tensors (torch.Tensor): Tensor representing the input images to the vision encoder. 107 | prompt_tensors (torch.Tensor): Tensor representing the prompt sequences for the text decoder. 108 | return_attentions (bool, optional): If True, the method returns attention maps from the decoder. Default is False. 109 | pointer_args (dict, optional): 110 | A dictionary containing arguments required for pointer-based decoding. 111 | Must include "coord_input_idx" and "coord_input_length". 112 | return_last_hidden_state (bool, optional): 113 | If True, the last hidden state of the decoder will be included in the output. Default is False. 114 | 115 | Returns: 116 | dict: A dictionary containing the following keys: 117 | - "output_sequences": Decoded output sequences (torch.Tensor). 118 | - "text_to_dr_coord": Text-to-detection-region coordinate predictions (torch.Tensor). 119 | - "last_hidden_state" (optional): The last hidden state of the decoder, if requested. 120 | - "attention" (optional): A dictionary with attention maps, if requested, containing: 121 | - "self_attentions": Self-attention maps from the decoder. 122 | - "cross_attentions": Cross-attention maps between encoder and decoder. 123 | 124 | Raises: 125 | ValueError: If either `image_tensors` or `prompt_tensors` is None. 126 | ValueError: If `pointer_args` is not provided or missing required keys. 127 | 128 | """ 129 | 130 | if image_tensors is None: 131 | raise ValueError("image_tensors must be provided.") 132 | if self.device.type == "cuda": 133 | # image_tensors = image_tensors.half() 134 | image_tensors = image_tensors.to(self.device) 135 | 136 | if prompt_tensors is None: 137 | raise ValueError("prompt_tensors must be provided.") 138 | prompt_tensors = prompt_tensors.to(self.device) 139 | 140 | # vision encoding 141 | last_hidden_state = self.encoder(image_tensors) 142 | if self.device.type != "cuda": 143 | last_hidden_state = last_hidden_state.to(torch.float32) 144 | encoder_outputs = ModelOutput( 145 | last_hidden_state=last_hidden_state, attentions=None 146 | ) 147 | 148 | # Set up vision encoder & prompt tensor for decoding 149 | if len(encoder_outputs.last_hidden_state.size()) == 1: 150 | encoder_outputs.last_hidden_state = ( 151 | encoder_outputs.last_hidden_state.unsqueeze(0) 152 | ) 153 | if len(prompt_tensors.size()) == 1: 154 | prompt_tensors = prompt_tensors.unsqueeze(0) 155 | 156 | # text decoding 157 | if pointer_args is not None: 158 | decoder_output = self.decoder.model.generate( 159 | decoder_input_ids=prompt_tensors, 160 | encoder_outputs=encoder_outputs, 161 | max_length=(self.config.max_length - self.config.bbox_token_cnt), 162 | early_stopping=True, 163 | pad_token_id=self.tokenizer.pad_token_id, 164 | eos_token_id=self.tokenizer.eos_token_id, 165 | use_cache=True, 166 | num_beams=1, 167 | bad_words_ids=[[self.tokenizer.unk_token_id]], 168 | return_dict_in_generate=True, 169 | output_attentions=return_attentions, 170 | output_scores=True, 171 | output_hidden_states=True, 172 | input_coords=pointer_args["coord_input_idx"], 173 | input_coords_length=pointer_args["coord_input_length"], 174 | ) 175 | last_hidden_state_collection = [ 176 | tok_pos_output[-1] 177 | for tok_pos_output in decoder_output.decoder_hidden_states 178 | ] 179 | last_hidden_state_collection = torch.cat( 180 | last_hidden_state_collection, dim=1 181 | ) # (bsz, seq_len-1, d_model) 182 | 183 | text_to_dr_point_pred = self.get_dr_point_pred( 184 | last_hidden_state_collection, decoder_output.sequences 185 | ) # (bsz, text_token_cnt-2, bbox_token_cnt) 186 | output = { 187 | "output_sequences": decoder_output.sequences, 188 | "text_to_dr_coord": text_to_dr_point_pred, 189 | } 190 | 191 | if return_last_hidden_state: 192 | output["last_hidden_state"] = last_hidden_state_collection 193 | 194 | else: 195 | raise ValueError("pointer_args must be provided.") 196 | 197 | if return_attentions: 198 | output["attention"] = { 199 | "self_attentions": decoder_output.decoder_attentions, 200 | "cross_attentions": decoder_output.cross_attentions, 201 | } 202 | 203 | return output 204 | 205 | def get_dr_point_pred(self, last_hidden_state_collection, decoder_output_sequences): 206 | """ 207 | get pointer tensor from hidden state and decoder output 208 | Args: 209 | last_hidden_state_collection(batch_size, seq_len-1, d_model): 210 | decoder_output_sequences(batch_size, seq_len-1): 211 | Returns: 212 | combined_feature(bsz, text_token_cnt-2, bbox_token_cnt): 213 | """ 214 | 215 | query_feature = self.decoder.q_linear( 216 | last_hidden_state_collection[:, (1 + self.config.bbox_token_cnt) :] 217 | ) # (bsz, text_token_cnt - 2, d_model) 218 | key_feature = self.decoder.k_linear( 219 | last_hidden_state_collection[:, : self.config.bbox_token_cnt] 220 | ) # (bsz, bbox_token_cnt, d_model) 221 | 222 | combined_feature = [] 223 | 224 | batch_size = query_feature.shape[0] 225 | iter_size = 8 226 | for i in range(0, batch_size, iter_size): 227 | norm_query_feature = torch.nn.functional.normalize( 228 | query_feature[i : i + iter_size], dim=-1 229 | ) # (*bsz, text_token_cnt-2, d_model) 230 | norm_key_feature = torch.nn.functional.normalize( 231 | key_feature[i : i + iter_size], dim=-1 232 | ) # (*bsz, bbox_token_cnt, d_model) 233 | tmp_feature = torch.bmm( 234 | norm_query_feature, norm_key_feature.transpose(1, 2) 235 | ) # (*bsz, text_token_cnt-2, bbox_token_cnt) 236 | combined_feature.append(tmp_feature) 237 | 238 | combined_feature = torch.cat( 239 | combined_feature, dim=0 240 | ) # (bsz, text_token_cnt-2, bbox_token_cnt) 241 | 242 | coord_logits = combined_feature[ 243 | :, :, 1: 244 | ] # (bsz, text_token_cnt-2, bbox_token_cnt-1) 245 | text_token_seq = decoder_output_sequences[:, 2:] # (bsz, text_token_cnt-2) 246 | is_data_tensor = torch.zeros( 247 | (text_token_seq.shape[0], text_token_seq.shape[1]), 248 | dtype=torch.bool, 249 | device=text_token_seq.device, 250 | ) # (bsz, text_token_cnt-2) 251 | for data_id in self.data_ids: 252 | is_data_tensor = torch.logical_or( 253 | is_data_tensor, text_token_seq == data_id 254 | ) # (bsz, text_token_cnt-2) 255 | 256 | coord_logits[~is_data_tensor] = float( 257 | "-inf" 258 | ) # (bsz, text_token_cnt-2, bbox_token_cnt-1) 259 | coord_one_hot = torch.nn.functional.one_hot( 260 | torch.argmax(coord_logits, dim=1), num_classes=coord_logits.shape[1] 261 | ).transpose( 262 | 1, 2 263 | ) # (bsz, text_token_cnt-2, bbox_token_cnt-1) 264 | 265 | # Find data where it is empty 266 | is_empty = torch.sum(coord_one_hot, dim=-1) == 0 # (bsz, text_token_cnt-2) 267 | is_empty = is_empty.unsqueeze(-1).to( 268 | coord_one_hot.dtype 269 | ) # (bsz, text_token_cnt-2, 1) 270 | 271 | combined_feature = torch.cat( 272 | [is_empty, coord_one_hot], dim=-1 273 | ) # (bsz, text_token_cnt-2, bbox_token_cnt) 274 | 275 | return combined_feature 276 | -------------------------------------------------------------------------------- /tflop/model/model/TFLOP_Config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Tuple, Union 3 | 4 | import omegaconf 5 | from transformers import PretrainedConfig 6 | 7 | 8 | class TFLOPConfig(PretrainedConfig): 9 | model_type = "tflop" 10 | 11 | def __init__( 12 | self: "TFLOPConfig", 13 | input_size: dict = {"height": 1280, "width": 960}, 14 | align_along_axis: bool = False, 15 | window_size: int = 10, 16 | encoder_layer: Tuple[int] = (2, 2, 14, 2), 17 | decoder_layer: int = 4, 18 | max_position_embeddings: int = None, 19 | max_length: int = 768, 20 | name_or_path: Union[str, bytes, os.PathLike] = "", 21 | use_fast_decoder: bool = False, 22 | use_ptr_decoder: bool = False, 23 | bbox_token_cnt: int = None, 24 | use_cell_bbox: bool = False, 25 | max_num_row: int = 40, 26 | max_num_col: int = 40, 27 | use_bbox_HiMulConET: bool = False, 28 | use_imgRoiAlign: bool = False, 29 | use_RowWise_contLearning: bool = False, 30 | use_ColWise_contLearning: bool = False, 31 | empty_cell_ptr_loss_coeff: float = 0.5, 32 | non_empty_cell_ptr_loss_coeff: float = 0.5, 33 | **kwargs, 34 | ): 35 | super().__init__() 36 | 37 | if type(input_size) in [dict, omegaconf.dictconfig.DictConfig]: 38 | self.input_size = ( 39 | input_size["width"], 40 | input_size["height"], 41 | ) # Set to default (width, height) 42 | else: 43 | self.input_size = input_size 44 | self.align_along_axis = align_along_axis 45 | self.window_size = window_size 46 | self.encoder_layer = encoder_layer 47 | self.decoder_layer = decoder_layer 48 | self.max_position_embeddings = ( 49 | max_length if max_position_embeddings is None else max_position_embeddings 50 | ) 51 | self.max_length = max_length 52 | self.name_or_path = name_or_path 53 | self.use_fast_decoder = use_fast_decoder 54 | self.use_ptr_decoder = use_ptr_decoder 55 | self.bbox_token_cnt = bbox_token_cnt 56 | self.use_cell_bbox = use_cell_bbox 57 | self.max_num_row = max_num_row 58 | self.max_num_col = max_num_col 59 | self.use_bbox_HiMulConET = use_bbox_HiMulConET 60 | self.use_imgRoiAlign = use_imgRoiAlign 61 | self.use_RowWise_contLearning = use_RowWise_contLearning 62 | self.use_ColWise_contLearning = use_ColWise_contLearning 63 | self.empty_cell_ptr_loss_coeff = empty_cell_ptr_loss_coeff 64 | self.non_empty_cell_ptr_loss_coeff = non_empty_cell_ptr_loss_coeff 65 | 66 | @classmethod 67 | def get_member_variables(cls): 68 | return [ 69 | "input_size", 70 | "align_along_axis", 71 | "window_size", 72 | "encoder_layer", 73 | "decoder_layer", 74 | "max_position_embeddings", 75 | "max_length", 76 | "name_or_path", 77 | "use_fast_decoder", 78 | "use_ptr_decoder", 79 | "bbox_token_cnt", 80 | "use_cell_bbox", 81 | "max_num_row", 82 | "max_num_col", 83 | "use_bbox_HiMulConET", 84 | "use_imgRoiAlign", 85 | "use_RowWise_contLearning", 86 | "use_ColWise_contLearning", 87 | "empty_cell_ptr_loss_coeff", 88 | "non_empty_cell_ptr_loss_coeff", 89 | ] 90 | -------------------------------------------------------------------------------- /tflop/model/visual_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from tflop.model.visual_encoder.swin import SwinEncoder 2 | -------------------------------------------------------------------------------- /tflop/model/visual_encoder/swin.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from typing import Tuple, Union 4 | 5 | import timm 6 | from timm.models.swin_transformer import SwinTransformer 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class SwinEncoder(nn.Module): 13 | r""" 14 | Donut encoder based on SwinTransformer 15 | Set the initial weights and configuration with a pretrained SwinTransformer and then 16 | modify the detailed configurations as a Donut Encoder 17 | 18 | Args: 19 | input_size: Input image size (height, width) 20 | align_long_axis: Whether to rotate image if height is greater than width 21 | window_size: Window size(=patch size) of SwinTransformer 22 | encoder_layer: Number of layers of SwinTransformer encoder 23 | name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local. 24 | otherwise, `swin_base_patch4_window12_384` will be set (using `timm`). 25 | """ 26 | 27 | def __init__( 28 | self: "SwinEncoder", 29 | input_size: Tuple[int], 30 | align_long_axis: bool, 31 | window_size: int, 32 | name_or_path: Union[str, bytes, os.PathLike] = None, 33 | encoder_layer: Tuple[int] = (2, 2, 14, 2), 34 | ) -> None: 35 | super().__init__() 36 | self.input_size = input_size 37 | self.align_long_axis = align_long_axis 38 | self.window_size = window_size 39 | self.encoder_layer = encoder_layer 40 | self.name_or_path = name_or_path 41 | 42 | self.model = SwinTransformer( 43 | img_size=self.input_size, 44 | depths=self.encoder_layer, 45 | window_size=self.window_size, 46 | patch_size=4, 47 | embed_dim=128, 48 | num_heads=[4, 8, 16, 32], 49 | num_classes=0, 50 | ) 51 | 52 | self.model.norm = None 53 | 54 | # weight init with swin 55 | if self.name_or_path is None: 56 | raise NotImplementedError 57 | swin_state_dict = timm.create_model( 58 | "swin_base_patch4_window12_384", pretrained=True 59 | ).state_dict() 60 | state_dict = self.model.state_dict() 61 | for x in state_dict: 62 | if x.endswith("relative_position_index") or x.endswith("attn_mask"): 63 | pass 64 | elif ( 65 | x.endswith("relative_position_bias_table") 66 | and self.model.layers[0].blocks[0].attn.window_size[0] != 12 67 | ): 68 | pos_bias = swin_state_dict[x].unsqueeze(0)[0] 69 | old_len = int(math.sqrt(len(pos_bias))) 70 | new_len = int(2 * self.window_size - 1) 71 | pos_bias = pos_bias.reshape(1, old_len, old_len, -1).permute( 72 | 0, 3, 1, 2 73 | ) 74 | pos_bias = F.interpolate( 75 | pos_bias, 76 | size=(new_len, new_len), 77 | mode="bicubic", 78 | align_corners=False, 79 | ) 80 | state_dict[x] = ( 81 | pos_bias.permute(0, 2, 3, 1) 82 | .reshape(1, new_len**2, -1) 83 | .squeeze(0) 84 | ) 85 | else: 86 | state_dict[x] = swin_state_dict[x] 87 | self.model.load_state_dict(state_dict) 88 | 89 | def forward(self: "SwinEncoder", image_tensors: torch.Tensor) -> torch.Tensor: 90 | """ 91 | Args: 92 | x: (batch_size, num_channels, height, width) 93 | """ 94 | image_tensors = self.model.patch_embed(image_tensors) 95 | # image_tensors = self.model.pos_drop(image_tensors) # can be removed as long as drop_rate is not initialized or set to 0.0 in SwinTransformer 96 | 97 | return self.model.layers(image_tensors) 98 | -------------------------------------------------------------------------------- /tflop/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import omegaconf 5 | from omegaconf import DictConfig, OmegaConf 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 8 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger 9 | from pytorch_lightning.utilities import rank_zero_only 10 | import torch 11 | from transformers import AutoTokenizer 12 | 13 | 14 | @rank_zero_only 15 | def save_config_file(config: DictConfig, path: str) -> None: 16 | """ 17 | Save a configuration file in YAML format to the specified path. 18 | 19 | This function takes a DictConfig object, optionally resolves its interpolations, 20 | and saves it as a YAML file in the given directory. If the directory does not exist, 21 | it will be created. The function is decorated with @rank_zero_only, which means it 22 | should only be executed by the process with rank 0 in a distributed setting. 23 | 24 | Args: 25 | config (DictConfig): The configuration object to be saved. 26 | path (str): The directory path where the configuration file should be saved. 27 | 28 | Returns: 29 | None 30 | 31 | Raises: 32 | OmegaConfException: If there is an error in resolving the interpolations. 33 | 34 | """ 35 | if not Path(path).exists(): 36 | os.makedirs(path) 37 | save_path = Path(path) / "config.yaml" 38 | with open(save_path, "w") as f: 39 | OmegaConf.save(config=config, f=f) 40 | print(f"Config is saved at {save_path}") 41 | 42 | 43 | def set_up_tokenizer( 44 | pretrained_tokenizer_name_or_path, bbox_special_tokens, other_special_tokens 45 | ): 46 | """Set up tokenizer and add bbox & other special tokens 47 | 48 | Args: 49 | pretrained_tokenizer_name_or_path (str): pretrained tokenizer name or path 50 | bbox_special_tokens (ListConfig): bbox special tokens 51 | other_special_tokens (ListConfig): other special tokens 52 | 53 | Returns: 54 | tokenizer (AutoTokenizer): tokenizer with bbox & other special tokens added 55 | """ 56 | # 1. Instantiate tokenizer 57 | tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer_name_or_path) 58 | 59 | # 2. Add bbox special tokens 60 | added_bbox_token_cnt = tokenizer.add_special_tokens( 61 | {"additional_special_tokens": sorted(set(bbox_special_tokens))} 62 | ) 63 | 64 | # 3. Add other special tokens 65 | added_special_tok_cnt = tokenizer.add_special_tokens( 66 | {"additional_special_tokens": sorted(set(other_special_tokens))} 67 | ) 68 | 69 | print( 70 | "Added %s bbox tokens and %s special tokens to tokenizer" 71 | % (added_bbox_token_cnt, added_special_tok_cnt) 72 | ) 73 | 74 | return tokenizer 75 | 76 | 77 | class ProgressBar(pl.callbacks.TQDMProgressBar): 78 | def __init__(self, config): 79 | super().__init__() 80 | self.enable = True 81 | self.config = config 82 | 83 | def disable(self): 84 | self.enable = False 85 | 86 | def get_metrics(self, trainer, model): 87 | items = super().get_metrics(trainer, model) 88 | items.pop("v_num", None) 89 | items["exp_name"] = f"{self.config.get('exp_name', '')}" 90 | items["exp_version"] = f"{self.config.get('exp_version', '')}" 91 | return items 92 | 93 | 94 | def set_up_logger_and_callbacks(config): 95 | """Set up tensorboard logger, LR & ckpt callbacks + progress bar for experiment.""" 96 | result_path = config.result_path 97 | exp_name = config.exp_name 98 | exp_version = config.exp_version 99 | 100 | # 1. Setup tensorboard logger 101 | logger = TensorBoardLogger( 102 | save_dir=result_path, 103 | name=exp_name, 104 | version=exp_version, 105 | default_hp_metric=False, 106 | ) 107 | 108 | # 2. Setup callbacks 109 | # 2.1 lr callback 110 | lr_callback = LearningRateMonitor(logging_interval="step") 111 | # 2.2 checkpoint callback 112 | checkpoint_callback = ModelCheckpoint( 113 | monitor="val_metric", 114 | dirpath=Path(result_path) / exp_name / exp_version, 115 | filename="artifacts-{epoch:02d}-{step:07d}", 116 | save_top_k=1, 117 | save_last=True, 118 | mode="min", 119 | ) 120 | 121 | # 3. Setup progress bar 122 | bar = ProgressBar(config) 123 | 124 | return logger, lr_callback, checkpoint_callback, bar 125 | 126 | 127 | def set_seed(seed_value): 128 | """Set seed for reproducibility.""" 129 | pytorch_lightning_version = int(pl.__version__[0]) 130 | assert pytorch_lightning_version == 2, "Only PL version 2.x is supported." 131 | if pytorch_lightning_version < 2: 132 | pl.utilities.seed.seed_everything(seed_value, workers=True) 133 | else: 134 | import lightning_fabric 135 | 136 | lightning_fabric.utilities.seed.seed_everything(seed_value, workers=True) 137 | 138 | 139 | def resolve_missing_config(model_config_dict): 140 | """Function to handle cases where certain config values are missing""" 141 | 142 | # 1. Filling in max_position_embeddings if absent while max_length is present 143 | if ( 144 | "max_position_embeddings" not in model_config_dict 145 | and "max_length" in model_config_dict 146 | ): 147 | model_config_dict["max_position_embeddings"] = model_config_dict["max_length"] 148 | 149 | # 2. Set use_imgRoiAlign to False if absent 150 | if "use_imgRoiAlign" not in model_config_dict: 151 | model_config_dict["use_imgRoiAlign"] = False 152 | 153 | # 3. Fixing input_size format 154 | if type(model_config_dict["input_size"]) in [ 155 | dict, 156 | omegaconf.dictconfig.DictConfig, 157 | ]: 158 | model_config_dict["input_size"] = ( 159 | model_config_dict["input_size"]["width"], 160 | model_config_dict["input_size"]["height"], 161 | ) 162 | 163 | return model_config_dict 164 | 165 | 166 | def decode_OTSL_seq(otsl_token_seq, pointer_tensor, cell_text_data): 167 | """Decode otsl token seq from token seq and pointer prediction 168 | 169 | Args: 170 | otsl_token_seq List[str]: token sequence 171 | point_prediction torch.Tensor: pointer prediction 172 | cell_text_data List[str]: cell text data 173 | 174 | Returns: 175 | output_seq_tokens str: html sequence 176 | """ 177 | # decode OTSL seq prediction output to html 178 | cell_text = None 179 | OTSL_full_compilation = [] 180 | OTSL_row_compilation = [] 181 | curr_column_index = 0 182 | 183 | for data_ind, token in enumerate( 184 | otsl_token_seq[2:] 185 | ): # ignore the first two tokens as they are [ and ] 186 | if token == "C-tag": 187 | mapping_mask = pointer_tensor[data_ind] # (bbox_token_cnt,) 188 | 189 | # mapping_mask is a boolean mask. Get all indices where value is True 190 | coord_indices = torch.nonzero(mapping_mask).squeeze(-1) # (num_of_coords,) 191 | if len(coord_indices) == 0: # No coordinate mapping predicted 192 | cell_text = None 193 | else: 194 | indices_list = coord_indices.tolist() 195 | for coord_ind in indices_list: 196 | if coord_ind == 0: 197 | continue 198 | elif coord_ind > len(cell_text_data): 199 | continue 200 | else: 201 | if cell_text is None: 202 | cell_text = cell_text_data[coord_ind - 1] 203 | else: 204 | cell_text += " " + cell_text_data[coord_ind - 1] 205 | 206 | OTSL_row_compilation.append([1, 0, 0, cell_text]) 207 | curr_column_index += 1 208 | cell_text = None 209 | elif token == "NL-tag": 210 | # new line 211 | OTSL_full_compilation.append(OTSL_row_compilation) 212 | OTSL_row_compilation = [] 213 | curr_column_index = 0 214 | elif token == "L-tag": 215 | # column span 216 | for col_i in range(len(OTSL_row_compilation)): 217 | # traverse backwards 218 | col_i_value = OTSL_row_compilation[-1 - col_i] 219 | if col_i_value is not None: 220 | col_i_value[2] += 1 221 | break 222 | OTSL_row_compilation.append(None) 223 | curr_column_index += 1 224 | 225 | elif token == "U-tag": 226 | # row span 227 | for row_i in range(len(OTSL_full_compilation)): 228 | # traverse backwards 229 | row_i_value = OTSL_full_compilation[-1 - row_i] # row_i_value is list 230 | if ( 231 | curr_column_index < len(row_i_value) 232 | and row_i_value[curr_column_index] is not None 233 | ): 234 | row_i_value[curr_column_index][1] += 1 235 | break 236 | 237 | OTSL_row_compilation.append(None) 238 | curr_column_index += 1 239 | elif token == "X-tag": 240 | OTSL_row_compilation.append(None) 241 | curr_column_index += 1 242 | continue 243 | else: 244 | continue 245 | 246 | if len(OTSL_row_compilation) > 0: 247 | OTSL_full_compilation.append(OTSL_row_compilation) 248 | 249 | # unravel 250 | OTSL_full_compilation = [ 251 | item for sublist in OTSL_full_compilation for item in sublist 252 | ] 253 | output_html_seq = "" 254 | current_data_index = 0 255 | for data_ind, token in enumerate( 256 | otsl_token_seq[2:] 257 | ): # ignore the first two tokens as they are [ and ] 258 | if token in ["L-tag", "U-tag", "X-tag"]: 259 | current_data_index += 1 260 | continue 261 | elif token == "C-tag": 262 | cell_info = OTSL_full_compilation[current_data_index] 263 | if cell_info is not None: 264 | if cell_info[1] == 0 and cell_info[2] == 0: # This is NOT a span cell 265 | if cell_info[3] is None: 266 | output_html_seq += "" 267 | else: 268 | output_html_seq += "" + cell_info[3] + "" 269 | elif cell_info[1] == 0: # This is column span 270 | if cell_info[3] is None: 271 | output_html_seq += '' % (cell_info[2] + 1) 272 | else: 273 | output_html_seq += ( 274 | '' % (cell_info[2] + 1) 275 | + cell_info[3] 276 | + "" 277 | ) 278 | elif cell_info[2] == 0: # This is row span 279 | if cell_info[3] is None: 280 | output_html_seq += '' % (cell_info[1] + 1) 281 | else: 282 | output_html_seq += ( 283 | '' % (cell_info[1] + 1) 284 | + cell_info[3] 285 | + "" 286 | ) 287 | else: # This is both column and row span 288 | if cell_info[3] is None: 289 | output_html_seq += '' % ( 290 | cell_info[1] + 1, 291 | cell_info[2] + 1, 292 | ) 293 | else: 294 | output_html_seq += ( 295 | '' 296 | % (cell_info[1] + 1, cell_info[2] + 1) 297 | + cell_info[3] 298 | + "" 299 | ) 300 | 301 | current_data_index += 1 302 | 303 | elif token == "NL-tag": 304 | output_html_seq += "" 305 | else: 306 | if token == "▁": 307 | token_to_add = " " 308 | else: 309 | token_to_add = token.replace("▁", "") 310 | output_html_seq += token_to_add 311 | 312 | # Formatting refinement 313 | if not output_html_seq.startswith(""): 314 | if "" in output_html_seq: 315 | output_html_seq = "" + output_html_seq.split("", 1)[1] 316 | else: 317 | output_html_seq = "" + output_html_seq 318 | else: 319 | output_html_seq = "" + output_html_seq.split("", 1)[1] 320 | 321 | # Remove the last tag 322 | tmp_split = output_html_seq.rsplit("", 1) 323 | output_html_seq = tmp_split[0] + tmp_split[1] 324 | output_html_seq = output_html_seq.replace("", "") 325 | output_html_seq = output_html_seq.replace(" token 328 | output_html_seq = output_html_seq.replace("", "") 329 | 330 | return output_html_seq 331 | 332 | 333 | def custom_format_html(html_string, tokenizer): 334 | """Custom format html string 335 | 336 | Args: 337 | html_string str: html string 338 | """ 339 | tokens_to_remove = [ 340 | tokenizer.bos_token, 341 | tokenizer.eos_token, 342 | tokenizer.pad_token, 343 | "", 344 | "", 345 | ] 346 | for token in tokens_to_remove: 347 | html_string = html_string.replace(token, "") 348 | 349 | html_seq = "" + html_string + "
" 350 | 351 | return html_string, html_seq 352 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import importlib 4 | import os 5 | from pathlib import Path 6 | 7 | from omegaconf import OmegaConf 8 | import pytorch_lightning as pl 9 | from pytorch_lightning.plugins import CheckpointIO 10 | import torch 11 | 12 | from tflop.lightning_module.lightning_module import DataPLModule, TFLOPModelPLModule 13 | from tflop.utils import ( 14 | save_config_file, 15 | set_seed, 16 | set_up_logger_and_callbacks, 17 | set_up_tokenizer, 18 | ) 19 | 20 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 21 | 22 | 23 | class CustomCheckpointIO(CheckpointIO): 24 | def save_checkpoint(self, checkpoint, path, storage_options=None): 25 | del checkpoint["state_dict"] 26 | torch.save(checkpoint, path) 27 | 28 | def load_checkpoint(self, path, storage_options=None): 29 | checkpoint = torch.load(path + "artifacts.ckpt") 30 | state_dict = torch.load(path + "pytorch_model.bin") 31 | checkpoint["state_dict"] = { 32 | "model." + key: value for key, value in state_dict.items() 33 | } 34 | return checkpoint 35 | 36 | def remove_checkpoint(self, path) -> None: 37 | return super().remove_checkpoint(path) 38 | 39 | 40 | def train(config): 41 | # Seed everything 42 | set_seed(config.get("seed", 42)) 43 | 44 | # sanity check on Contrastive Learning setting 45 | assert any( 46 | [ 47 | config.get("use_RowWise_contLearning", False), 48 | config.get("use_ColWise_contLearning", False), 49 | ] 50 | ), "Contrastive Learning setting is not correct." 51 | 52 | # Setup tokenizer 53 | tokenizer = set_up_tokenizer( 54 | pretrained_tokenizer_name_or_path=config.get( 55 | "pretrained_tokenizer_name_or_path", "hyunwoongko/asian-bart-ecjk" 56 | ), 57 | bbox_special_tokens=config.bbox_special_tokens, 58 | other_special_tokens=config.special_chars, 59 | ) 60 | 61 | # Setup model PL module 62 | model_module = TFLOPModelPLModule(config=config, tokenizer=tokenizer, mode="train") 63 | # Setup data PL module 64 | data_module = DataPLModule(config=config) 65 | 66 | # Instantiate dataset 67 | dataset_collection = {} 68 | dataset_class = getattr( 69 | importlib.import_module(config.dataset_script_path), config.dataset_class_name 70 | ) 71 | for split in ["train", "validation"]: 72 | dataset_collection[split] = dataset_class( 73 | tokenizer=tokenizer, split=split, config=config 74 | ) 75 | data_module.train_dataset = dataset_collection["train"] 76 | data_module.val_dataset = dataset_collection["validation"] 77 | 78 | # Setup logger, callbacks and progressbar 79 | logger, lr_callback, checkpoint_callback, bar = set_up_logger_and_callbacks(config) 80 | 81 | custom_ckpt = CustomCheckpointIO() 82 | trainer = pl.Trainer( 83 | num_nodes=config.get("num_nodes", 1), 84 | devices=4, 85 | strategy=config.get("strategy", "ddp"), 86 | accelerator="gpu", 87 | plugins=custom_ckpt, 88 | max_epochs=config.max_epochs, 89 | max_steps=config.max_steps, 90 | val_check_interval=config.val_check_interval, 91 | check_val_every_n_epoch=config.check_val_every_n_epoch, 92 | gradient_clip_val=config.gradient_clip_val, 93 | precision="bf16", 94 | num_sanity_val_steps=1, 95 | logger=logger, 96 | accumulate_grad_batches=config.get("accumulate_grad_batches", 1), 97 | callbacks=[lr_callback, checkpoint_callback, bar], 98 | ) 99 | trainer.fit( 100 | model_module, 101 | data_module, 102 | ckpt_path=config.get("resume_from_checkpoint_path", None), 103 | ) 104 | 105 | 106 | if __name__ == "__main__": 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument("--exp_config", type=str, required=True) 109 | parser.add_argument("--data_config", type=str, required=True) 110 | args, left_argv = parser.parse_known_args() 111 | 112 | exp_config = OmegaConf.load(args.exp_config) 113 | data_config = OmegaConf.load(args.data_config) 114 | cli_config = OmegaConf.from_cli( 115 | left_argv 116 | ) # config from cli in the form of key=value 117 | 118 | config = OmegaConf.unsafe_merge(exp_config, data_config, cli_config) 119 | config.exp_version = ( 120 | datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 121 | if not config.exp_version 122 | else config.exp_version 123 | ) 124 | # Load bbox tokens into config 125 | bbox_special_tokens = [ 126 | f"" for i in range(max(config.input_size.values()) + 1) 127 | ] 128 | config.bbox_special_tokens = bbox_special_tokens 129 | 130 | OmegaConf.resolve(config) 131 | for sanity_config in ["result_path", "exp_name", "exp_version"]: 132 | assert ( 133 | config.get(sanity_config, None) is not None 134 | ), f"{sanity_config} is not set in config." 135 | 136 | save_config_file( 137 | config, Path(config.result_path) / config.exp_name / config.exp_version 138 | ) 139 | 140 | # print config on console 141 | for k, v in config.items(): 142 | if k != "bbox_special_tokens": 143 | print(f"{k}: {v}") 144 | else: 145 | print(f"{k}: {len(v)}") 146 | train(config) 147 | --------------------------------------------------------------------------------