├── .gitignore ├── CONFIG.mk ├── LICENSE ├── Makefile ├── README.md ├── configs ├── dataset │ ├── augmentation │ │ ├── beit.yaml │ │ ├── resize_normalize.yaml │ │ └── vqvae.yaml │ ├── concat_dataset.yaml │ ├── fintabnet │ │ ├── test_dataset.yaml │ │ ├── train_dataset.yaml │ │ └── valid_dataset.yaml │ ├── icdar │ │ ├── test_dataset.yaml │ │ ├── train_dataset.yaml │ │ └── valid_dataset.yaml │ ├── mini_pubtabnet │ │ ├── test_dataset.yaml │ │ ├── train_dataset.yaml │ │ └── valid_dataset.yaml │ ├── pubtables1m │ │ ├── test_dataset.yaml │ │ ├── train_dataset.yaml │ │ └── valid_dataset.yaml │ ├── pubtabnet │ │ ├── test_dataset.yaml │ │ ├── train_dataset.yaml │ │ └── valid_dataset.yaml │ ├── single_dataset.yaml │ ├── synthtabnet_fintabnet │ │ ├── test_dataset.yaml │ │ ├── train_dataset.yaml │ │ └── valid_dataset.yaml │ ├── synthtabnet_marketing │ │ ├── test_dataset.yaml │ │ ├── train_dataset.yaml │ │ └── valid_dataset.yaml │ ├── synthtabnet_pubtabnet │ │ ├── test_dataset.yaml │ │ ├── train_dataset.yaml │ │ └── valid_dataset.yaml │ ├── synthtabnet_sparse │ │ ├── test_dataset.yaml │ │ ├── train_dataset.yaml │ │ └── valid_dataset.yaml │ └── tablebank │ │ ├── test_dataset.yaml │ │ ├── train_dataset.yaml │ │ └── valid_dataset.yaml ├── main.yaml ├── model │ ├── beit.yaml │ ├── encoderdecoder.yaml │ ├── model │ │ ├── backbone │ │ │ ├── imgcnn.yaml │ │ │ ├── imgconvstem.yaml │ │ │ └── imglinear.yaml │ │ ├── decoder │ │ │ └── transformer.yaml │ │ └── encoder │ │ │ └── transformer.yaml │ └── vqvae.yaml ├── trainer │ ├── beit.yaml │ ├── table.yaml │ ├── train │ │ ├── lr_scheduler │ │ │ ├── cosine.yaml │ │ │ ├── exponential.yaml │ │ │ └── step.yaml │ │ └── optimizer │ │ │ ├── adam.yaml │ │ │ └── adamw.yaml │ └── vqvae.yaml └── vocab │ ├── bbox.yaml │ ├── cell.yaml │ ├── empty.yaml │ └── html.yaml ├── dataset └── mini_pubtabnet │ ├── mini_pubtabnet_examples.jsonl │ ├── train │ ├── PMC1626454_002_00.png │ ├── PMC2753619_002_00.png │ ├── PMC2759935_007_01.png │ ├── PMC2838834_005_00.png │ ├── PMC3519711_003_00.png │ ├── PMC3826085_003_00.png │ ├── PMC3907710_006_00.png │ ├── PMC4003957_018_00.png │ ├── PMC4172848_007_00.png │ ├── PMC4517499_004_00.png │ ├── PMC4682394_003_00.png │ ├── PMC4776821_005_00.png │ ├── PMC4840965_004_00.png │ ├── PMC5134617_013_00.png │ ├── PMC5198506_004_00.png │ ├── PMC5332562_005_00.png │ ├── PMC5402779_004_00.png │ ├── PMC5577841_001_00.png │ ├── PMC5679144_002_01.png │ └── PMC5897438_004_00.png │ └── val │ ├── PMC1626454_002_00.png │ ├── PMC2753619_002_00.png │ ├── PMC2759935_007_01.png │ ├── PMC2838834_005_00.png │ ├── PMC3519711_003_00.png │ ├── PMC3826085_003_00.png │ ├── PMC3907710_006_00.png │ ├── PMC4003957_018_00.png │ ├── PMC4172848_007_00.png │ ├── PMC4517499_004_00.png │ ├── PMC4682394_003_00.png │ ├── PMC4776821_005_00.png │ ├── PMC4840965_004_00.png │ ├── PMC5134617_013_00.png │ ├── PMC5198506_004_00.png │ ├── PMC5332562_005_00.png │ ├── PMC5402779_004_00.png │ ├── PMC5577841_001_00.png │ ├── PMC5679144_002_01.png │ └── PMC5897438_004_00.png ├── experiments └── .gitignore ├── notebooks ├── .gitignore └── full_pipeline.ipynb ├── requirements.txt ├── setup.py ├── src ├── datamodule │ ├── __init__.py │ ├── augmentation.py │ ├── dataloader.py │ ├── fintabnet.py │ ├── pubtables1m.py │ ├── pubtabnet.py │ ├── synthtabnet.py │ └── tablebank.py ├── main.py ├── model │ ├── __init__.py │ ├── beit.py │ ├── components.py │ ├── encoderdecoder.py │ └── vqvae.py ├── trainer │ ├── __init__.py │ ├── train_beit.py │ ├── train_table.py │ ├── train_vqvae.py │ └── utils.py ├── utils │ ├── __init__.py │ ├── coco_map.py │ ├── data.py │ ├── engine.py │ ├── mask_generator.py │ ├── misc.py │ ├── teds.py │ └── visualization.py └── vocab │ ├── .gitignore │ ├── __init__.py │ └── constant.py ├── vocab ├── .gitignore ├── vocab_bbox.json ├── vocab_cell_6k.json └── vocab_html.json └── website ├── unitable-demo.gif ├── unitable-demo.mp4 └── wandb_screenshot.png /.gitignore: -------------------------------------------------------------------------------- 1 | # VS Code 2 | .history 3 | 4 | # make targets & code outputs 5 | .done* 6 | outputs 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/#use-with-ide 117 | .pdm.toml 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | -------------------------------------------------------------------------------- /CONFIG.mk: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Configurations # 3 | ################################################## 4 | 5 | # 6 | # Datasets 7 | # 8 | 9 | # label type 10 | LABEL_IMAGE = ++trainer.label_type="image" 11 | LABEL_HTML = ++trainer.label_type="html" "++trainer.train.loss_weights.html=1" 12 | LABEL_CELL = ++trainer.label_type="cell" "++trainer.train.loss_weights.cell=1" 13 | LABEL_BBOX = ++trainer.label_type="bbox" "++trainer.train.loss_weights.bbox=1" 14 | MEAN = [0.86597056,0.88463002,0.87491087] 15 | STD = [0.20686628,0.18201602,0.18485524] 16 | 17 | # augmentation 18 | AUG_VQVAE = dataset/augmentation=vqvae 19 | AUG_BEIT = dataset/augmentation=beit \ 20 | ++dataset.augmentation.mean=$(MEAN) ++dataset.augmentation.std=$(STD) 21 | AUG_RESIZE_NORM = dataset/augmentation=resize_normalize \ 22 | ++dataset.augmentation.transforms.2.mean=$(MEAN) ++dataset.augmentation.transforms.2.std=$(STD) 23 | 24 | # single dataset 25 | DATA_SINGLE = dataset=single_dataset 26 | PUBTABNET = $(DATA_SINGLE) \ 27 | +dataset/pubtabnet@dataset.train_dataset=train_dataset \ 28 | +dataset/pubtabnet@dataset.valid_dataset=valid_dataset \ 29 | +dataset/pubtabnet@dataset.test_dataset=test_dataset 30 | MINIPUBTABNET = $(DATA_SINGLE) \ 31 | +dataset/mini_pubtabnet@dataset.train_dataset=train_dataset \ 32 | +dataset/mini_pubtabnet@dataset.valid_dataset=valid_dataset \ 33 | +dataset/mini_pubtabnet@dataset.test_dataset=test_dataset 34 | 35 | # multiple datasets 36 | DATA_MULTI = dataset=concat_dataset 37 | PUBTABNET_M = +dataset/pubtabnet@dataset.train.d1=train_dataset \ 38 | +dataset/pubtabnet@dataset.valid.d1=valid_dataset \ 39 | +dataset/pubtabnet@dataset.test.d1=test_dataset 40 | SYN_MARKET_M = +dataset/synthtabnet_marketing@dataset.train.d2=train_dataset \ 41 | +dataset/synthtabnet_marketing@dataset.valid.d2=valid_dataset \ 42 | +dataset/synthtabnet_marketing@dataset.test.d2=test_dataset 43 | SYN_FIN_M = +dataset/synthtabnet_fintabnet@dataset.train.d3=train_dataset \ 44 | +dataset/synthtabnet_fintabnet@dataset.valid.d3=valid_dataset \ 45 | +dataset/synthtabnet_fintabnet@dataset.test.d3=test_dataset 46 | SYN_SPARSE_M = +dataset/synthtabnet_sparse@dataset.train.d4=train_dataset \ 47 | +dataset/synthtabnet_sparse@dataset.valid.d4=valid_dataset \ 48 | +dataset/synthtabnet_sparse@dataset.test.d4=test_dataset 49 | SYN_PUB_M = +dataset/synthtabnet_pubtabnet@dataset.train.d5=train_dataset \ 50 | +dataset/synthtabnet_pubtabnet@dataset.valid.d5=valid_dataset \ 51 | +dataset/synthtabnet_pubtabnet@dataset.test.d5=test_dataset 52 | PUBTABLES_M = +dataset/pubtables1m@dataset.train.d7=train_dataset \ 53 | +dataset/pubtables1m@dataset.valid.d7=valid_dataset \ 54 | +dataset/pubtables1m@dataset.test.d7=test_dataset 55 | TABLEBANK_M = +dataset/tablebank@dataset.train.d8=train_dataset \ 56 | +dataset/tablebank@dataset.valid.d8=valid_dataset \ 57 | +dataset/tablebank@dataset.test.d8=test_dataset 58 | FINTABNET_M = +dataset/fintabnet@dataset.train.d9=train_dataset \ 59 | +dataset/fintabnet@dataset.valid.d9=valid_dataset \ 60 | +dataset/fintabnet@dataset.test.d9=test_dataset 61 | 62 | DATA_VQVAE_1M = $(DATA_MULTI) \ 63 | $(PUBTABNET_M) $(SYN_MARKET_M) $(SYN_FIN_M) $(SYN_SPARSE_M) 64 | DATA_VQVAE_2M = $(DATA_MULTI) \ 65 | $(PUBTABNET_M) $(SYN_MARKET_M) $(SYN_FIN_M) $(SYN_SPARSE_M) $(SYN_PUB_M) \ 66 | $(PUBTABLES_M) $(TABLEBANK_M) 67 | 68 | PUBTABLES1M = $(DATA_MULTI) $(PUBTABLES_M) 69 | FINTABNET = $(DATA_MULTI) $(FINTABNET_M) 70 | 71 | PUB_SYN = $(DATA_MULTI) \ 72 | $(PUBTABNET_M) $(SYN_MARKET_M) $(SYN_FIN_M) $(SYN_SPARSE_M) $(SYN_PUB_M) 73 | 74 | PUB_SYN_FIN = $(DATA_MULTI) $(PUBTABNET_M) $(FINTABNET_M) \ 75 | $(SYN_MARKET_M) $(SYN_FIN_M) $(SYN_SPARSE_M) $(SYN_PUB_M) 76 | 77 | PUB_SYN_PUB1M = $(DATA_MULTI) $(PUBTABNET_M) $(PUBTABLES_M) \ 78 | $(SYN_MARKET_M) $(SYN_FIN_M) $(SYN_SPARSE_M) $(SYN_PUB_M) 79 | 80 | SYN = $(DATA_MULTI) $(SYN_MARKET_M) $(SYN_FIN_M) $(SYN_SPARSE_M) $(SYN_PUB_M) 81 | 82 | SYN_fin = $(DATA_MULTI) $(SYN_FIN_M) 83 | SYN_market = $(DATA_MULTI) $(SYN_MARKET_M) 84 | SYN_pub = $(DATA_MULTI) $(SYN_PUB_M) 85 | SYN_sparse = $(DATA_MULTI) $(SYN_SPARSE_M) 86 | 87 | # 88 | # Vocab 89 | # 90 | VOCAB_NONE = vocab=empty 91 | VOCAB_HTML = vocab=html 92 | VOCAB_BBOX = vocab=bbox 93 | VOCAB_CELL = vocab=cell 94 | 95 | 96 | # 97 | # Trainer 98 | # 99 | 100 | # trainer type 101 | TRAINER_VQVAE = trainer=vqvae 102 | TRAINER_BEIT = trainer=beit 103 | TRAINER_TABLE = trainer=table 104 | 105 | # input image size 106 | I224 = ++trainer.img_size=[224,224] 107 | I448 = ++trainer.img_size=[448,448] 108 | I112_448 = ++trainer.img_size=[112,448] 109 | 110 | # max sequence length 111 | SEQ200 = trainer.max_seq_len=200 112 | SEQ512 = trainer.max_seq_len=512 113 | SEQ1024 = trainer.max_seq_len=1024 114 | 115 | # batch size + epoch 116 | BATCH24 = ++trainer.train.dataloader.batch_size=24 ++trainer.valid.dataloader.batch_size=24 117 | BATCH48 = ++trainer.train.dataloader.batch_size=48 ++trainer.valid.dataloader.batch_size=48 118 | BATCH72 = ++trainer.train.dataloader.batch_size=72 ++trainer.valid.dataloader.batch_size=72 119 | BATCH80 = ++trainer.train.dataloader.batch_size=80 ++trainer.valid.dataloader.batch_size=80 120 | BATCH96 = ++trainer.train.dataloader.batch_size=96 ++trainer.valid.dataloader.batch_size=96 121 | BATCH256 = ++trainer.train.dataloader.batch_size=256 ++trainer.valid.dataloader.batch_size=256 122 | BATCH384 = ++trainer.train.dataloader.batch_size=384 ++trainer.valid.dataloader.batch_size=384 123 | 124 | EPOCH24 = ++trainer.train.epochs=24 125 | EPOCH30 = ++trainer.train.epochs=30 126 | EPOCH48 = ++trainer.train.epochs=48 127 | 128 | # optimizer 129 | OPT_ADAMW = trainer/train/optimizer=adamw 130 | OPT_WD5e2 = ++trainer.train.optimizer.weight_decay=5e-2 131 | 132 | # lr + scheduler 133 | LR_5e4 = ++trainer.train.optimizer.lr=5e-4 134 | LR_3e4 = ++trainer.train.optimizer.lr=3e-4 135 | LR_1e4 = ++trainer.train.optimizer.lr=1e-4 136 | LR_8e5 = ++trainer.train.optimizer.lr=8e-5 137 | 138 | LR_cosine = trainer/train/lr_scheduler=cosine ++trainer.train.lr_scheduler.lr_lambda.min_ratio=5e-3 139 | LR_cosine93k_warm6k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=93400 ++trainer.train.lr_scheduler.lr_lambda.warmup=5800 140 | LR_cosine77k_warm8k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=76600 ++trainer.train.lr_scheduler.lr_lambda.warmup=7660 141 | LR_cosine30k_warm4k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=30500 ++trainer.train.lr_scheduler.lr_lambda.warmup=4000 142 | LR_cosine8k_warm1k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=7600 ++trainer.train.lr_scheduler.lr_lambda.warmup=800 143 | LR_cosine44k_warm6k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=44100 ++trainer.train.lr_scheduler.lr_lambda.warmup=5500 144 | LR_cosine118k_warm15k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=117800 ++trainer.train.lr_scheduler.lr_lambda.warmup=14700 145 | LR_cosine216k_warm27k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=216000 ++trainer.train.lr_scheduler.lr_lambda.warmup=27000 146 | LR_cosine32k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=32000 ++trainer.train.lr_scheduler.lr_lambda.warmup=0 147 | LR_cosine118k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=118000 ++trainer.train.lr_scheduler.lr_lambda.warmup=0 148 | 149 | GRAD_CLIP12 = ++trainer.train.grad_clip=12 150 | 151 | # vqvae 152 | VQVAE_TEMP_1M = ++trainer.train.starting_temp=1. \ 153 | ++trainer.train.temp_min=5e-3 ++trainer.train.temp_anneal_rate=1e-3 154 | VQVAE_TEMP_2M = ++trainer.train.starting_temp=1. \ 155 | ++trainer.train.temp_min=1e-3 ++trainer.train.temp_anneal_rate=2e-4 156 | 157 | # pretraining specific 158 | TRANS448_VQVAE224_GRID28_MASK300 = ++trainer.trans_size=[448,448] ++trainer.vqvae_size=[224,224] ++trainer.grid_size=28 ++trainer.num_mask_patches=300 159 | VQVAE1M_WEIGHTS = $(MODEL_VQVAE) ++trainer.vqvae_weights="../unitable_weights/vqvae_1m.pt" 160 | VQVAE2M_WEIGHTS = $(MODEL_VQVAE_L) ++trainer.vqvae_weights="../unitable_weights/vqvae_2m.pt" 161 | 162 | # finetuning specific 163 | WEIGHTS_mtim_1m_base = ++trainer.trainer.beit_pretrained_weights="../unitable_weights/ssp_1m_base.pt" 164 | WEIGHTS_mtim_1m_large = ++trainer.trainer.beit_pretrained_weights="../unitable_weights/ssp_1m_large.pt" 165 | WEIGHTS_mtim_2m_base = ++trainer.trainer.beit_pretrained_weights="../unitable_weights/ssp_2m_base.pt" 166 | WEIGHTS_mtim_2m_large = ++trainer.trainer.beit_pretrained_weights="../unitable_weights/ssp_2m_large.pt" 167 | LOCK_MTIM_4 = ++trainer.trainer.freeze_beit_epoch=4 168 | 169 | # 170 | # Models 171 | # 172 | 173 | # model type 174 | MODEL_VQVAE = model=vqvae 175 | MODEL_VQVAE_L = $(MODEL_VQVAE) ++model.codebook_tokens=16384 ++model.hidden_dim=512 176 | MODEL_BEIT = model=beit 177 | MODEL_ENCODER_DECODER = model=encoderdecoder 178 | 179 | # backbone for input preprocessing: resnet, linear projection, and convstem 180 | IMGCNN = model/model/backbone=imgcnn 181 | IMGLINEAR = model/model/backbone=imglinear 182 | IMGCONVSTEM = model/model/backbone=imgconvstem 183 | 184 | # number of layers 185 | E4 = ++model.model.encoder.nlayer=4 186 | E12 = ++model.model.encoder.nlayer=12 187 | E24 = ++model.model.encoder.nlayer=24 188 | D4 = ++model.model.decoder.nlayer=4 189 | 190 | # transformer layer: attention heads, hidden size, activation, norm 191 | FF4 = ++model.ff_ratio=4 192 | 193 | NHEAD8 = ++model.nhead=8 194 | NHEAD12 = ++model.nhead=12 195 | 196 | NORM_FIRST = ++model.norm_first=true 197 | NORM_LAST = ++model.norm_first=false 198 | 199 | ACT_RELU = ++model.activation="relu" 200 | ACT_GELU = ++model.activation="gelu" 201 | 202 | D_MODEL512 = ++model.d_model=512 203 | D_MODEL768 = ++model.d_model=768 204 | 205 | # regularization 206 | REG_d00 = ++model.dropout=0.0 207 | REG_d02 = ++model.dropout=0.2 208 | 209 | # linear projection patch size 210 | P16 = ++model.backbone_downsampling_factor=16 211 | P28 = ++model.backbone_downsampling_factor=28 212 | P32 = ++model.backbone_downsampling_factor=32 213 | 214 | # cnn backbone 215 | R18 = ++model.model.backbone.backbone._target_=torchvision.models.resnet18 \ 216 | ++model.model.backbone.output_channels=512 217 | 218 | MTIM_BASE = $(MODEL_BEIT) $(IMGLINEAR) $(NHEAD8) $(FF4) $(ACT_GELU) \ 219 | $(NORM_FIRST) $(D_MODEL512) $(REG_d02) $(P16) $(E4) 220 | MTIM_LARGE = $(MODEL_BEIT) $(IMGLINEAR) $(NHEAD12) $(FF4) $(ACT_GELU) \ 221 | $(NORM_FIRST) $(D_MODEL768) $(REG_d02) $(P16) $(E12) 222 | 223 | ARCH_BASE = $(MTIM_BASE) $(MODEL_ENCODER_DECODER) $(D4) 224 | ARCH_LARGE = $(MTIM_LARGE) $(MODEL_ENCODER_DECODER) $(D4) 225 | 226 | 227 | ############################################### 228 | # Experiments # 229 | ############################################### 230 | 231 | TRAIN_vqvae := $(VOCAB_NONE) \ 232 | $(LABEL_IMAGE) $(AUG_VQVAE) $(I224) \ 233 | $(TRAINER_VQVAE) $(OPT_ADAMW) $(LR_1e4) $(EPOCH24) 234 | 235 | TRAIN_mtim := $(VOCAB_NONE) \ 236 | $(LABEL_IMAGE) $(AUG_BEIT) \ 237 | $(TRAINER_BEIT) $(OPT_ADAMW) $(OPT_WD5e2) $(LR_5e4) \ 238 | $(TRANS448_VQVAE224_GRID28_MASK300) 239 | 240 | # 241 | # mini_pubtabnet pretraining example (dataset code: mini) 242 | # 243 | 244 | # vq-vae 245 | # > make experiments/vqvae_mini/.done_pretrain 246 | EXP_vqvae_mini := $(TRAIN_vqvae) $(MINIPUBTABNET) $(VQVAE_TEMP_2M) $(BATCH80) $(MODEL_VQVAE) $(LR_cosine32k) 247 | 248 | # visual encoder pretraining - masked tabular image modeling (MTIM) 249 | # > make experiments/mtim_mini_base/.done_pretrain 250 | EXP_mtim_mini_base := $(TRAIN_mtim) $(MINIPUBTABNET) $(VQVAE2M_WEIGHTS) $(MTIM_BASE) \ 251 | $(BATCH384) $(LR_cosine8k_warm1k) $(EPOCH24) 252 | 253 | # 254 | # mini_pubtabnet finetuning example 255 | # 256 | 257 | # table structure (task code: html) 258 | # > make experiments/ssp_2m_mini_html_base/.done_finetune 259 | TRAIN_mini_html := $(VOCAB_HTML) \ 260 | $(MINIPUBTABNET) $(LABEL_HTML) $(AUG_RESIZE_NORM) \ 261 | $(TRAINER_TABLE) $(I448) $(SEQ512) \ 262 | $(EPOCH48) $(OPT_ADAMW) $(OPT_WD5e2) $(LR_8e5) 263 | 264 | EXP_ssp_2m_mini_html_base := $(TRAIN_mini_html) $(ARCH_BASE) \ 265 | $(WEIGHTS_mtim_2m_base) $(LOCK_MTIM_4) $(BATCH72) $(LR_cosine93k_warm6k) 266 | 267 | # table cell bbox (task code: bbox) 268 | # > make experiments/ssp_2m_mini_bbox_base/.done_finetune 269 | TRAIN_mini_bbox := $(VOCAB_BBOX) \ 270 | $(MINIPUBTABNET) $(LABEL_BBOX) $(AUG_RESIZE_NORM) \ 271 | $(TRAINER_TABLE) $(I448) $(SEQ1024) \ 272 | $(EPOCH30) $(OPT_ADAMW) $(OPT_WD5e2) $(LR_3e4) $(GRAD_CLIP12) 273 | 274 | EXP_ssp_2m_mini_bbox_base := $(TRAIN_mini_bbox) $(ARCH_BASE) \ 275 | $(WEIGHTS_mtim_2m_base) $(LOCK_MTIM_4) $(BATCH48) $(LR_cosine77k_warm8k) 276 | 277 | # table cell content (task code: cell) 278 | # > make experiments/ssp_2m_mini_cell_base/.done_finetune 279 | TRAIN_mini_cell := $(VOCAB_CELL) \ 280 | $(MINIPUBTABNET) $(LABEL_CELL) $(AUG_RESIZE_NORM) \ 281 | $(TRAINER_TABLE) $(I112_448) $(SEQ200) \ 282 | $(EPOCH24) $(OPT_ADAMW) $(OPT_WD5e2) $(LR_8e5) $(GRAD_CLIP12) 283 | 284 | EXP_ssp_2m_mini_cell_base := $(TRAIN_mini_cell) $(ARCH_BASE) \ 285 | $(WEIGHTS_mtim_2m_base) $(LOCK_MTIM_4) $(BATCH24) $(LR_cosine216k_warm27k) 286 | 287 | # 288 | # cross-dataset pretraining 289 | # 290 | 291 | # vq-vae 292 | EXP_vqvae_1M := $(TRAIN_vqvae) $(DATA_VQVAE_1M) $(VQVAE_TEMP_1M) $(BATCH80) $(MODEL_VQVAE) $(LR_cosine32k) 293 | EXP_vqvae_2M := $(TRAIN_vqvae) $(DATA_VQVAE_2M) $(VQVAE_TEMP_2M) $(BATCH48) $(MODEL_VQVAE_L) $(LR_cosine118k) 294 | 295 | # visual encoder pretraining 296 | EXP_mtim_1M_base := $(TRAIN_mtim) $(PUB_SYN) $(VQVAE1M_WEIGHTS) $(MTIM_BASE) \ 297 | $(BATCH384) $(LR_cosine8k_warm1k) $(EPOCH24) 298 | EXP_mtim_1M_large := $(TRAIN_mtim) $(PUB_SYN) $(VQVAE1M_WEIGHTS) $(MTIM_LARGE) \ 299 | $(BATCH96) $(LR_cosine30k_warm4k) $(EPOCH24) 300 | EXP_mtim_2M_base := $(TRAIN_mtim) $(DATA_VQVAE_2M) $(VQVAE2M_WEIGHTS) $(MTIM_BASE) \ 301 | $(BATCH256) $(LR_cosine44k_warm6k) $(EPOCH48) 302 | EXP_mtim_2M_large := $(TRAIN_mtim) $(DATA_VQVAE_2M) $(VQVAE2M_WEIGHTS) $(MTIM_LARGE) \ 303 | $(BATCH96) $(LR_cosine118k_warm15k) $(EPOCH48) 304 | 305 | # 306 | # cross-dataset finetuning 307 | # 308 | 309 | # table structure 310 | # > make experiments/ssp_2m_syn_pub_html_medium/.done_finetune 311 | TRAIN_syn_pub_html := $(VOCAB_HTML) \ 312 | $(PUB_SYN) $(LABEL_HTML) $(AUG_RESIZE_NORM) \ 313 | $(TRAINER_TABLE) $(I448) $(SEQ512) \ 314 | $(EPOCH48) $(OPT_ADAMW) $(OPT_WD5e2) $(LR_8e5) 315 | 316 | EXP_ssp_2m_syn_pub_html_large := $(TRAIN_syn_pub_html) $(ARCH_LARGE) \ 317 | $(WEIGHTS_mtim_2m_large) $(LOCK_MTIM_4) $(BATCH72) $(LR_cosine93k_warm6k) 318 | 319 | # table cell bbox 320 | # > make experiments/ssp_2m_syn_pub_bbox_medium/.done_finetune 321 | TRAIN_syn_pub_bbox := $(VOCAB_BBOX) \ 322 | $(PUB_SYN) $(LABEL_BBOX) $(AUG_RESIZE_NORM) \ 323 | $(TRAINER_TABLE) $(I448) $(SEQ1024) \ 324 | $(EPOCH30) $(OPT_ADAMW) $(OPT_WD5e2) $(LR_3e4) $(GRAD_CLIP12) 325 | 326 | EXP_ssp_2m_syn_pub_bbox_large := $(TRAIN_syn_pub_bbox) $(ARCH_LARGE) \ 327 | $(WEIGHTS_mtim_2m_large) $(LOCK_MTIM_4) $(BATCH48) $(LR_cosine77k_warm8k) 328 | 329 | # table cell content 330 | # > make experiments/syn_pub_pub1m_cell_medium/.done_finetune 331 | TRAIN_syn_pub_pub1m_cell := $(VOCAB_CELL) \ 332 | $(PUB_SYN_PUB1M) $(LABEL_CELL) $(AUG_RESIZE_NORM) \ 333 | $(TRAINER_TABLE) $(I112_448) $(SEQ200) \ 334 | $(EPOCH24) $(OPT_ADAMW) $(OPT_WD5e2) $(LR_8e5) $(GRAD_CLIP12) 335 | 336 | EXP_ssp_2m_syn_pub_pub1m_cell_large := $(TRAIN_syn_pub_pub1m_cell) $(ARCH_LARGE) \ 337 | $(WEIGHTS_mtim_2m_base) $(LOCK_MTIM_4) $(BATCH24) $(LR_cosine216k_warm27k) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ShengYun (Anthony) Peng. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := /bin/bash 2 | VENV_NAME := unitable 3 | CONDA_ACTIVATE := source $$(conda info --base)/etc/profile.d/conda.sh && conda activate $(VENV_NAME) 4 | PYTHON := $(CONDA_ACTIVATE) && python 5 | PIP := $(CONDA_ACTIVATE) && pip3 6 | # Stacked single-node multi-worker: https://pytorch.org/docs/stable/elastic/run.html#stacked-single-node-multi-worker 7 | TORCHRUN = $(CONDA_ACTIVATE) && torchrun --rdzv-backend=c10d --rdzv_endpoint localhost:0 --nnodes=1 --nproc_per_node=$(NGPU) 8 | 9 | # Taken from https://tech.davis-hansson.com/p/make/ 10 | ifeq ($(origin .RECIPEPREFIX), undefined) 11 | $(error This Make does not support .RECIPEPREFIX. Please use GNU Make 4.0 or later) 12 | endif 13 | .RECIPEPREFIX = > 14 | 15 | # 16 | # Virtual Environment Targets 17 | # 18 | clean: 19 | > rm -f .venv_done 20 | 21 | .done_venv: clean 22 | > conda create -n $(VENV_NAME) python=3.9 -y 23 | > $(PIP) install -r requirements.txt 24 | > $(PIP) install -e . 25 | > touch $@ 26 | 27 | # 28 | # Download pretrained and UniTable model weights 29 | # 30 | WEIGHTS_PATH = experiments/unitable_weights 31 | M_VQVAE_1M = $(WEIGHTS_PATH)/vqvae_1m.pt 32 | M_VQVAE_2M = $(WEIGHTS_PATH)/vqvae_2m.pt 33 | M_SSP_1M_BASE = $(WEIGHTS_PATH)/ssp_1m_base.pt 34 | M_SSP_1M_LARGE = $(WEIGHTS_PATH)/ssp_1m_large.pt 35 | M_SSP_2M_BASE = $(WEIGHTS_PATH)/ssp_2m_base.pt 36 | M_SSP_2M_LARGE = $(WEIGHTS_PATH)/ssp_2m_large.pt 37 | UNITABLE_HTML = $(WEIGHTS_PATH)/unitable_large_structure.pt 38 | UNITABLE_BBOX = $(WEIGHTS_PATH)/unitable_large_bbox.pt 39 | UNITABLE_CELL = $(WEIGHTS_PATH)/unitable_large_content.pt 40 | 41 | .done_download_weights: 42 | ifeq ("$(words $(wildcard $(WEIGHTS_PATH)/*.pt))", "9") 43 | > $(info All 9 model weights have already been downloaded to $(WEIGHTS_PATH).) 44 | else 45 | > $(info There should be 9 weights file under $(WEIGHTS_PATH), but only $(words $(wildcard $(WEIGHTS_PATH)/*.pt)) are found.) 46 | > $(info Begin downloading weights from HuggingFace ...) 47 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/vqvae_1m.pt -P $(WEIGHTS_PATH) 48 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/vqvae_2m.pt -P $(WEIGHTS_PATH) 49 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/ssp_1m_base.pt -P $(WEIGHTS_PATH) 50 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/ssp_1m_large.pt -P $(WEIGHTS_PATH) 51 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/ssp_2m_base.pt -P $(WEIGHTS_PATH) 52 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/ssp_2m_large.pt -P $(WEIGHTS_PATH) 53 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/unitable_large_structure.pt -P $(WEIGHTS_PATH) 54 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/unitable_large_bbox.pt -P $(WEIGHTS_PATH) 55 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/unitable_large_content.pt -P $(WEIGHTS_PATH) 56 | > $(info Completed!) 57 | endif 58 | 59 | # 60 | # Python Targets 61 | # 62 | include CONFIG.mk 63 | SRC := src 64 | BEST_MODEL = "../$(word 1,$(subst -, ,$*))/model/best.pt" 65 | RESULT_JSON := html.json 66 | TEDS_STRUCTURE = -f "../experiments/$*/$(RESULT_JSON)" -s 67 | 68 | ###################### 69 | NGPU := 1 # number of gpus used in the experiments 70 | 71 | .SECONDARY: 72 | 73 | # vq-vae and self-supervised pretraining 74 | experiments/%/.done_pretrain: 75 | > @echo "Using experiment configurations from variable EXP_$*" 76 | > cd $(SRC) && $(TORCHRUN) -m main ++name=$* $(EXP_$*) ++trainer.mode="train" 77 | > touch $@ 78 | 79 | # finetuning from SSP weights for table structure, cell bbox and cell content 80 | experiments/%/.done_finetune: 81 | > @echo "Finetuning phase 1 - using experiment configurations from variable EXP_$*" 82 | > cd $(SRC) && $(TORCHRUN) -m main ++name=$* $(EXP_$*) ++trainer.mode="train" 83 | > @echo "Finetuning phase 2 - starting from epoch 4" 84 | > cd $(SRC) && $(TORCHRUN) -m main ++name=$* $(EXP_$*) ++trainer.mode="train" ++trainer.trainer.snapshot="epoch3_snapshot.pt" ++trainer.trainer.beit_pretrained_weights=null 85 | > touch $@ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UniTable: Towards a Unified Table Foundation Model 2 | 3 |

Demo

4 | 5 | 1. 📈 [High-Performance Transformers for Table Structure Recognition Need Early Convolutions](https://arxiv.org/abs/2311.05565). ShengYun Peng, Seongmin Lee, Xiaojing Wang, Rajarajeswari Balasubramaniyan, Duen Horng Chau. In *NeurIPS Second Table Representation Learning Workshop*, 2023. (Oral) 6 | 2. 🚀 [Self-Supervised Pretraining for Table Structure Recognition Transformer](https://arxiv.org/abs/2402.15578). ShengYun Peng, Seongmin Lee, Xiaojing Wang, Rajarajeswari Balasubramaniyan, Duen Horng Chau. In *AAAI Scientific Document Understanding Workshop*, 2024. (Oral) 7 | 3. 🆕 [UniTable: Towards a Unified Framework for Table Structure Recognition via Self-Supervised Pretraining](https://arxiv.org/abs/2403.04822). ShengYun Peng, Seongmin Lee, Xiaojing Wang, Rajarajeswari Balasubramaniyan, Duen Horng Chau. ArXiv, 2024. 8 | 9 | Tables convey factual and quantitative data with implicit conventions created by humans that are often challenging for machines to parse. Prior work on table recognition (TR) has mainly centered around complex task-specific combinations of available inputs and tools. We present UniTable, a training framework that unifies training paradigm, training objective, and model architecture of TR. 10 | Its training paradigm combines the simplicity of purely pixel-level inputs with the effectiveness and scalability empowered by self-supervised pretraining (SSP) from diverse unannotated tabular images. Our framework unifies the training of all three TR tasks — extracting table structure, cell content, and cell bounding box (bbox) — into a unified task-agnostic training objective: language modeling. Extensive quantitative and qualitative analyses highlight UniTable’s state-of-the-art (SOTA) performance on four of the largest TR datasets. To promote reproducible research, enhance transparency, and SOTA innovations, we have released the first-of-its-kind [Jupyter Notebook](./notebooks/full_pipeline.ipynb) of the entire inference pipeline, fine-tuned across multiple TR datasets, supporting all three TR tasks. 11 | 12 | 13 | > This repo includes code for linear projection Transformers. For convolutional stem (early convolution) Transformers, please check out our [tsr-convstem repo](https://github.com/poloclub/tsr-convstem). 14 | 15 | # News 16 | `Apr. 2024` - You can fully digitalize your own tabular image in our [Jupyter Notebook](./notebooks/full_pipeline.ipynb). 17 | 18 | `Apr. 2024` - UniTable v1.0.0 is now online with model weights available at [HuggingFace](https://huggingface.co/poloclub/UniTable/tree/main). 19 | 20 | `Feb. 2024` - We presented "Self-Supervised Pretraining" paper at AAAI'24. 21 | 22 | `Jan. 2024` - "Self-Supervised Pretraining" paper was selected as [oral](https://sites.google.com/view/sdu-aaai24/schedule?authuser=0). 23 | 24 | `Dec. 2023` - "Self-Supervised Pretraining" paper was accepted by [AAAI'24 Scientific Document Understanding Workshop](https://sites.google.com/view/sdu-aaai24/schedule?authuser=0). 25 | 26 | `Dec. 2023` - We presented "Early Convolutions" paper at [NeurIPS'23](https://x.com/RealAnthonyPeng/status/1735715161476866135?s=20). 27 | 28 | `Oct. 2023` - "Early Convolutions" paper was selected as [oral](https://table-representation-learning.github.io/#accepted-papers). 29 | 30 | `Oct. 2023` - "Early Convolutions" paper was accepted by [NeurIPS'23 Table Representation Learning Workshop](https://table-representation-learning.github.io/). 31 | 32 | # Quick Start 33 | 1. Set up virtual environment (unitable) by running `make .done_venv` in your terminal. 34 | 2. Download all the model weights from [HuggingFace](https://huggingface.co/poloclub/UniTable/tree/main) by running `make .done_download_weights` in your terminal. 35 | 3. Try out our demo [Jupyter Notebook](./notebooks/full_pipeline.ipynb) with your own tabular image! Remember to select "unitable" as your notebook kernel. 36 | 37 | # Training 38 | Our code is driven by [Makefile targets](https://www.gnu.org/software/make/manual/make.html) and configured by [Hydra](https://hydra.cc/docs/intro/). Experiment names are defined as `EXP_` in [CONFIG.mk Sec. Experiments](CONFIG.mk). We have also provided how to launch the make target in the comment above each experiment. 39 | ## Dataset annotation format 40 | We provide a tiny portion (20 samples) of PubTabNet as an example for a quick walk through of the training process. The dataset (tabular image and annotation) is available at [dataset/mini_pubtabnet](./dataset/mini_pubtabnet/). The annotation for all images are in [mini_pubtabnet_examples.jsonl](./dataset/mini_pubtabnet/mini_pubtabnet_examples.jsonl). Each line is a `json` object that corresponds to a `png` image with the following structure: 41 | 42 | ```python 43 | "filename": "tabular image filename inside one of the 'train', 'val', or 'test'", 44 | "split": "One of 'train', 'val', or 'test'", 45 | "html": "table structure, cell content, and cell bbox", 46 | "cells": "Array with all cell content and bbox", 47 | "tokens": "Array with the content of the cell", 48 | "bbox": "The bounding bbox of the cell in [x1, y1, x2, y2] format. Only provided when cell is non-empty.", 49 | "structure": "Dict with table structure", 50 | "tokens": "Array with html tags that describe the table structure. '[]' represents non-empty cell", 51 | ``` 52 | 53 | If you want to train on your own dataset, please convert your dataset based on the provided format. 54 | The five datasets we used in the paper are [PubTabNet](https://github.com/ibm-aur-nlp/PubTabNet), [SynthTabNet](https://github.com/IBM/SynthTabNet), [FinTabNet](https://developer.ibm.com/exchanges/data/all/fintabnet/), [ICDAR 2019 B2 Modern](https://github.com/cndplab-founder/ICDAR2019_cTDaR), and [PubTables-1M](https://huggingface.co/datasets/bsmock/pubtables-1m). 55 | After downloading these datasets, please update the `root_dir` for each dataset under [configs/dataset/](./configs/dataset/). 56 | 57 | ## Tracking your training progress 58 | Please register [Weights & Biases account](https://wandb.ai/site) if you want to visualize training curves and reconstructed tables (for pretraining VQ-VAE only). An example of reconstructed tables by VQ-VAE: 59 | 60 |

wandb

61 | 62 | 63 | ## Finetuning 64 | We present finetuning on the provided mini-PubTabNet. For more details on cross dataset finetuning, please check [CONFIG.mk](CONFIG.mk). 65 | ```bash 66 | # table structure 67 | make experiments/ssp_2m_mini_html_base/.done_finetune 68 | 69 | # cell bbox 70 | make experiments/ssp_2m_mini_bbox_base/.done_finetune 71 | 72 | # cell content 73 | make experiments/ssp_2m_mini_cell_base/.done_finetune 74 | ``` 75 | 76 | ## Pretraining 77 | We present training the VQ-VAE and pretraining the visual encoder on the provided mini-PubTabNet. For more details on cross dataset finetuning, please check [CONFIG.mk](CONFIG.mk). 78 | 79 | ### VQ-VAE 80 | ```bash 81 | make experiments/vqvae_mini/.done_pretrain 82 | ``` 83 | 84 | ### SSP visual encoder - Masked tabular image modeling (MTIM) 85 | ```bash 86 | make experiments/mtim_mini_base/.done_pretrain 87 | ``` 88 | 89 | ## Multi-GPU 90 | The default setting is a single gpu, i.e., `NGPU := 1` in [Makefile](Makefile). To enable multi-GPU, please launch the above command with the following format: `CUDA_VISIBLE_DEVICES=0,1,2,3 make NGPU=4 experiment//.done_`. 91 | 92 | ## Citation 93 | ```bibtex 94 | @article{peng2024unitable, 95 | title={UniTable: Towards a Unified Framework for Table Structure Recognition via Self-Supervised Pretraining}, 96 | author={Peng, ShengYun and Lee, Seongmin and Wang, Xiaojing and Balasubramaniyan, Rajarajeswari and Chau, Duen Horng}, 97 | journal={arXiv preprint}, 98 | year={2024} 99 | } 100 | 101 | @article{peng2024self, 102 | title={Self-Supervised Pretraining for Table Structure Recognition Transformer}, 103 | author={Peng, ShengYun and Lee, Seongmin and Wang, Xiaojing and Balasubramaniyan, Rajarajeswari and Chau, Duen Horng}, 104 | journal={arXiv preprint}, 105 | year={2024} 106 | } 107 | 108 | @inproceedings{peng2023high, 109 | title={High-Performance Transformers for Table Structure Recognition Need Early Convolutions}, 110 | author={Peng, Anthony and Lee, Seongmin and Wang, Xiaojing and Balasubramaniyan, Rajarajeswari Raji and Chau, Duen Horng}, 111 | booktitle={NeurIPS 2023 Second Table Representation Learning Workshop}, 112 | year={2023} 113 | } 114 | ``` 115 | ## Contact 116 | If you have any questions, feel free to contact [Anthony Peng](https://shengyun-peng.github.io/) (CS PhD @Georgia Tech). 117 | 118 | -------------------------------------------------------------------------------- /configs/dataset/augmentation/beit.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodule.augmentation.AugmentationForMIM 2 | mean: [0.86597056, 0.88463002, 0.87491087] 3 | std: [0.20686628, 0.18201602, 0.18485524] 4 | trans_size: ${trainer.trans_size} 5 | vqvae_size: ${trainer.vqvae_size} 6 | trans_interpolation: bicubic 7 | vqvae_interpolation: lanczos -------------------------------------------------------------------------------- /configs/dataset/augmentation/resize_normalize.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.transforms.Compose 2 | transforms: 3 | - _target_: torchvision.transforms.Resize 4 | size: ${trainer.img_size} 5 | - _target_: torchvision.transforms.ToTensor 6 | - _target_: torchvision.transforms.Normalize 7 | mean: [0.86597056, 0.88463002, 0.87491087] 8 | std: [0.20686628, 0.18201602, 0.18485524] 9 | 10 | -------------------------------------------------------------------------------- /configs/dataset/augmentation/vqvae.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.transforms.Compose 2 | transforms: 3 | - _target_: torchvision.transforms.Resize 4 | size: ${trainer.img_size} 5 | - _target_: torchvision.transforms.ToTensor -------------------------------------------------------------------------------- /configs/dataset/concat_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - augmentation: beit 4 | # - pubtabnet@train.d1: train_dataset 5 | # - pubtabnet@valid.d1: valid_dataset 6 | # - synthtabnet_marketing@train.d2: train_dataset 7 | # - synthtabnet_marketing@valid.d2: valid_dataset 8 | # - synthtabnet_fintabnet@train.d3: train_dataset 9 | # - synthtabnet_fintabnet@valid.d3: valid_dataset 10 | # - synthtabnet_sparse@train.d4: train_dataset 11 | # - synthtabnet_sparse@valid.d4: valid_dataset 12 | # - synthtabnet_pubtabnet@train.d5: train_dataset 13 | # - synthtabnet_pubtabnet@valid.d5: valid_dataset 14 | 15 | 16 | label_type: ${trainer.label_type} 17 | cell_limit: 10 18 | 19 | train_dataset: 20 | _target_: torch.utils.data.ConcatDataset 21 | datasets: ${oc.dict.values:..train} 22 | 23 | valid_dataset: 24 | _target_: torch.utils.data.ConcatDataset 25 | datasets: ${oc.dict.values:..valid} 26 | 27 | test_dataset: 28 | _target_: torch.utils.data.ConcatDataset 29 | datasets: ${oc.dict.values:..test} -------------------------------------------------------------------------------- /configs/dataset/fintabnet/test_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - valid_dataset 3 | 4 | jsonl_filename: clean_FinTabNet_1.0.0_cell_test.jsonl -------------------------------------------------------------------------------- /configs/dataset/fintabnet/train_dataset.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodule.FinTabNet 2 | root_dir: ../../../../DATASETS/finTabNet 3 | label_type: ${dataset.label_type} 4 | jsonl_filename: clean_FinTabNet_1.0.0_cell_train.jsonl 5 | transform: ${dataset.augmentation} -------------------------------------------------------------------------------- /configs/dataset/fintabnet/valid_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_dataset 3 | 4 | jsonl_filename: clean_FinTabNet_1.0.0_cell_val.jsonl -------------------------------------------------------------------------------- /configs/dataset/icdar/test_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - valid_dataset 3 | 4 | split: test -------------------------------------------------------------------------------- /configs/dataset/icdar/train_dataset.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodule.ICDAR 2 | root_dir: ../../../../DATASETS/ICDAR-2013 3 | label_type: ${dataset.label_type} 4 | split: train 5 | transform: ${dataset.augmentation} -------------------------------------------------------------------------------- /configs/dataset/icdar/valid_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_dataset 3 | 4 | split: val -------------------------------------------------------------------------------- /configs/dataset/mini_pubtabnet/test_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - valid_dataset 3 | 4 | cell_limit: 256 -------------------------------------------------------------------------------- /configs/dataset/mini_pubtabnet/train_dataset.yaml: -------------------------------------------------------------------------------- 1 | 2 | _target_: src.datamodule.pubtabnet.PubTabNet 3 | root_dir: ../../dataset/mini_pubtabnet 4 | label_type: ${dataset.label_type} 5 | split: train 6 | json_html: mini_pubtabnet_examples.jsonl 7 | transform: ${dataset.augmentation} 8 | cell_limit: 150 -------------------------------------------------------------------------------- /configs/dataset/mini_pubtabnet/valid_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_dataset -------------------------------------------------------------------------------- /configs/dataset/pubtables1m/test_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - valid_dataset 3 | 4 | split: test -------------------------------------------------------------------------------- /configs/dataset/pubtables1m/train_dataset.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodule.PubTables 2 | root_dir: ../../../../DATASETS/pubtables1m/PubTables-1M-Structure 3 | label_type: ${dataset.label_type} 4 | split: train 5 | transform: ${dataset.augmentation} 6 | cell_limit: ${dataset.cell_limit} -------------------------------------------------------------------------------- /configs/dataset/pubtables1m/valid_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_dataset 3 | 4 | split: val -------------------------------------------------------------------------------- /configs/dataset/pubtabnet/test_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - valid_dataset 3 | 4 | cell_limit: 256 -------------------------------------------------------------------------------- /configs/dataset/pubtabnet/train_dataset.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodule.PubTabNet 2 | root_dir: ../../../../DATASETS/pubtabnet 3 | label_type: ${dataset.label_type} 4 | split: train 5 | json_html: clean_html_PubTabNet_2.0.0.jsonl 6 | transform: ${dataset.augmentation} 7 | cell_limit: ${dataset.cell_limit} -------------------------------------------------------------------------------- /configs/dataset/pubtabnet/valid_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_dataset 3 | 4 | split: val -------------------------------------------------------------------------------- /configs/dataset/single_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - augmentation: beit 4 | # - pubtabnet@train_dataset: train_dataset 5 | # - pubtabnet@valid_dataset: valid_dataset 6 | # - pubtabnet@test_dataset: test_dataset 7 | 8 | label_type: ${trainer.label_type} 9 | cell_limit: 10 -------------------------------------------------------------------------------- /configs/dataset/synthtabnet_fintabnet/test_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_dataset 3 | 4 | split: test -------------------------------------------------------------------------------- /configs/dataset/synthtabnet_fintabnet/train_dataset.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodule.Synthtabnet 2 | root_dir: ../../../../DATASETS/synthtabnet/fintabnet 3 | label_type: ${dataset.label_type} 4 | split: train 5 | json_html: clean_html_synthetic_data.jsonl 6 | transform: ${dataset.augmentation} 7 | cell_limit: ${dataset.cell_limit} -------------------------------------------------------------------------------- /configs/dataset/synthtabnet_fintabnet/valid_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_dataset 3 | 4 | split: val -------------------------------------------------------------------------------- /configs/dataset/synthtabnet_marketing/test_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_dataset 3 | 4 | split: test -------------------------------------------------------------------------------- /configs/dataset/synthtabnet_marketing/train_dataset.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodule.Synthtabnet 2 | root_dir: ../../../../DATASETS/synthtabnet/marketing 3 | label_type: ${dataset.label_type} 4 | split: train 5 | json_html: clean_html_synthetic_data.jsonl 6 | transform: ${dataset.augmentation} 7 | cell_limit: ${dataset.cell_limit} -------------------------------------------------------------------------------- /configs/dataset/synthtabnet_marketing/valid_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_dataset 3 | 4 | split: val -------------------------------------------------------------------------------- /configs/dataset/synthtabnet_pubtabnet/test_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_dataset 3 | 4 | split: test -------------------------------------------------------------------------------- /configs/dataset/synthtabnet_pubtabnet/train_dataset.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodule.Synthtabnet 2 | root_dir: ../../../../DATASETS/synthtabnet/pubtabnet 3 | label_type: ${dataset.label_type} 4 | split: train 5 | json_html: clean_html_synthetic_data.jsonl 6 | transform: ${dataset.augmentation} 7 | cell_limit: ${dataset.cell_limit} -------------------------------------------------------------------------------- /configs/dataset/synthtabnet_pubtabnet/valid_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_dataset 3 | 4 | split: val -------------------------------------------------------------------------------- /configs/dataset/synthtabnet_sparse/test_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_dataset 3 | 4 | split: test -------------------------------------------------------------------------------- /configs/dataset/synthtabnet_sparse/train_dataset.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodule.Synthtabnet 2 | root_dir: ../../../../DATASETS/synthtabnet/sparse 3 | label_type: ${dataset.label_type} 4 | split: train 5 | json_html: clean_html_synthetic_data.jsonl 6 | transform: ${dataset.augmentation} 7 | cell_limit: ${dataset.cell_limit} -------------------------------------------------------------------------------- /configs/dataset/synthtabnet_sparse/valid_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_dataset 3 | 4 | split: val -------------------------------------------------------------------------------- /configs/dataset/tablebank/test_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - valid_dataset 3 | 4 | split: test -------------------------------------------------------------------------------- /configs/dataset/tablebank/train_dataset.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodule.TableBank 2 | root_dir: ../../../../DATASETS/tablebank/Recognition 3 | label_type: ${dataset.label_type} 4 | split: train 5 | transform: ${dataset.augmentation} 6 | -------------------------------------------------------------------------------- /configs/dataset/tablebank/valid_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_dataset 3 | 4 | split: val -------------------------------------------------------------------------------- /configs/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - dataset: mini_pubtabnet 4 | - model: encoderdecoder 5 | - trainer: table 6 | - vocab: html 7 | - override hydra/job_logging: colorlog 8 | - override hydra/hydra_logging: colorlog 9 | 10 | 11 | hydra: 12 | run: 13 | dir: ../experiments/${name} 14 | sweep: 15 | dir: ../experiments/${name} 16 | job: 17 | name: ${name} 18 | chdir: true 19 | 20 | wandb: 21 | project: UniTable 22 | 23 | seed: 1234 -------------------------------------------------------------------------------- /configs/model/beit.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model/backbone: imglinear 4 | - model/encoder: transformer 5 | 6 | nhead: 12 7 | ff_ratio: 4 8 | activation: gelu 9 | norm_first: true 10 | d_model: 768 11 | dropout: 0.0 12 | backbone_downsampling_factor: 16 13 | 14 | codebook_tokens: 8192 15 | hidden_dim: 256 16 | 17 | model: 18 | _target_: src.model.beit.BeitEncoder 19 | d_model: ${model.d_model} 20 | codebook_tokens: ${model.codebook_tokens} 21 | dropout: ${model.dropout} 22 | norm_layer: 23 | _partial_: true 24 | _target_: torch.nn.LayerNorm 25 | eps: 1e-6 26 | 27 | model_vqvae: 28 | _target_: src.model.vqvae.DiscreteVAE 29 | image_size: ${trainer.vqvae_size} 30 | codebook_tokens: ${model.codebook_tokens} 31 | codebook_dim: 512 32 | num_layers: 3 33 | hidden_dim: ${model.hidden_dim} 34 | smooth_l1_loss: false 35 | kl_div_loss_weight: 0.0 -------------------------------------------------------------------------------- /configs/model/encoderdecoder.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model/backbone: imgcnn 4 | - model/encoder: transformer 5 | - model/decoder: transformer 6 | 7 | 8 | nhead: 4 9 | ff_ratio: 2 10 | activation: relu 11 | norm_first: false 12 | d_model: 512 13 | dropout: 0.5 14 | backbone_downsampling_factor: 16 15 | 16 | 17 | model: 18 | _target_: src.model.EncoderDecoder 19 | vocab_size: -1 20 | d_model: ${model.d_model} 21 | padding_idx: -1 22 | max_seq_len: ${trainer.max_seq_len} 23 | dropout: ${model.dropout} 24 | norm_layer: 25 | _partial_: true 26 | _target_: torch.nn.LayerNorm 27 | eps: 1e-6 28 | 29 | 30 | -------------------------------------------------------------------------------- /configs/model/model/backbone/imgcnn.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.components.ImgCnnBackbone 2 | backbone: 3 | _target_: torchvision.models.resnet18 4 | output_channels: 512 5 | d_model: ${model.d_model} 6 | drop_layer: 7 | - 3 8 | - 8 9 | - 9 -------------------------------------------------------------------------------- /configs/model/model/backbone/imgconvstem.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.components.ImgConvStemBackbone 2 | d_model: ${model.d_model} 3 | downsample_factor: ${model.backbone_downsampling_factor} 4 | output_channels: 192 5 | kernel_size: 3 -------------------------------------------------------------------------------- /configs/model/model/backbone/imglinear.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.components.ImgLinearBackbone 2 | d_model: ${model.d_model} 3 | patch_size: ${model.backbone_downsampling_factor} -------------------------------------------------------------------------------- /configs/model/model/decoder/transformer.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.components.Decoder 2 | d_model: ${model.d_model} 3 | nhead: ${model.nhead} 4 | dropout: ${model.dropout} 5 | activation: ${model.activation} 6 | norm_first: ${model.norm_first} 7 | nlayer: 4 8 | ff_ratio: ${model.ff_ratio} -------------------------------------------------------------------------------- /configs/model/model/encoder/transformer.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.components.Encoder 2 | d_model: ${model.d_model} 3 | nhead: ${model.nhead} 4 | dropout: ${model.dropout} 5 | activation: ${model.activation} 6 | norm_first: ${model.norm_first} 7 | nlayer: 2 8 | ff_ratio: ${model.ff_ratio} -------------------------------------------------------------------------------- /configs/model/vqvae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | codebook_tokens: 8192 5 | hidden_dim: 256 6 | 7 | model: 8 | _target_: src.model.vqvae.DiscreteVAE 9 | image_size: ${trainer.img_size} 10 | codebook_tokens: ${model.codebook_tokens} 11 | codebook_dim: 512 12 | num_layers: 3 13 | hidden_dim: ${model.hidden_dim} 14 | smooth_l1_loss: false 15 | kl_div_loss_weight: 0.0 -------------------------------------------------------------------------------- /configs/trainer/beit.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - train/lr_scheduler: exponential 4 | - train/optimizer: adam 5 | 6 | mode: train 7 | trans_size: 448 8 | vqvae_size: 224 9 | grid_size: 28 10 | num_mask_patches: 300 11 | min_num_patches: 16 12 | max_seq_len: null 13 | 14 | vqvae_weights: null 15 | 16 | train: 17 | epochs: 20 18 | grad_clip: 5 19 | save_every: 3 20 | dataloader: 21 | _target_: src.datamodule.dataloader.dataloader_beit 22 | batch_size: 48 23 | grid_size: ${trainer.grid_size} 24 | num_mask_patches: ${trainer.num_mask_patches} 25 | min_num_patches: ${trainer.min_num_patches} 26 | valid: 27 | dataloader: 28 | _target_: src.datamodule.dataloader.dataloader_beit 29 | batch_size: 48 30 | grid_size: ${trainer.grid_size} 31 | num_mask_patches: ${trainer.num_mask_patches} 32 | min_num_patches: ${trainer.min_num_patches} 33 | test: 34 | metrics: null 35 | 36 | 37 | trainer: 38 | _target_: src.trainer.BeitTrainer 39 | snapshot: null 40 | model_weights: null -------------------------------------------------------------------------------- /configs/trainer/table.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - train/lr_scheduler: step 4 | - train/optimizer: adam 5 | 6 | 7 | mode: train 8 | img_size: [448,448] 9 | max_seq_len: 512 10 | label_type: html+cell+bbox 11 | 12 | train: 13 | target: ${trainer.label_type} 14 | img_size: ${trainer.img_size} 15 | loss_weights: 16 | table: 0 17 | html: 0 18 | cell: 0 19 | bbox: 0 20 | grad_clip: 5 21 | epochs: 24 22 | save_every: 1 23 | max_seq_len: ${trainer.max_seq_len} 24 | dataloader: 25 | _target_: src.datamodule.dataloader_html 26 | batch_size: 48 27 | label_type: ${trainer.label_type} 28 | valid: 29 | target: ${trainer.label_type} 30 | img_size: ${trainer.img_size} 31 | loss_weights: ${trainer.train.loss_weights} 32 | max_seq_len: ${trainer.max_seq_len} 33 | dataloader: 34 | _target_: src.datamodule.dataloader_html 35 | batch_size: 48 36 | label_type: ${trainer.label_type} 37 | test: 38 | target: ${trainer.train.target} 39 | img_size: ${trainer.img_size} 40 | loss_weights: ${trainer.train.loss_weights} 41 | metrics: teds 42 | max_seq_len: ${trainer.max_seq_len} 43 | sampling: greedy 44 | save_to_prefix: html_table_result 45 | dataloader: 46 | _target_: src.datamodule.dataloader_html 47 | batch_size: 96 48 | label_type: ${trainer.label_type} 49 | 50 | 51 | trainer: 52 | _target_: src.trainer.TableTrainer 53 | snapshot: null 54 | model_weights: null 55 | beit_pretrained_weights: null 56 | freeze_beit_epoch: null -------------------------------------------------------------------------------- /configs/trainer/train/lr_scheduler/cosine.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.LambdaLR 2 | lr_lambda: 3 | _partial_: true 4 | _target_: src.utils.cosine_schedule_with_warmup 5 | warmup: 6 6 | min_ratio: 5e-3 7 | total_step: ${trainer.train.epochs} 8 | -------------------------------------------------------------------------------- /configs/trainer/train/lr_scheduler/exponential.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.ExponentialLR 2 | gamma: 0.98 -------------------------------------------------------------------------------- /configs/trainer/train/lr_scheduler/step.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.StepLR 2 | step_size: 12 3 | gamma: 0.1 -------------------------------------------------------------------------------- /configs/trainer/train/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | lr: 1e-4 3 | weight_decay: 1e-4 -------------------------------------------------------------------------------- /configs/trainer/train/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.AdamW 2 | lr: 1e-4 3 | betas: [0.9, 0.999] 4 | weight_decay: 1e-4 -------------------------------------------------------------------------------- /configs/trainer/vqvae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - train/lr_scheduler: exponential 4 | - train/optimizer: adam 5 | 6 | mode: train 7 | img_size: [256,256] 8 | label_type: image 9 | max_seq_len: null 10 | 11 | train: 12 | epochs: 20 13 | grad_clip: 0.2 14 | starting_temp: 1. 15 | temp_min: 0.06 16 | temp_anneal_rate: 1e-6 17 | save_every: 3 18 | dataloader: 19 | _target_: src.datamodule.dataloader.dataloader_vae 20 | batch_size: 48 21 | valid: 22 | dataloader: 23 | _target_: src.datamodule.dataloader.dataloader_vae 24 | batch_size: 48 25 | test: 26 | metrics: null 27 | 28 | 29 | trainer: 30 | _target_: src.trainer.VqvaeTrainer 31 | snapshot: null 32 | model_weights: null 33 | 34 | -------------------------------------------------------------------------------- /configs/vocab/bbox.yaml: -------------------------------------------------------------------------------- 1 | need_vocab: true 2 | type: html 3 | dir: ${hydra:runtime.cwd}/../vocab/vocab_bbox.json -------------------------------------------------------------------------------- /configs/vocab/cell.yaml: -------------------------------------------------------------------------------- 1 | need_vocab: true 2 | type: cell 3 | dir: ${hydra:runtime.cwd}/../vocab/vocab_cell_6k.json -------------------------------------------------------------------------------- /configs/vocab/empty.yaml: -------------------------------------------------------------------------------- 1 | need_vocab: false -------------------------------------------------------------------------------- /configs/vocab/html.yaml: -------------------------------------------------------------------------------- 1 | need_vocab: true 2 | type: html 3 | dir: ${hydra:runtime.cwd}/../vocab/vocab_html.json -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC1626454_002_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC1626454_002_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC2753619_002_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC2753619_002_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC2759935_007_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC2759935_007_01.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC2838834_005_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC2838834_005_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC3519711_003_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC3519711_003_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC3826085_003_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC3826085_003_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC3907710_006_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC3907710_006_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC4003957_018_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC4003957_018_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC4172848_007_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC4172848_007_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC4517499_004_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC4517499_004_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC4682394_003_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC4682394_003_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC4776821_005_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC4776821_005_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC4840965_004_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC4840965_004_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC5134617_013_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC5134617_013_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC5198506_004_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC5198506_004_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC5332562_005_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC5332562_005_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC5402779_004_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC5402779_004_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC5577841_001_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC5577841_001_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC5679144_002_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC5679144_002_01.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/train/PMC5897438_004_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC5897438_004_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC1626454_002_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC1626454_002_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC2753619_002_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC2753619_002_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC2759935_007_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC2759935_007_01.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC2838834_005_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC2838834_005_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC3519711_003_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC3519711_003_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC3826085_003_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC3826085_003_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC3907710_006_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC3907710_006_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC4003957_018_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC4003957_018_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC4172848_007_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC4172848_007_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC4517499_004_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC4517499_004_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC4682394_003_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC4682394_003_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC4776821_005_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC4776821_005_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC4840965_004_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC4840965_004_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC5134617_013_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC5134617_013_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC5198506_004_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC5198506_004_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC5332562_005_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC5332562_005_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC5402779_004_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC5402779_004_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC5577841_001_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC5577841_001_00.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC5679144_002_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC5679144_002_01.png -------------------------------------------------------------------------------- /dataset/mini_pubtabnet/val/PMC5897438_004_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC5897438_004_00.png -------------------------------------------------------------------------------- /experiments/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !*.ipynb -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchaudio 4 | torchtext 5 | jsonlines 6 | beautifulsoup4 7 | matplotlib 8 | hydra-core 9 | hydra_colorlog 10 | apted 11 | Distance 12 | lxml==4.9.3 13 | torchmetrics 14 | wandb 15 | einops 16 | ptflops 17 | tokenizers 18 | pycocotools 19 | torchmetrics 20 | faster-coco-eval 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages 2 | from setuptools import setup 3 | 4 | setup(name="unitable", version="1.0.0", packages=find_packages()) -------------------------------------------------------------------------------- /src/datamodule/__init__.py: -------------------------------------------------------------------------------- 1 | from .pubtabnet import PubTabNet 2 | from .synthtabnet import Synthtabnet 3 | from .dataloader import dataloader_vae, dataloader_beit, dataloader_html 4 | from .pubtables1m import PubTables 5 | from .tablebank import TableBank 6 | from .fintabnet import FinTabNet 7 | -------------------------------------------------------------------------------- /src/datamodule/augmentation.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Any, Optional, Union 2 | from torch import Tensor 3 | import random 4 | from PIL import Image 5 | import torchvision.transforms.functional as F 6 | from torchvision import datasets, transforms 7 | 8 | from torchvision.transforms.transforms import _setup_size 9 | 10 | 11 | _PIL_INTERPOLATION = { 12 | "bilinear": Image.BILINEAR, 13 | "bicubic": Image.BICUBIC, 14 | "lanczos": Image.LANCZOS, 15 | "hamming": Image.HAMMING, 16 | } 17 | 18 | get_interpolation = lambda method: _PIL_INTERPOLATION.get(method, Image.BILINEAR) 19 | 20 | 21 | class RandomResizedCropAndInterpolationWithTwoPic(transforms.RandomResizedCrop): 22 | """Ensure both crops of vqvae and visual encoder have the same scale and size.""" 23 | 24 | def __init__( 25 | self, 26 | size: Union[int, Tuple[int, int]], # transformer 27 | second_size: Union[int, Tuple[int, int]], # vqvae 28 | scale: Tuple[float, float] = (0.08, 1.0), 29 | ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), 30 | interpolation: str = "bilinear", 31 | second_interpolation: str = "lanczos", 32 | ): 33 | self.second_size = _setup_size( 34 | second_size, 35 | error_msg="Please provide only two dimensions (h, w) for second size.", 36 | ) 37 | 38 | if interpolation == "random": 39 | interpolation = random.choice( 40 | [get_interpolation("bilinear"), get_interpolation("bicubic")] 41 | ) 42 | else: 43 | interpolation = get_interpolation(interpolation) 44 | self.second_interpolation = get_interpolation(second_interpolation) 45 | 46 | super().__init__( 47 | size=size, scale=scale, ratio=ratio, interpolation=interpolation 48 | ) 49 | 50 | def forward(self, img: Image): 51 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 52 | out = F.resized_crop(img, i, j, h, w, self.size, self.interpolation) 53 | out_second = F.resized_crop( 54 | img, i, j, h, w, self.second_size, self.second_interpolation 55 | ) 56 | 57 | return out, out_second 58 | 59 | 60 | class AugmentationForMIM(object): 61 | def __init__( 62 | self, 63 | mean: float, 64 | std: float, 65 | trans_size: Union[int, Tuple[int, int]], 66 | vqvae_size: Union[int, Tuple[int, int]], 67 | trans_interpolation: str, 68 | vqvae_interpolation: str, 69 | ) -> None: 70 | self.common_transform = transforms.Compose( 71 | [ 72 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 73 | transforms.RandomHorizontalFlip(p=0.5), 74 | RandomResizedCropAndInterpolationWithTwoPic( 75 | size=trans_size, 76 | second_size=vqvae_size, 77 | interpolation=trans_interpolation, 78 | second_interpolation=vqvae_interpolation, 79 | ), 80 | ] 81 | ) 82 | 83 | self.trans_transform = transforms.Compose( 84 | [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)] 85 | ) 86 | 87 | self.vqvae_transform = transforms.ToTensor() 88 | 89 | def __call__(self, img: Image) -> Tuple[Tensor, Tensor]: 90 | trans_img, vqvae_img = self.common_transform(img) 91 | trans_img = self.trans_transform(trans_img) 92 | vqvae_img = self.vqvae_transform(vqvae_img) 93 | 94 | return trans_img, vqvae_img 95 | 96 | 97 | if __name__ == "__main__": 98 | mean = [240.380, 240.390, 240.486] 99 | std = [45.735, 45.785, 45.756] 100 | 101 | T = RandomResizedCropAndInterpolationWithTwoPic( 102 | size=(256, 256), 103 | second_size=(256, 256), 104 | interpolation="bicubic", 105 | second_interpolation="lanczos", 106 | ) 107 | 108 | print(T) 109 | -------------------------------------------------------------------------------- /src/datamodule/dataloader.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from torch.utils.data import DataLoader, Dataset, Sampler 3 | from functools import partial 4 | import tokenizers as tk 5 | import torch 6 | from torch.utils.data import default_collate 7 | from src.utils.mask_generator import MaskGenerator 8 | from src.utils import ( 9 | prepare_html_seq, 10 | prepare_cell_seq, 11 | prepare_bbox_seq, 12 | ) 13 | 14 | 15 | class Collator: 16 | def __init__( 17 | self, 18 | vocab: tk.Tokenizer, 19 | max_seq_len: int, 20 | label_type: str, 21 | ) -> None: 22 | self.vocab = vocab 23 | self.vocab.enable_truncation(max_seq_len) 24 | self.label_type = label_type 25 | 26 | def __call__(self, batch) -> Any: 27 | return self._collate_batch(batch, self.vocab, self.label_type) 28 | 29 | def _collate_batch( 30 | self, 31 | batch: list[dict], 32 | vocab: tk.Tokenizer, 33 | label_type: str, 34 | ): 35 | if "cell" in label_type: 36 | image_list = [j for i in batch for j in i[0]] 37 | else: 38 | image_list = [i["image"] for i in batch] 39 | image_list = default_collate(image_list) 40 | 41 | if "cell" in label_type: 42 | filename = [(j["filename"], j["bbox_id"]) for i in batch for j in i[1]] 43 | else: 44 | filename = [i["filename"] for i in batch] 45 | label = dict(filename=filename) 46 | 47 | if "html" in label_type: 48 | html_list = ["".join(prepare_html_seq(i["html"])) for i in batch] 49 | label["html"] = vocab.encode_batch(html_list) 50 | 51 | if "cell" in label_type: 52 | cell_list = [ 53 | " ".join(prepare_cell_seq(j["cell"])) for i in batch for j in i[1] 54 | ] 55 | label["cell"] = vocab.encode_batch(cell_list) 56 | 57 | if "bbox" in label_type: 58 | bbox_list = [" ".join(prepare_bbox_seq(i["bbox"])) for i in batch] 59 | label["bbox"] = vocab.encode_batch(bbox_list) 60 | 61 | return image_list, label 62 | 63 | 64 | def generate_mask_for_batch_samples( 65 | batch, grid_size: int, num_mask_patches: int, min_num_patches: int 66 | ): 67 | N = len(batch) 68 | mg = MaskGenerator( 69 | input_size=grid_size, 70 | num_mask_patches=num_mask_patches, 71 | min_num_patches=min_num_patches, 72 | ) 73 | mask_list = [mg() for _ in range(N)] 74 | return default_collate(batch), default_collate(mask_list) 75 | 76 | 77 | def dataloader_vae( 78 | dataset: Dataset, batch_size: int, sampler: Sampler = None, **kwargs 79 | ) -> DataLoader: 80 | dataloader = DataLoader( 81 | dataset, batch_size, sampler=sampler, num_workers=8, pin_memory=True 82 | ) 83 | 84 | return dataloader 85 | 86 | 87 | def dataloader_beit( 88 | dataset: Dataset, 89 | grid_size: int, 90 | num_mask_patches: int, 91 | min_num_patches: int, 92 | batch_size: int, 93 | sampler: Sampler = None, 94 | **kwargs 95 | ): 96 | dataloader = DataLoader( 97 | dataset, 98 | batch_size, 99 | sampler=sampler, 100 | collate_fn=partial( 101 | generate_mask_for_batch_samples, 102 | grid_size=grid_size, 103 | num_mask_patches=num_mask_patches, 104 | min_num_patches=min_num_patches, 105 | ), 106 | num_workers=8, 107 | pin_memory=True, 108 | ) 109 | 110 | return dataloader 111 | 112 | 113 | def dataloader_html( 114 | dataset: Dataset, 115 | batch_size: int, 116 | vocab: tk.Tokenizer, 117 | max_seq_len: int, 118 | label_type: str, 119 | sampler=None, 120 | ) -> DataLoader: 121 | collate_fn = Collator(vocab, max_seq_len, label_type) 122 | 123 | dataloader = DataLoader( 124 | dataset, 125 | batch_size=batch_size, 126 | shuffle=False, 127 | num_workers=8, 128 | collate_fn=collate_fn, 129 | pin_memory=True, 130 | sampler=sampler, 131 | ) 132 | 133 | return dataloader 134 | -------------------------------------------------------------------------------- /src/datamodule/fintabnet.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal, Union 2 | from pathlib import Path 3 | import jsonlines 4 | from PIL import Image 5 | from torch import Tensor 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class FinTabNet(Dataset): 11 | """Load PubTabNet for different training purposes.""" 12 | 13 | def __init__( 14 | self, 15 | root_dir: Union[Path, str], 16 | label_type: Literal["image", "html", "cell", "bbox"], 17 | transform: transforms = None, 18 | jsonl_filename: Union[Path, str] = None, 19 | ) -> None: 20 | super().__init__() 21 | 22 | self.root_dir = Path(root_dir) 23 | self.label_type = label_type 24 | self.transform = transform 25 | 26 | if label_type != "image": 27 | jsonl_file = self.root_dir / jsonl_filename 28 | with jsonlines.open(jsonl_file) as f: 29 | self.image_label_pair = list(f) 30 | 31 | def __len__(self): 32 | return len(self.image_label_pair) 33 | 34 | def __getitem__(self, index: int) -> Any: 35 | if self.label_type == "image": 36 | raise ValueError("FinTabNet is not used in pretraining.") 37 | else: 38 | obj = self.image_label_pair[index] 39 | img_name = f"{obj['table_id']}.png" 40 | img = Image.open(self.root_dir / "image" / img_name) 41 | if self.transform: 42 | img = self.transform(img) 43 | 44 | sample = dict(filename=obj["filename"], image=img) 45 | 46 | if self.label_type == "html": 47 | sample["html"] = obj["html"]["structure"]["tokens"] 48 | return sample 49 | else: 50 | raise ValueError("Task not supported in current dataset.") 51 | -------------------------------------------------------------------------------- /src/datamodule/pubtables1m.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal, Union 2 | from pathlib import Path 3 | import jsonlines 4 | from PIL import Image 5 | from torch import Tensor 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | import numpy as np 9 | import os 10 | import json 11 | 12 | from src.utils import bbox_augmentation_resize 13 | 14 | 15 | class PubTables(Dataset): 16 | """PubTables-1M-Structure""" 17 | 18 | def __init__( 19 | self, 20 | root_dir: Union[Path, str], 21 | label_type: Literal["image", "cell", "bbox"], 22 | split: Literal["train", "val", "test"], 23 | transform: transforms = None, 24 | cell_limit: int = 100, 25 | ) -> None: 26 | super().__init__() 27 | 28 | self.root_dir = Path(root_dir) 29 | self.split = split 30 | self.label_type = label_type 31 | self.transform = transform 32 | self.cell_limit = cell_limit 33 | 34 | tmp = os.listdir(self.root_dir / self.split) 35 | 36 | self.image_list = [i.split(".xml")[0] for i in tmp] 37 | 38 | def __len__(self): 39 | return len(self.image_list) 40 | 41 | def __getitem__(self, index: int) -> Any: 42 | name = self.image_list[index] 43 | img = Image.open(os.path.join(self.root_dir, "images", name + ".jpg")) 44 | 45 | if self.label_type == "image": 46 | if self.transform: 47 | img = self.transform(img) 48 | return img 49 | elif "bbox" in self.label_type: 50 | img_size = img.size 51 | if self.transform: 52 | img = self.transform(img) 53 | tgt_size = img.shape[-1] 54 | with open( 55 | os.path.join(self.root_dir, "words", name + "_words.json"), "r" 56 | ) as f: 57 | obj = json.load(f) 58 | 59 | obj[:] = [ 60 | v 61 | for i in obj 62 | if "bbox" in i.keys() 63 | and all([i["bbox"][w + 2] > i["bbox"][w] for w in range(2)]) 64 | for v in bbox_augmentation_resize( 65 | [ 66 | min(max(i["bbox"][0], 0), img_size[0]), 67 | min(max(i["bbox"][1], 0), img_size[1]), 68 | min(max(i["bbox"][2], 0), img_size[0]), 69 | min(max(i["bbox"][3], 0), img_size[1]), 70 | ], 71 | img_size, 72 | tgt_size, 73 | ) 74 | ] 75 | 76 | sample = {"filename": name, "image": img, "bbox": obj} 77 | return sample 78 | 79 | elif "cell" in self.label_type: 80 | img_size = img.size 81 | with open( 82 | os.path.join(self.root_dir, "words", name + "_words.json"), "r" 83 | ) as f: 84 | obj = json.load(f) 85 | 86 | bboxes_texts = [ 87 | (i["bbox"], i["text"]) 88 | for idx, i in enumerate(obj) 89 | if "bbox" in i 90 | and i["bbox"][0] < i["bbox"][2] 91 | and i["bbox"][1] < i["bbox"][3] 92 | and i["bbox"][0] >= 0 93 | and i["bbox"][1] >= 0 94 | and i["bbox"][2] < img_size[0] 95 | and i["bbox"][3] < img_size[1] 96 | and idx < self.cell_limit 97 | ] 98 | 99 | img_bboxes = [self.transform(img.crop(bbox[0])) for bbox in bboxes_texts] 100 | 101 | text_bboxes = [ 102 | {"filename": name, "bbox_id": i, "cell": j[1]} 103 | for i, j in enumerate(bboxes_texts) 104 | ] 105 | return img_bboxes, text_bboxes 106 | -------------------------------------------------------------------------------- /src/datamodule/pubtabnet.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal, Union 2 | from pathlib import Path 3 | from PIL import Image 4 | from torch import Tensor 5 | from torch.utils.data import Dataset 6 | import torchvision.transforms as transforms 7 | import numpy as np 8 | import os 9 | 10 | from src.utils import load_json_annotations, bbox_augmentation_resize 11 | 12 | 13 | # average html annotation length: train: 181.327 149.753 14 | # samples train: 500777, val: 9115 15 | class PubTabNet(Dataset): 16 | """Load PubTabNet for different training purposes.""" 17 | 18 | def __init__( 19 | self, 20 | root_dir: Union[Path, str], 21 | label_type: Literal["image", "html", "cell", "bbox"], 22 | split: Literal["train", "val"], 23 | transform: transforms = None, 24 | json_html: Union[Path, str] = None, 25 | cell_limit: int = 150, 26 | ) -> None: 27 | super().__init__() 28 | 29 | self.root_dir = Path(root_dir) 30 | self.split = split 31 | self.label_type = label_type 32 | self.transform = transform 33 | self.cell_limit = cell_limit 34 | 35 | self.img_list = os.listdir(self.root_dir / self.split) 36 | 37 | if label_type != "image": 38 | self.image_label_pair = load_json_annotations( 39 | json_file_dir=Path(root_dir) / json_html, split=self.split 40 | ) 41 | 42 | def __len__(self): 43 | return len(self.img_list) 44 | 45 | def __getitem__(self, index: int) -> Any: 46 | if self.label_type == "image": 47 | img = Image.open(self.root_dir / self.split / self.img_list[index]) 48 | if self.transform: 49 | sample = self.transform(img) 50 | return sample 51 | else: 52 | obj = self.image_label_pair[index] 53 | img = Image.open(self.root_dir / self.split / obj[0]) 54 | 55 | if self.label_type == "html": 56 | if self.transform: 57 | img = self.transform(img) 58 | sample = dict( 59 | filename=obj[0], image=img, html=obj[1]["structure"]["tokens"] 60 | ) 61 | return sample 62 | elif self.label_type == "cell": 63 | bboxes_texts = [ 64 | (i["bbox"], "".join(i["tokens"])) 65 | for idx, i in enumerate(obj[1]["cells"]) 66 | if "bbox" in i 67 | and i["bbox"][0] < i["bbox"][2] 68 | and i["bbox"][1] < i["bbox"][3] 69 | and idx < self.cell_limit 70 | ] 71 | 72 | img_bboxes = [ 73 | self.transform(img.crop(bbox[0])) for bbox in bboxes_texts 74 | ] 75 | 76 | text_bboxes = [ 77 | {"filename": obj[0], "bbox_id": i, "cell": j[1]} 78 | for i, j in enumerate(bboxes_texts) 79 | ] 80 | return img_bboxes, text_bboxes 81 | else: 82 | img_size = img.size 83 | if self.transform: 84 | img = self.transform(img) 85 | tgt_size = img.shape[-1] 86 | sample = dict(filename=obj[0], image=img) 87 | 88 | bboxes = [ 89 | entry["bbox"] 90 | for entry in obj[1]["cells"] 91 | if "bbox" in entry 92 | and entry["bbox"][0] < entry["bbox"][2] 93 | and entry["bbox"][1] < entry["bbox"][3] 94 | ] 95 | 96 | bboxes[:] = [ 97 | i 98 | for entry in bboxes 99 | for i in bbox_augmentation_resize(entry, img_size, tgt_size) 100 | ] 101 | 102 | sample["bbox"] = bboxes 103 | 104 | return sample 105 | -------------------------------------------------------------------------------- /src/datamodule/synthtabnet.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal, Union 2 | from pathlib import Path 3 | import jsonlines 4 | from PIL import Image 5 | from torch import Tensor 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | import numpy as np 9 | import os 10 | 11 | from src.utils import load_json_annotations, bbox_augmentation_resize 12 | 13 | # invalid data pairs: image_000000_1634629424.098128.png has 4 channels 14 | INVALID_DATA = [ 15 | { 16 | "dataset": "fintabnet", 17 | "split": "train", 18 | "image": "image_009379_1634631303.201671.png", 19 | }, 20 | { 21 | "dataset": "marketing", 22 | "split": "train", 23 | "image": "image_000000_1634629424.098128.png", 24 | }, 25 | ] 26 | 27 | 28 | class Synthtabnet(Dataset): 29 | def __init__( 30 | self, 31 | root_dir: Union[Path, str], 32 | label_type: Literal["image", "html", "all"], 33 | split: Literal["train", "val", "test"], 34 | transform: transforms = None, 35 | json_html: Union[Path, str] = None, 36 | cell_limit: int = 100, 37 | ) -> None: 38 | super().__init__() 39 | 40 | self.root_dir = Path(root_dir) / "images" 41 | self.split = split 42 | self.label_type = label_type 43 | self.transform = transform 44 | self.cell_limit = cell_limit 45 | 46 | # SSP only needs image 47 | self.img_list = os.listdir(self.root_dir / self.split) 48 | if label_type != "image": 49 | self.image_label_pair = load_json_annotations( 50 | json_file_dir=Path(root_dir) / json_html, split=split 51 | ) 52 | 53 | def __len__(self): 54 | return len(self.img_list) 55 | 56 | def __getitem__(self, index: int) -> Any: 57 | if self.label_type == "image": 58 | img = Image.open(self.root_dir / self.split / self.img_list[index]) 59 | if self.transform: 60 | sample = self.transform(img) 61 | return sample 62 | else: 63 | obj = self.image_label_pair[index] 64 | img = Image.open(self.root_dir / self.split / obj[0]) 65 | 66 | if self.label_type == "html": 67 | if self.transform: 68 | img = self.transform(img) 69 | sample = dict( 70 | filename=obj[0], image=img, html=obj[1]["structure"]["tokens"] 71 | ) 72 | return sample 73 | elif self.label_type == "cell": 74 | bboxes_texts = [ 75 | (i["bbox"], "".join(i["tokens"])) 76 | for idx, i in enumerate(obj[1]["cells"]) 77 | if "bbox" in i 78 | and i["bbox"][0] < i["bbox"][2] 79 | and i["bbox"][1] < i["bbox"][3] 80 | and idx < self.cell_limit 81 | ] 82 | 83 | img_bboxes = [ 84 | self.transform(img.crop(bbox[0])) for bbox in bboxes_texts 85 | ] # you can limit the total cropped cells to lower gpu memory 86 | 87 | text_bboxes = [ 88 | {"filename": obj[0], "bbox_id": i, "cell": j[1]} 89 | for i, j in enumerate(bboxes_texts) 90 | ] 91 | return img_bboxes, text_bboxes 92 | else: 93 | img_size = img.size 94 | if self.transform: 95 | img = self.transform(img) 96 | tgt_size = img.shape[-1] 97 | sample = dict(filename=obj[0], image=img) 98 | 99 | bboxes = [ 100 | entry["bbox"] 101 | for entry in obj[1]["cells"] 102 | if "bbox" in entry 103 | and entry["bbox"][0] < entry["bbox"][2] 104 | and entry["bbox"][1] < entry["bbox"][3] 105 | ] 106 | 107 | bboxes[:] = [ 108 | i 109 | for entry in bboxes 110 | for i in bbox_augmentation_resize(entry, img_size, tgt_size) 111 | ] 112 | 113 | sample["bbox"] = bboxes 114 | 115 | return sample 116 | -------------------------------------------------------------------------------- /src/datamodule/tablebank.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal, Union 2 | from pathlib import Path 3 | import jsonlines 4 | from PIL import Image 5 | from torch import Tensor 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | import numpy as np 9 | import os 10 | import json 11 | 12 | 13 | class TableBank(Dataset): 14 | """tablebank recognition""" 15 | 16 | def __init__( 17 | self, 18 | root_dir: Union[Path, str], 19 | label_type: Literal["image"], 20 | split: Literal["train", "val", "test"], 21 | transform: transforms = None, 22 | ) -> None: 23 | super().__init__() 24 | 25 | assert label_type == "image", "No annotations" 26 | 27 | self.root_dir = Path(root_dir) 28 | self.label_type = label_type 29 | self.transform = transform 30 | self.image_list = os.listdir(self.root_dir / "images") 31 | 32 | if split == "val" or split == "test": 33 | self.image_list = self.image_list[:1000] 34 | 35 | def __len__(self): 36 | return len(self.image_list) 37 | 38 | def __getitem__(self, index: int) -> Any: 39 | name = self.image_list[index] 40 | img = Image.open(os.path.join(self.root_dir, "images", name)) 41 | if self.transform: 42 | img = self.transform(img) 43 | 44 | if self.label_type == "image": 45 | return img 46 | else: 47 | raise ValueError("TableBank doesn't have HTML annotations.") 48 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import hydra 3 | import logging 4 | import os 5 | import wandb 6 | import torch 7 | import tokenizers as tk 8 | from omegaconf import DictConfig, OmegaConf 9 | from hydra.utils import get_original_cwd, instantiate 10 | from pathlib import Path 11 | import torch.multiprocessing as mp 12 | from torch.utils.data.distributed import DistributedSampler 13 | from torch.distributed import init_process_group, destroy_process_group 14 | 15 | from src.utils import printer, count_total_parameters 16 | 17 | log = logging.getLogger(__name__) 18 | 19 | 20 | @hydra.main(config_path="../configs", config_name="main", version_base="1.3") 21 | def main(cfg: DictConfig): 22 | torch.manual_seed(cfg.seed) 23 | ddp_setup() 24 | device = int(os.environ["LOCAL_RANK"]) 25 | cwd = Path(get_original_cwd()) 26 | exp_dir = Path(os.getcwd()) # experiment directory 27 | 28 | if cfg.trainer.mode == "train": 29 | (exp_dir / "snapshot").mkdir(parents=True, exist_ok=True) 30 | (exp_dir / "model").mkdir(parents=True, exist_ok=True) 31 | if device == 0: 32 | wandb.init(project=cfg.wandb.project, name=cfg.name, resume=True) 33 | 34 | # vocab is used in finetuning, not in self-supervised pretraining 35 | vocab = None 36 | if cfg.vocab.need_vocab: 37 | log.info( 38 | printer( 39 | device, 40 | f"Loading {cfg.vocab.type} vocab from {(cwd / cfg.vocab.dir).resolve()}", 41 | ) 42 | ) 43 | vocab = tk.Tokenizer.from_file(str(cwd / cfg.vocab.dir)) 44 | 45 | # dataset 46 | if cfg.trainer.mode == "train": 47 | log.info(printer(device, "Loading training dataset")) 48 | train_dataset = instantiate(cfg.dataset.train_dataset) 49 | 50 | log.info(printer(device, "Loading validation dataset")) 51 | valid_dataset = instantiate(cfg.dataset.valid_dataset) 52 | 53 | train_kwargs = { 54 | "dataset": train_dataset, 55 | "sampler": DistributedSampler(train_dataset), 56 | "vocab": vocab, 57 | "max_seq_len": cfg.trainer.max_seq_len, 58 | } 59 | 60 | valid_kwargs = { 61 | "dataset": valid_dataset, 62 | "sampler": DistributedSampler(valid_dataset), 63 | "vocab": vocab, 64 | "max_seq_len": cfg.trainer.max_seq_len, 65 | } 66 | 67 | train_dataloader = instantiate(cfg.trainer.train.dataloader, **train_kwargs) 68 | valid_dataloader = instantiate(cfg.trainer.valid.dataloader, **valid_kwargs) 69 | elif cfg.trainer.mode == "test": 70 | # load testing dataset, same as valid for ssl 71 | log.info(printer(device, "Loading testing dataset")) 72 | test_dataset = instantiate(cfg.dataset.test_dataset) 73 | 74 | test_kwargs = { 75 | "dataset": test_dataset, 76 | "sampler": DistributedSampler(test_dataset), 77 | "vocab": vocab, 78 | "max_seq_len": cfg.trainer.max_seq_len, 79 | } 80 | 81 | test_dataloader = instantiate(cfg.trainer.test.dataloader, **test_kwargs) 82 | 83 | # model 84 | log.info(printer(device, "Loading model ...")) 85 | model_name = str(cfg.model.model._target_).split(".")[-1] 86 | if model_name == "DiscreteVAE": 87 | model = instantiate(cfg.model.model) 88 | elif model_name == "BeitEncoder": 89 | max_seq_len = ( 90 | cfg.trainer.trans_size[0] // cfg.model.backbone_downsampling_factor 91 | ) * (cfg.trainer.trans_size[1] // cfg.model.backbone_downsampling_factor) 92 | model = instantiate( 93 | cfg.model.model, 94 | max_seq_len=max_seq_len, 95 | ) 96 | # load pretrained vqvae 97 | model_vqvae = instantiate(cfg.model.model_vqvae) 98 | 99 | log.info(printer(device, "Loading pretrained VQVAE model ...")) 100 | assert Path( 101 | cfg.trainer.vqvae_weights 102 | ).is_file(), f"VQVAE weights doesn't exist: {cfg.trainer.vqvae_weights}" 103 | model_vqvae.load_state_dict( 104 | torch.load(cfg.trainer.vqvae_weights, map_location="cpu") 105 | ) 106 | elif model_name == "EncoderDecoder": 107 | max_seq_len = max( 108 | (cfg.trainer.img_size[0] // cfg.model.backbone_downsampling_factor) 109 | * (cfg.trainer.img_size[1] // cfg.model.backbone_downsampling_factor), 110 | cfg.trainer.max_seq_len, 111 | ) # for positional embedding 112 | model = instantiate( 113 | cfg.model.model, 114 | max_seq_len=max_seq_len, 115 | vocab_size=vocab.get_vocab_size(), 116 | padding_idx=vocab.token_to_id(""), 117 | ) 118 | 119 | log.info( 120 | printer(device, f"Total parameters: {count_total_parameters(model) / 1e6:.2f}M") 121 | ) 122 | 123 | # trainer 124 | log.info(printer(device, "Loading trainer ...")) 125 | trainer_name = str(cfg.trainer.trainer._target_).split(".")[-1] 126 | trainer_kwargs = { 127 | "device": device, 128 | "model": model, 129 | "log": log, 130 | "exp_dir": exp_dir, 131 | "snapshot": ( 132 | exp_dir / "snapshot" / cfg.trainer.trainer.snapshot 133 | if cfg.trainer.trainer.snapshot 134 | else None 135 | ), 136 | } 137 | 138 | if trainer_name == "VqvaeTrainer": 139 | trainer = instantiate(cfg.trainer.trainer, **trainer_kwargs) 140 | elif trainer_name == "BeitTrainer": 141 | trainer_kwargs["model_vqvae"] = model_vqvae 142 | trainer = instantiate(cfg.trainer.trainer, **trainer_kwargs) 143 | elif trainer_name == "TableTrainer": 144 | trainer_kwargs["vocab"] = vocab 145 | trainer = instantiate(cfg.trainer.trainer, **trainer_kwargs) 146 | else: 147 | raise ValueError(f"The provided trainer type {trainer_name} is not supported.") 148 | 149 | if cfg.trainer.mode == "train": 150 | log.info(printer(device, "Training starts ...")) 151 | trainer.train( 152 | train_dataloader, valid_dataloader, cfg.trainer.train, cfg.trainer.valid 153 | ) 154 | elif cfg.trainer.mode == "test": 155 | log.info(printer(device, "Evaluation starts ...")) 156 | save_to = exp_dir / cfg.name 157 | save_to.mkdir(parents=True, exist_ok=True) 158 | trainer.test(test_dataloader, cfg.trainer.test, save_to=save_to) 159 | else: 160 | raise NotImplementedError 161 | 162 | destroy_process_group() 163 | 164 | 165 | def ddp_setup(): 166 | init_process_group(backend="nccl") 167 | 168 | 169 | if __name__ == "__main__": 170 | main() 171 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .beit import BeitEncoder 2 | from .vqvae import DiscreteVAE 3 | from .encoderdecoder import EncoderDecoder 4 | from .components import * -------------------------------------------------------------------------------- /src/model/beit.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn, Tensor 4 | from functools import partial 5 | 6 | from src.model.components import ImgLinearBackbone, PositionEmbedding, Encoder 7 | 8 | 9 | class BeitEncoder(nn.Module): 10 | def __init__( 11 | self, 12 | d_model: int, # embed_dim 13 | backbone: nn.Module, 14 | max_seq_len: int, # for positional embedding 15 | codebook_tokens: int, 16 | dropout: float, 17 | encoder: Encoder, 18 | norm_layer: nn.Module, 19 | init_std: float = 0.02, 20 | ) -> None: 21 | super().__init__() 22 | 23 | self.d_model = d_model 24 | self.init_std = init_std 25 | 26 | self.backbone = backbone 27 | self.pos_embed = PositionEmbedding( 28 | max_seq_len=max_seq_len, d_model=d_model, dropout=dropout 29 | ) 30 | 31 | self.encoder = encoder 32 | self.norm = norm_layer(d_model) 33 | self.generator = nn.Linear(d_model, codebook_tokens) 34 | 35 | self.trunc_normal = partial( 36 | nn.init.trunc_normal_, std=init_std, a=-init_std, b=init_std 37 | ) 38 | self.apply(self._init_weights) 39 | 40 | self.mask_token = nn.Parameter(torch.zeros(1, 1, d_model)) 41 | 42 | def _init_weights(self, m: nn.Module): 43 | if isinstance(m, nn.Linear): 44 | self.trunc_normal(m.weight) 45 | if m.bias is not None: 46 | nn.init.constant_(m.bias, 0.0) 47 | elif isinstance(m, nn.LayerNorm): 48 | nn.init.constant_(m.weight, 1.0) 49 | nn.init.constant_(m.bias, 0.0) 50 | elif isinstance(m, nn.Conv2d): 51 | self.trunc_normal(m.weight) 52 | if m.bias is not None: 53 | nn.init.constant_(m.bias, 0.0) 54 | elif isinstance(m, PositionEmbedding): 55 | self.trunc_normal(m.embedding.weight) 56 | 57 | @torch.jit.ignore 58 | def no_weight_decay(self): 59 | return {"pos_embed"} 60 | 61 | def forward( 62 | self, x: Tensor, bool_masked_pos: Tensor, return_all_tokens: bool = False 63 | ): 64 | x = self.backbone(x) 65 | B, S, E = x.shape 66 | assert E == self.d_model 67 | 68 | mask_token = self.mask_token.expand(B, S, -1) 69 | 70 | w = bool_masked_pos.unsqueeze(-1).type_as(mask_token) 71 | x = x * (1 - w) + mask_token * w 72 | 73 | x = self.pos_embed(x) 74 | 75 | x = self.encoder(x) 76 | x = self.norm(x) 77 | 78 | if return_all_tokens: 79 | return self.generator(x) 80 | else: 81 | return self.generator(x[bool_masked_pos]) 82 | 83 | 84 | if __name__ == "__main__": 85 | d_model = 512 86 | patch_size = 16 87 | nhead = 8 88 | dropout = 0.0 89 | acitvation = "gelu" 90 | norm_first = True 91 | nlayer = 12 92 | ff_ratio = 4 93 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 94 | codebook_tokens = 8192 95 | 96 | img_size = 448 97 | 98 | max_seq_len = (img_size // patch_size) ** 2 99 | 100 | backbone = ImgLinearBackbone(d_model=d_model, patch_size=patch_size) 101 | encoder = Encoder( 102 | d_model=d_model, 103 | nhead=nhead, 104 | dropout=dropout, 105 | activation=acitvation, 106 | norm_first=norm_first, 107 | nlayer=nlayer, 108 | ff_ratio=ff_ratio, 109 | ) 110 | 111 | model = BeitEncoder( 112 | d_model=d_model, 113 | backbone=backbone, 114 | max_seq_len=max_seq_len, 115 | codebook_tokens=codebook_tokens, 116 | dropout=dropout, 117 | encoder=encoder, 118 | norm_layer=norm_layer, 119 | ) 120 | 121 | print(model) 122 | 123 | x = torch.rand((1, 3, img_size, img_size)) 124 | bool_masked_pos = torch.rand((1, (img_size // patch_size) ** 2)) < 0.5 125 | y = model(x, bool_masked_pos) 126 | print(torch.sum(bool_masked_pos)) 127 | print(y.shape) 128 | -------------------------------------------------------------------------------- /src/model/components.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import torch 3 | from torch import nn, Tensor 4 | from torchvision.ops.misc import Conv2dNormActivation 5 | 6 | 7 | __all__ = [ 8 | "ImgCnnBackbone", 9 | "ImgLinearBackbone", 10 | "ImgConvStemBackbone", 11 | "PositionEmbedding", 12 | "Encoder", 13 | "Decoder", 14 | "TokenEmbedding", 15 | ] 16 | 17 | 18 | class ImgCnnBackbone(nn.Module): 19 | def __init__( 20 | self, 21 | backbone: nn.Module, 22 | output_channels: int, 23 | d_model: int, 24 | drop_layer: Tuple = None, 25 | ) -> None: 26 | super().__init__() 27 | 28 | # drop layers for classification & maxpooling for higher feature resolution 29 | layers = list(backbone.children()) 30 | nlayer = len(layers) 31 | keep_layer = set([i for i in range(nlayer)]) - set(drop_layer) 32 | backbone = [layers[i] for i in keep_layer] 33 | self.backbone = nn.Sequential(*backbone) 34 | self.proj = nn.Linear(output_channels, d_model) 35 | self.channels = output_channels 36 | 37 | def forward(self, x: Tensor) -> Tensor: 38 | x = self.backbone(x) 39 | x = x.flatten(start_dim=-2).transpose(1, 2) 40 | assert x.shape[-1] == self.channels, "Image channels size mismatch." 41 | x = self.proj(x) 42 | return x 43 | 44 | 45 | class ImgLinearBackbone(nn.Module): 46 | def __init__( 47 | self, 48 | d_model: int, 49 | patch_size: int, 50 | in_chan: int = 3, 51 | ) -> None: 52 | super().__init__() 53 | 54 | self.conv_proj = nn.Conv2d( 55 | in_chan, out_channels=d_model, kernel_size=patch_size, stride=patch_size 56 | ) 57 | self.d_model = d_model 58 | 59 | def forward(self, x: Tensor) -> Tensor: 60 | x = self.conv_proj(x) 61 | x = x.flatten(start_dim=-2).transpose(1, 2) 62 | return x 63 | 64 | 65 | class ImgConvStemBackbone(nn.Module): 66 | def __init__( 67 | self, 68 | d_model: int, 69 | downsample_factor: int, 70 | output_channels: int, 71 | kernel_size: int, 72 | ) -> None: 73 | super().__init__() 74 | 75 | assert downsample_factor % 2 == 0 76 | assert output_channels % (downsample_factor // 2) == 0 77 | input_channels = output_channels // (downsample_factor // 2) 78 | 79 | layers = [ 80 | Conv2dNormActivation( 81 | 3, input_channels, kernel_size=kernel_size, stride=2, padding=1 82 | ) 83 | ] 84 | 85 | while input_channels != output_channels: 86 | layers.append( 87 | Conv2dNormActivation( 88 | input_channels, 89 | input_channels * 2, 90 | kernel_size=kernel_size, 91 | stride=2, 92 | padding=1, 93 | ) 94 | ) 95 | input_channels = input_channels * 2 96 | 97 | layers.append(nn.Conv2d(output_channels, d_model, kernel_size=1)) 98 | 99 | self.conv_stem = nn.Sequential(*layers) 100 | 101 | def forward(self, x: Tensor) -> Tensor: 102 | x = self.conv_stem(x) 103 | x = x.flatten(start_dim=-2).transpose(1, 2) 104 | return x 105 | 106 | 107 | class Encoder(nn.Module): 108 | def __init__( 109 | self, 110 | d_model: int, 111 | nhead: int, 112 | dropout: float, 113 | activation: str, 114 | norm_first: bool, 115 | nlayer: int, 116 | ff_ratio: int = 4, 117 | ) -> None: 118 | super().__init__() 119 | 120 | encoder_layer = nn.TransformerEncoderLayer( 121 | d_model, 122 | nhead=nhead, 123 | dim_feedforward=ff_ratio * d_model, 124 | dropout=dropout, 125 | activation=activation, 126 | batch_first=True, 127 | norm_first=norm_first, 128 | ) 129 | 130 | self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=nlayer) 131 | 132 | def forward(self, x: Tensor) -> Tensor: 133 | x = self.encoder(x) 134 | return x 135 | 136 | 137 | class Decoder(nn.Module): 138 | def __init__( 139 | self, 140 | d_model: int, 141 | nhead: int, 142 | dropout: float, 143 | activation: str, 144 | norm_first: bool, 145 | nlayer: int, 146 | ff_ratio: int = 4, 147 | ) -> None: 148 | super().__init__() 149 | decoder_layer = nn.TransformerDecoderLayer( 150 | d_model, 151 | nhead, 152 | dim_feedforward=ff_ratio * d_model, 153 | dropout=dropout, 154 | activation=activation, 155 | batch_first=True, 156 | norm_first=norm_first, 157 | ) 158 | 159 | self.decoder = nn.TransformerDecoder(decoder_layer, nlayer) 160 | 161 | def forward( 162 | self, x: Tensor, memory: Tensor, tgt_mask: Tensor, tgt_padding_mask: Tensor 163 | ) -> Tensor: 164 | x = self.decoder( 165 | x, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_padding_mask 166 | ) 167 | return x 168 | 169 | 170 | class PositionEmbedding(nn.Module): 171 | def __init__(self, max_seq_len: int, d_model: int, dropout: float) -> None: 172 | super().__init__() 173 | self.embedding = nn.Embedding(max_seq_len, d_model) 174 | self.dropout = nn.Dropout(dropout) 175 | 176 | def forward(self, x: Tensor) -> Tensor: 177 | # assume x is batch first 178 | out = self.embedding(torch.arange(x.shape[1], device=x.device)) 179 | return self.dropout(out + x) 180 | 181 | 182 | class TokenEmbedding(nn.Module): 183 | def __init__( 184 | self, 185 | vocab_size: int, 186 | d_model: int, 187 | padding_idx: int, 188 | ) -> None: 189 | super().__init__() 190 | assert vocab_size > 0 191 | self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) 192 | 193 | def forward(self, x: Tensor) -> Tensor: 194 | return self.embedding(x) 195 | 196 | 197 | class PrintLayer(nn.Module): 198 | """Only for debugging when loss is nan.""" 199 | 200 | def __init__(self): 201 | super().__init__() 202 | 203 | def forward(self, x): 204 | print( 205 | "torch.isfinite(x).all(): {}, min. {:.5f}, max. {:.5f}".format( 206 | torch.isfinite(x).all(), x.min(), x.max() 207 | ) 208 | ) 209 | return x 210 | 211 | 212 | if __name__ == "__main__": 213 | from torchvision import models 214 | 215 | x = torch.rand(1, 3, 392, 392) 216 | model = ImgConvStemBackbone( 217 | d_model=512, downsample_factor=16, output_channels=64, kernel_size=5 218 | ) 219 | y = model(x) 220 | print(model) 221 | print(y.shape) 222 | 223 | model = ImgCnnBackbone( 224 | backbone=models.resnet34(), 225 | output_channels=512, 226 | d_model=512, 227 | drop_layer=(3, 8, 9), 228 | ) 229 | 230 | # print(model) 231 | y = model(x) 232 | print(y.shape) 233 | -------------------------------------------------------------------------------- /src/model/encoderdecoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | from functools import partial 4 | 5 | from src.model.components import ( 6 | ImgCnnBackbone, 7 | ImgLinearBackbone, 8 | ImgConvStemBackbone, 9 | Encoder, 10 | Decoder, 11 | PositionEmbedding, 12 | TokenEmbedding, 13 | ) 14 | 15 | 16 | class EncoderDecoder(nn.Module): 17 | """Encoder decoder architecture that takes in a tabular image and generates the text output. 18 | Backbone serves as the image processor. There are three types of backbones: CNN, linear projection, and ConvStem. 19 | 20 | Args: 21 | ---- 22 | backbone: tabular image processor 23 | encoder: transformer encoder 24 | decoder: transformer decoder 25 | vocab_size: size of the vocabulary 26 | d_model: feature size 27 | padding_idx: index of in the vocabulary 28 | max_seq_len: max sequence length of generated text 29 | dropout: dropout rate 30 | norm_layer: layernorm 31 | init_std: std in weights initialization 32 | """ 33 | 34 | def __init__( 35 | self, 36 | backbone: nn.Module, 37 | encoder: nn.Module, 38 | decoder: nn.Module, 39 | vocab_size: int, 40 | d_model: int, 41 | padding_idx: int, 42 | max_seq_len: int, 43 | dropout: float, 44 | norm_layer: nn.Module, 45 | init_std: float = 0.02, 46 | ): 47 | super().__init__() 48 | 49 | self.backbone = backbone 50 | self.encoder = encoder 51 | self.decoder = decoder 52 | self.norm = norm_layer(d_model) 53 | self.token_embed = TokenEmbedding( 54 | vocab_size=vocab_size, d_model=d_model, padding_idx=padding_idx 55 | ) 56 | self.pos_embed = PositionEmbedding( 57 | max_seq_len=max_seq_len, d_model=d_model, dropout=dropout 58 | ) 59 | self.generator = nn.Linear(d_model, vocab_size) 60 | 61 | self.trunc_normal = partial( 62 | nn.init.trunc_normal_, std=init_std, a=-init_std, b=init_std 63 | ) 64 | self.apply(self._init_weights) 65 | 66 | def _init_weights(self, m: nn.Module): 67 | if isinstance(m, nn.Linear): 68 | self.trunc_normal(m.weight) 69 | if m.bias is not None: 70 | nn.init.constant_(m.bias, 0.0) 71 | elif isinstance(m, nn.LayerNorm): 72 | nn.init.constant_(m.weight, 1.0) 73 | nn.init.constant_(m.bias, 0.0) 74 | elif isinstance(m, nn.Conv2d): 75 | self.trunc_normal(m.weight) 76 | if m.bias is not None: 77 | nn.init.constant_(m.bias, 0.0) 78 | elif isinstance(m, PositionEmbedding): 79 | self.trunc_normal(m.embedding.weight) 80 | elif isinstance(m, TokenEmbedding): 81 | self.trunc_normal(m.embedding.weight) 82 | 83 | @torch.jit.ignore 84 | def no_weight_decay(self): 85 | return {"token_embed", "pos_embed"} 86 | 87 | def encode(self, src: Tensor) -> Tensor: 88 | src_feature = self.backbone(src) 89 | src_feature = self.pos_embed(src_feature) 90 | memory = self.encoder(src_feature) 91 | memory = self.norm(memory) 92 | return memory 93 | 94 | def decode( 95 | self, memory: Tensor, tgt: Tensor, tgt_mask: Tensor, tgt_padding_mask: Tensor 96 | ) -> Tensor: 97 | tgt_feature = self.pos_embed(self.token_embed(tgt)) 98 | tgt = self.decoder(tgt_feature, memory, tgt_mask, tgt_padding_mask) 99 | 100 | return tgt 101 | 102 | def forward( 103 | self, src: Tensor, tgt: Tensor, tgt_mask: Tensor, tgt_padding_mask: Tensor 104 | ) -> Tensor: 105 | memory = self.encode(src) 106 | tgt = self.decode(memory, tgt, tgt_mask, tgt_padding_mask) 107 | tgt = self.generator(tgt) 108 | 109 | return tgt 110 | -------------------------------------------------------------------------------- /src/model/vqvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor, einsum 3 | from typing import Optional, Tuple 4 | import math 5 | from functools import partial 6 | from collections import OrderedDict 7 | import torch.nn.functional as F 8 | from einops import rearrange 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def default(val, d): 16 | return val if exists(val) else d 17 | 18 | 19 | def eval_decorator(fn): 20 | def inner(model, *args, **kwargs): 21 | was_training = model.training 22 | model.eval() 23 | out = fn(model, *args, **kwargs) 24 | model.train(was_training) 25 | return out 26 | 27 | return inner 28 | 29 | 30 | class ResBlock(nn.Module): 31 | def __init__(self, chan_in, hidden_size, chan_out): 32 | super().__init__() 33 | self.net = nn.Sequential( 34 | nn.Conv2d(chan_in, hidden_size, 3, padding=1), 35 | nn.ReLU(), 36 | nn.Conv2d(hidden_size, hidden_size, 3, padding=1), 37 | nn.ReLU(), 38 | nn.Conv2d(hidden_size, chan_out, 1), 39 | ) 40 | 41 | def forward(self, x): 42 | return self.net(x) + x 43 | 44 | 45 | class BasicVAE(nn.Module): 46 | def get_codebook_indices(self, images): 47 | raise NotImplementedError() 48 | 49 | def decode(self, img_seq): 50 | raise NotImplementedError() 51 | 52 | def get_codebook_probs(self, img_seq): 53 | raise NotImplementedError() 54 | 55 | def get_image_tokens_size(self): 56 | pass 57 | 58 | def get_image_size(self): 59 | pass 60 | 61 | 62 | class DiscreteVAE(BasicVAE): 63 | def __init__( 64 | self, 65 | image_size: Tuple[int, int] = [256, 256], # input image size 66 | codebook_tokens: int = 512, # codebook vocab size 67 | codebook_dim: int = 512, # codebook embedding dimension 68 | num_layers: int = 3, # layers of resnet blocks in encoder/decoder 69 | hidden_dim: int = 64, # dimension in resnet blocks 70 | channels: int = 3, # input channels 71 | smooth_l1_loss: bool = False, # prevents exploding gradients 72 | temperature: float = 0.9, # tau in gumbel softmax 73 | straight_through: bool = False, # if True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd 74 | kl_div_loss_weight: float = 0.0, 75 | ): 76 | super().__init__() 77 | assert num_layers >= 1, "number of layers must be greater than or equal to 1" 78 | 79 | self.image_size = image_size 80 | self.codebook_tokens = codebook_tokens 81 | self.num_layers = num_layers 82 | self.temperature = temperature 83 | self.straight_through = straight_through 84 | self.codebook = nn.Embedding(codebook_tokens, codebook_dim) 85 | 86 | encoder_layers = list() 87 | decoder_layers = list() 88 | 89 | encoder_in = channels 90 | decoder_in = codebook_dim 91 | 92 | for _ in range(num_layers): 93 | encoder_layers.append( 94 | nn.Sequential( 95 | nn.Conv2d(encoder_in, hidden_dim, 4, stride=2, padding=1), nn.ReLU() 96 | ) 97 | ) 98 | encoder_layers.append( 99 | ResBlock( 100 | chan_in=hidden_dim, hidden_size=hidden_dim, chan_out=hidden_dim 101 | ) 102 | ) 103 | encoder_in = hidden_dim 104 | 105 | decoder_layers.append( 106 | nn.Sequential( 107 | nn.ConvTranspose2d(decoder_in, hidden_dim, 4, stride=2, padding=1), 108 | nn.ReLU(), 109 | ) 110 | ) 111 | decoder_layers.append( 112 | ResBlock( 113 | chan_in=hidden_dim, hidden_size=hidden_dim, chan_out=hidden_dim 114 | ) 115 | ) 116 | decoder_in = hidden_dim 117 | 118 | encoder_layers.append(nn.Conv2d(hidden_dim, codebook_tokens, 1)) 119 | decoder_layers.append(nn.Conv2d(hidden_dim, channels, 1)) 120 | 121 | self.encoder = nn.Sequential(*encoder_layers) 122 | self.decoder = nn.Sequential(*decoder_layers) 123 | 124 | self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss 125 | self.kl_div_loss_weight = kl_div_loss_weight 126 | 127 | def get_image_size(self): 128 | return self.image_size 129 | 130 | def get_image_tokens_size(self) -> int: 131 | ds_ratio = math.pow(2, self.num_layers) 132 | return int((self.image_size[0] // ds_ratio) * (self.image_size[1] // ds_ratio)) 133 | 134 | @torch.no_grad() 135 | @eval_decorator 136 | def get_codebook_indices(self, images: Tensor): 137 | logits = self.forward(images, return_logits=True) 138 | codebook_indices = logits.argmax(dim=1) 139 | return codebook_indices 140 | 141 | @torch.no_grad() 142 | @eval_decorator 143 | def get_codebook_probs(self, images: Tensor): 144 | logits = self.forward(images, return_logits=True) 145 | return nn.Softmax(dim=1)(logits) 146 | 147 | def decode(self, img_seq: Tensor): 148 | image_embeds = self.codebook(img_seq) 149 | image_embeds = image_embeds.permute((0, 3, 1, 2)).contiguous() 150 | 151 | # image_embeds = rearrange(image_embeds, "b h w d -> b d h w", h=h, w=w) 152 | images = self.decoder(image_embeds) 153 | return images 154 | 155 | def forward( 156 | self, 157 | img: Tensor, 158 | return_loss: bool = False, 159 | return_recons: bool = False, 160 | return_logits: bool = False, 161 | temp=None, 162 | ) -> Tuple[Tensor, Optional[Tensor]]: 163 | assert ( 164 | img.shape[-1] == self.image_size[0] and img.shape[-2] == self.image_size[1] 165 | ), f"input must have the correct image size {self.image_size}" 166 | 167 | logits = self.encoder(img) 168 | 169 | if return_logits: 170 | return logits # return logits for getting hard image indices for DALL-E training 171 | 172 | temp = default(temp, self.temperature) 173 | soft_one_hot = F.gumbel_softmax( 174 | logits, tau=temp, dim=1, hard=self.straight_through 175 | ) 176 | sampled = einsum( 177 | "b n h w, n d -> b d h w", soft_one_hot, self.codebook.weight 178 | ).contiguous() 179 | out = self.decoder(sampled) 180 | 181 | if not return_loss: 182 | return out 183 | 184 | # reconstruction loss 185 | recon_loss = self.loss_fn(img, out) 186 | 187 | # kl divergence 188 | logits = rearrange(logits, "b n h w -> b (h w) n").contiguous() 189 | qy = F.softmax(logits, dim=-1) 190 | 191 | log_qy = torch.log(qy + 1e-10) 192 | log_uniform = torch.log( 193 | torch.tensor([1.0 / self.codebook_tokens], device=img.device) 194 | ) 195 | kl_div = F.kl_div(log_uniform, log_qy, None, None, "batchmean", log_target=True) 196 | 197 | loss = recon_loss + (kl_div * self.kl_div_loss_weight) 198 | 199 | if not return_recons: 200 | return loss 201 | 202 | return loss, out 203 | 204 | 205 | if __name__ == "__main__": 206 | input = torch.rand(1, 3, 256, 256) 207 | model = DiscreteVAE() 208 | loss, output = model(input, return_loss=True, return_recons=True) 209 | 210 | print(model) 211 | print(model.get_image_tokens_size()) 212 | print(model.get_codebook_indices(input).shape) 213 | print(loss, output.shape, output.max(), output.min()) 214 | -------------------------------------------------------------------------------- /src/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_beit import BeitTrainer 2 | from .train_vqvae import VqvaeTrainer 3 | from .train_table import TableTrainer -------------------------------------------------------------------------------- /src/trainer/train_beit.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from pathlib import Path 3 | from typing import Tuple, List, Union, Dict 4 | from omegaconf import DictConfig 5 | from hydra.utils import instantiate 6 | import logging 7 | import torch 8 | import time 9 | from torch import nn, Tensor, autograd 10 | from torch.utils.data import DataLoader 11 | from torch.nn.parallel import DistributedDataParallel as DDP 12 | 13 | from src.utils import printer, compute_grad_norm 14 | from src.trainer.utils import configure_optimizer_weight_decay 15 | 16 | SNAPSHOT_KEYS = set(["EPOCH", "STEP", "OPTIMIZER", "LR_SCHEDULER", "MODEL", "LOSS"]) 17 | 18 | 19 | class BeitTrainer: 20 | def __init__( 21 | self, 22 | device: int, 23 | model: nn.Module, 24 | model_vqvae: nn.Module, 25 | log: logging.Logger, 26 | exp_dir: Path, 27 | snapshot: Path = None, 28 | model_weights: Path = None, # only for testing 29 | ) -> None: 30 | self.device = device 31 | self.log = log 32 | self.exp_dir = exp_dir 33 | self.criterion = nn.CrossEntropyLoss() 34 | assert ( 35 | snapshot is None or model_weights is None 36 | ), "Snapshot and model weights cannot be set at the same time." 37 | 38 | self.model = model 39 | if snapshot is not None and snapshot.is_file(): 40 | self.snapshot = self.load_snapshot(snapshot) 41 | self.model.load_state_dict(self.snapshot["MODEL"]) 42 | self.start_epoch = self.snapshot["EPOCH"] 43 | self.global_step = self.snapshot["STEP"] 44 | elif model_weights is not None and model_weights.is_file(): 45 | self.load_model(model_weights) 46 | else: 47 | self.snapshot = None 48 | self.start_epoch = 0 49 | self.global_step = 0 50 | 51 | self.model = self.model.to(device) 52 | self.model = DDP(self.model, device_ids=[device]) 53 | self.model_vqvae = model_vqvae.to(device) 54 | 55 | # https://discuss.pytorch.org/t/extra-10gb-memory-on-gpu-0-in-ddp-tutorial/118113 56 | torch.cuda.set_device(device) # master gpu takes up extra memory 57 | torch.cuda.empty_cache() 58 | 59 | def train_epoch(self, epoch: int, grad_clip: float = None): 60 | start = time.time() 61 | total_loss = 0.0 62 | total_samples = 0 63 | 64 | for i, obj in enumerate(self.train_dataloader): 65 | (trans_image, vqvae_image), bool_mask_pos = obj 66 | trans_image, vqvae_image, bool_mask_pos = ( 67 | trans_image.to(self.device), 68 | vqvae_image.to(self.device), 69 | bool_mask_pos.to(self.device), 70 | ) 71 | 72 | with torch.no_grad(): 73 | input_ids = self.model_vqvae.get_codebook_indices(vqvae_image).flatten( 74 | 1 75 | ) 76 | bool_mask_pos = bool_mask_pos.flatten(1).to(torch.bool) 77 | labels = input_ids[bool_mask_pos] 78 | 79 | with autograd.detect_anomaly(): 80 | outputs = self.model( 81 | trans_image, bool_mask_pos, return_all_tokens=False 82 | ) 83 | loss = self.criterion(outputs, labels) 84 | 85 | self.optimizer.zero_grad() 86 | loss.backward() 87 | if grad_clip: 88 | nn.utils.clip_grad_norm_( 89 | self.model.parameters(), max_norm=grad_clip 90 | ) 91 | self.optimizer.step() 92 | 93 | loss = loss.detach().cpu().data 94 | total_loss += loss * trans_image.shape[0] 95 | total_samples += trans_image.shape[0] 96 | 97 | self.lr_scheduler.step() 98 | self.global_step += 1 99 | 100 | if i % 10 == 0: 101 | grad_norm = compute_grad_norm(self.model) 102 | lr = self.optimizer.param_groups[0]["lr"] 103 | elapsed = time.time() - start 104 | self.log.info( 105 | printer( 106 | self.device, 107 | f"Epoch {epoch} Step {i + 1}/{len(self.train_dataloader)} | Loss {loss:.4f} ({total_loss / total_samples:.4f}) | Grad norm {grad_norm:.3f} | {total_samples / elapsed:4.1f} images/s | lr {lr:5.1e}", 108 | ) 109 | ) 110 | 111 | if i % 100 == 0 and self.device == 0: 112 | lr = self.optimizer.param_groups[0]["lr"] 113 | log_info = { 114 | "epoch": epoch, 115 | "train_loss": loss, 116 | "learning rate": lr, 117 | "grad_norm": grad_norm, 118 | } 119 | 120 | wandb.log( 121 | log_info, 122 | step=self.global_step, 123 | ) 124 | 125 | return total_loss / total_samples 126 | 127 | def train( 128 | self, 129 | train_dataloader: DataLoader, 130 | valid_dataloader: DataLoader, 131 | train_cfg: DictConfig, 132 | valid_cfg: DictConfig, 133 | ): 134 | self.train_dataloader = train_dataloader 135 | self.valid_dataloader = valid_dataloader 136 | 137 | # ensure correct weight decay: https://github.com/karpathy/minGPT/blob/37baab71b9abea1b76ab957409a1cc2fbfba8a26/mingpt/model.py#L215 138 | optim_params = configure_optimizer_weight_decay( 139 | self.model.module, weight_decay=train_cfg.optimizer.weight_decay 140 | ) 141 | self.optimizer = instantiate(train_cfg.optimizer, optim_params) 142 | 143 | self.lr_scheduler = instantiate( 144 | train_cfg.lr_scheduler, optimizer=self.optimizer 145 | ) 146 | 147 | if self.snapshot is not None: 148 | self.optimizer.load_state_dict(self.snapshot["OPTIMIZER"]) 149 | self.lr_scheduler.load_state_dict(self.snapshot["LR_SCHEDULER"]) 150 | 151 | best_loss = float("inf") 152 | self.model.train() 153 | for epoch in range(self.start_epoch, train_cfg.epochs): 154 | train_dataloader.sampler.set_epoch(epoch) 155 | train_loss = self.train_epoch(epoch, grad_clip=train_cfg.grad_clip) 156 | 157 | torch.cuda.empty_cache() 158 | 159 | valid_loss = self.valid(valid_cfg) 160 | 161 | if self.device == 0: 162 | wandb.log( 163 | { 164 | "train loss (epoch)": train_loss, 165 | "valid loss (epoch)": valid_loss, 166 | }, 167 | step=self.global_step, 168 | ) 169 | 170 | if epoch % train_cfg.save_every == 0: 171 | self.save_snapshot(epoch, best_loss) 172 | if valid_loss < best_loss: 173 | self.save_model(epoch) 174 | best_loss = valid_loss 175 | 176 | def valid(self, cfg: DictConfig): 177 | total_samples = 0 178 | total_loss = 0.0 179 | 180 | self.model.eval() 181 | for i, obj in enumerate(self.valid_dataloader): 182 | (trans_image, vqvae_image), bool_mask_pos = obj 183 | trans_image, vqvae_image, bool_mask_pos = ( 184 | trans_image.to(self.device), 185 | vqvae_image.to(self.device), 186 | bool_mask_pos.to(self.device), 187 | ) 188 | 189 | with torch.no_grad(): 190 | input_ids = self.model_vqvae.get_codebook_indices(vqvae_image).flatten( 191 | 1 192 | ) 193 | bool_mask_pos = bool_mask_pos.flatten(1).to(torch.bool) 194 | labels = input_ids[bool_mask_pos] 195 | 196 | outputs = self.model( 197 | trans_image, bool_mask_pos, return_all_tokens=False 198 | ) 199 | loss = self.criterion(outputs, labels) 200 | 201 | loss = loss.detach().cpu().data 202 | total_loss += loss * trans_image.shape[0] 203 | total_samples += trans_image.shape[0] 204 | 205 | if i % 10 == 0: 206 | self.log.info( 207 | printer( 208 | self.device, 209 | f"Valid: Step {i + 1}/{len(self.valid_dataloader)} | Loss {loss:.4f} ({total_loss / total_samples:.4f})", 210 | ) 211 | ) 212 | 213 | return total_loss / total_samples 214 | 215 | def save_model(self, epoch: int): 216 | filename = Path(self.exp_dir) / "model" / f"epoch{epoch}_model.pt" 217 | torch.save(self.model.module.state_dict(), filename) 218 | self.log.info(printer(self.device, f"Saving model to {filename}")) 219 | filename = Path(self.exp_dir) / "model" / f"best.pt" 220 | torch.save(self.model.module.state_dict(), filename) 221 | 222 | def load_model(self, path: Union[str, Path]): 223 | self.model.load_state_dict(torch.load(path, map_location="cpu")) 224 | self.log.info(printer(self.device, f"Loading model from {path}")) 225 | 226 | def save_snapshot(self, epoch: int, best_loss: float): 227 | state_info = { 228 | "EPOCH": epoch + 1, 229 | "STEP": self.global_step, 230 | "OPTIMIZER": self.optimizer.state_dict(), 231 | "LR_SCHEDULER": self.lr_scheduler.state_dict(), 232 | "MODEL": self.model.module.state_dict(), 233 | "LOSS": best_loss, 234 | } 235 | 236 | snapshot_path = Path(self.exp_dir) / "snapshot" / f"epoch{epoch}_snapshot.pt" 237 | torch.save(state_info, snapshot_path) 238 | 239 | self.log.info(printer(self.device, f"Saving snapshot to {snapshot_path}")) 240 | 241 | def load_snapshot(self, path: Path): 242 | self.log.info(printer(self.device, f"Loading snapshot from {path}")) 243 | snapshot = torch.load(path, map_location="cpu") 244 | assert SNAPSHOT_KEYS.issubset(snapshot.keys()) 245 | return snapshot 246 | -------------------------------------------------------------------------------- /src/trainer/train_table.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List, Union, Dict, Optional 2 | import torch 3 | import wandb 4 | import json 5 | import os 6 | from torch import nn, Tensor, autograd 7 | from torch.utils.data import DataLoader 8 | from omegaconf import DictConfig 9 | from hydra.utils import instantiate 10 | import logging 11 | from pathlib import Path 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | import tokenizers as tk 14 | import torch.nn.functional as F 15 | 16 | from src.trainer.utils import ( 17 | Batch, 18 | configure_optimizer_weight_decay, 19 | turn_off_beit_grad, 20 | VALID_HTML_TOKEN, 21 | INVALID_CELL_TOKEN, 22 | VALID_BBOX_TOKEN, 23 | ) 24 | from src.utils import ( 25 | printer, 26 | compute_grad_norm, 27 | count_total_parameters, 28 | batch_autoregressive_decode, 29 | combine_filename_pred_gt, 30 | ) 31 | 32 | SNAPSHOT_KEYS = set(["EPOCH", "STEP", "OPTIMIZER", "LR_SCHEDULER", "MODEL", "LOSS"]) 33 | 34 | 35 | class TableTrainer: 36 | """A trainer for table recognition. The supported tasks are: 37 | 1) table structure extraction 38 | 2) table cell bbox detection 39 | 3) table cell content recognition 40 | 41 | Args: 42 | ---- 43 | device: gpu id 44 | vocab: a vocab shared among all tasks 45 | model: model architecture 46 | log: logger 47 | exp_dir: the experiment directory that saves logs, wandb files, model weights, and checkpoints (snapshots) 48 | snapshot: specify which snapshot to use, only used in training 49 | model_weights: specify which model weight to use, only used in testing 50 | beit_pretrained_weights: load SSL pretrained visual encoder 51 | freeze_beit_epoch: freeze beit weights for the first {freeze_beit_epoch} epochs 52 | """ 53 | 54 | def __init__( 55 | self, 56 | device: int, 57 | vocab: tk.Tokenizer, 58 | model: nn.Module, 59 | log: logging.Logger, 60 | exp_dir: Path, 61 | snapshot: Path = None, 62 | model_weights: str = None, 63 | beit_pretrained_weights: str = None, 64 | freeze_beit_epoch: int = None, 65 | ) -> None: 66 | self.device = device 67 | self.log = log 68 | self.exp_dir = exp_dir 69 | self.vocab = vocab 70 | self.padding_idx = vocab.token_to_id("") 71 | self.freeze_beit_epoch = freeze_beit_epoch 72 | 73 | # loss for training html, cell 74 | self.criterion = nn.CrossEntropyLoss(ignore_index=self.padding_idx) 75 | 76 | self.model = model 77 | 78 | if ( 79 | beit_pretrained_weights is not None 80 | and Path(beit_pretrained_weights).is_file() 81 | ): 82 | self.load_pretrained_beit(Path(beit_pretrained_weights)) 83 | 84 | assert ( 85 | snapshot is None or model_weights is None 86 | ), "Cannot set snapshot and model_weights at the same time!" 87 | 88 | if snapshot is not None and snapshot.is_file(): 89 | self.snapshot = self.load_snapshot(snapshot) 90 | self.model.load_state_dict(self.snapshot["MODEL"]) 91 | self.start_epoch = self.snapshot["EPOCH"] 92 | self.global_step = self.snapshot["STEP"] 93 | elif model_weights is not None and Path(model_weights).is_file(): 94 | self.load_model(Path(model_weights)) 95 | else: 96 | self.snapshot = None 97 | self.start_epoch = 0 98 | self.global_step = 0 99 | 100 | if freeze_beit_epoch and freeze_beit_epoch > 0: 101 | self._freeze_beit() 102 | 103 | self.model = self.model.to(device) 104 | self.model = DDP(self.model, device_ids=[device]) 105 | 106 | # https://discuss.pytorch.org/t/extra-10gb-memory-on-gpu-0-in-ddp-tutorial/118113 107 | torch.cuda.set_device(device) # master gpu takes up extra memory 108 | torch.cuda.empty_cache() 109 | 110 | def _freeze_beit(self): 111 | if self.start_epoch < self.freeze_beit_epoch: 112 | turn_off_beit_grad(self.model) 113 | self.log.info( 114 | printer( 115 | self.device, 116 | f"Lock SSL params for {self.freeze_beit_epoch} epochs (params: {count_total_parameters(self.model) / 1e6:.2f}M) - Current epoch {self.start_epoch + 1}", 117 | ) 118 | ) 119 | else: 120 | self.log.info( 121 | printer( 122 | self.device, 123 | f"Unlock all weights (params: {count_total_parameters(self.model) / 1e6:.2f}M) - Current epoch {self.start_epoch + 1}", 124 | ) 125 | ) 126 | 127 | def train_epoch( 128 | self, 129 | epoch: int, 130 | target: str, 131 | loss_weights: List[float], 132 | grad_clip: float = None, 133 | ): 134 | avg_loss = 0.0 135 | 136 | # load data from dataloader 137 | for i, obj in enumerate(self.train_dataloader): 138 | batch = Batch(device=self.device, target=target, vocab=self.vocab, obj=obj) 139 | 140 | with autograd.detect_anomaly(): 141 | loss, _ = batch.inference( 142 | self.model, 143 | criterion=self.criterion, 144 | criterion_bbox=self.criterion_bbox, 145 | loss_weights=loss_weights, 146 | ) 147 | 148 | total_loss = loss["total"] 149 | 150 | self.optimizer.zero_grad() 151 | total_loss.backward() 152 | if grad_clip: 153 | nn.utils.clip_grad_norm_( 154 | self.model.parameters(), max_norm=grad_clip 155 | ) 156 | self.optimizer.step() 157 | 158 | total_loss = total_loss.detach().cpu().data 159 | avg_loss += total_loss 160 | self.lr_scheduler.step() 161 | self.global_step += 1 162 | 163 | if i % 10 == 0: 164 | grad_norm = compute_grad_norm(self.model) 165 | lr = self.optimizer.param_groups[0]["lr"] 166 | # elapsed = time.time() - start 167 | 168 | loss_info = f"Loss {total_loss:.3f} ({avg_loss / (i + 1):.3f})" 169 | if not isinstance(loss["html"], int): 170 | loss_info += f" Html {loss['html'].detach().cpu().data:.3f}" 171 | if not isinstance(loss["cell"], int): 172 | loss_info += f" Cell {loss['cell'].detach().cpu().data:.3f}" 173 | if not isinstance(loss["bbox"], int): 174 | loss_info += f" Bbox {loss['bbox'].detach().cpu().data:.3f}" 175 | self.log.info( 176 | printer( 177 | self.device, 178 | f"Epoch {epoch} Step {i + 1}/{len(self.train_dataloader)} | {loss_info} | Grad norm {grad_norm:.3f} | lr {lr:5.1e}", 179 | ) 180 | ) 181 | 182 | if i % 100 == 0 and self.device == 0: 183 | log_info = { 184 | "epoch": epoch, 185 | "train_total_loss": total_loss, 186 | "learning rate": lr, 187 | "grad_norm": grad_norm, 188 | } 189 | 190 | wandb.log( 191 | log_info, 192 | step=self.global_step, 193 | ) 194 | 195 | def train( 196 | self, 197 | train_dataloader: DataLoader, 198 | valid_dataloader: DataLoader, 199 | train_cfg: DictConfig, 200 | valid_cfg: DictConfig, 201 | ): 202 | self.train_dataloader = train_dataloader 203 | self.valid_dataloader = valid_dataloader 204 | 205 | # ensure correct weight decay: https://github.com/karpathy/minGPT/blob/37baab71b9abea1b76ab957409a1cc2fbfba8a26/mingpt/model.py#L215 206 | optim_params = configure_optimizer_weight_decay( 207 | self.model.module, weight_decay=train_cfg.optimizer.weight_decay 208 | ) 209 | 210 | self.optimizer = instantiate(train_cfg.optimizer, optim_params) 211 | 212 | self.lr_scheduler = instantiate( 213 | train_cfg.lr_scheduler, optimizer=self.optimizer 214 | ) 215 | 216 | if self.snapshot is not None: 217 | self.optimizer.load_state_dict(self.snapshot["OPTIMIZER"]) 218 | self.lr_scheduler.load_state_dict(self.snapshot["LR_SCHEDULER"]) 219 | 220 | self.criterion_bbox = None 221 | if "bbox" in train_cfg.target: 222 | tmp = [ 223 | self.vocab.token_to_id(i) 224 | for i in VALID_BBOX_TOKEN[ 225 | : train_cfg.img_size[0] + 2 226 | ] # +1 for +1 for bbox == img_size 227 | ] 228 | tmp = [1.0 if i in tmp else 0.0 for i in range(self.vocab.get_vocab_size())] 229 | self.criterion_bbox = nn.CrossEntropyLoss( 230 | weight=torch.tensor(tmp, device=self.device), 231 | ignore_index=self.padding_idx, 232 | ) 233 | 234 | best_loss = float("inf") 235 | self.model.train() 236 | 237 | if self.freeze_beit_epoch and self.start_epoch < self.freeze_beit_epoch: 238 | max_epoch = self.freeze_beit_epoch 239 | else: 240 | max_epoch = train_cfg.epochs 241 | for epoch in range(self.start_epoch, max_epoch): 242 | train_dataloader.sampler.set_epoch(epoch) 243 | 244 | self.train_epoch( 245 | epoch, 246 | grad_clip=train_cfg.grad_clip, 247 | target=train_cfg.target, 248 | loss_weights=train_cfg.loss_weights, 249 | ) 250 | 251 | torch.cuda.empty_cache() 252 | 253 | valid_loss = self.valid(valid_cfg) 254 | 255 | if self.device == 0: 256 | wandb.log( 257 | {"valid loss (epoch)": valid_loss}, 258 | step=self.global_step, 259 | ) 260 | 261 | if epoch % train_cfg.save_every == 0: 262 | self.save_snapshot(epoch, best_loss) 263 | if valid_loss < best_loss: 264 | self.save_model(epoch) 265 | best_loss = valid_loss 266 | 267 | def valid(self, cfg: DictConfig): 268 | total_loss = 0.0 269 | avg_loss = 0.0 270 | total_samples = 0 271 | 272 | self.model.eval() 273 | for i, obj in enumerate(self.valid_dataloader): 274 | batch = Batch( 275 | device=self.device, target=cfg.target, vocab=self.vocab, obj=obj 276 | ) 277 | with torch.no_grad(): 278 | loss, _ = batch.inference( 279 | self.model, 280 | criterion=self.criterion, 281 | criterion_bbox=self.criterion_bbox, 282 | loss_weights=cfg.loss_weights, 283 | ) 284 | 285 | total_loss = loss["total"] 286 | total_loss = total_loss.detach().cpu().data 287 | avg_loss += total_loss * batch.image.shape[0] 288 | total_samples += batch.image.shape[0] 289 | 290 | if i % 10 == 0: 291 | loss_info = f"Loss {total_loss:.3f} ({avg_loss / total_samples:.3f})" 292 | if not isinstance(loss["html"], int): 293 | loss_info += f" Html {loss['html'].detach().cpu().data:.3f}" 294 | if not isinstance(loss["cell"], int): 295 | loss_info += f" Cell {loss['cell'].detach().cpu().data:.3f}" 296 | if not isinstance(loss["bbox"], int): 297 | loss_info += f" Bbox {loss['bbox'].detach().cpu().data:.3f}" 298 | self.log.info( 299 | printer( 300 | self.device, 301 | f"Valid: Step {i + 1}/{len(self.valid_dataloader)} | {loss_info}", 302 | ) 303 | ) 304 | 305 | return avg_loss / total_samples 306 | 307 | def test(self, test_dataloader: DataLoader, cfg: DictConfig, save_to: str): 308 | total_result = dict() 309 | for i, obj in enumerate(test_dataloader): 310 | batch = Batch( 311 | device=self.device, target=cfg.target, vocab=self.vocab, obj=obj 312 | ) 313 | 314 | if cfg.target == "html": 315 | prefix = [self.vocab.token_to_id("[html]")] 316 | valid_token_whitelist = [ 317 | self.vocab.token_to_id(i) for i in VALID_HTML_TOKEN 318 | ] 319 | valid_token_blacklist = None 320 | elif cfg.target == "cell": 321 | prefix = [self.vocab.token_to_id("[cell]")] 322 | valid_token_whitelist = None 323 | valid_token_blacklist = [ 324 | self.vocab.token_to_id(i) for i in INVALID_CELL_TOKEN 325 | ] 326 | elif cfg.target == "bbox": 327 | prefix = [self.vocab.token_to_id("[bbox]")] 328 | valid_token_whitelist = [ 329 | self.vocab.token_to_id(i) 330 | for i in VALID_BBOX_TOKEN[: cfg.img_size[0]] 331 | ] 332 | valid_token_blacklist = None 333 | else: 334 | raise NotImplementedError 335 | 336 | pred_id = batch_autoregressive_decode( 337 | device=self.device, 338 | model=self.model, 339 | batch_data=batch, 340 | prefix=prefix, 341 | max_decode_len=cfg.max_seq_len, 342 | eos_id=self.vocab.token_to_id(""), 343 | valid_token_whitelist=valid_token_whitelist, 344 | valid_token_blacklist=valid_token_blacklist, 345 | sampling=cfg.sampling, 346 | ) 347 | 348 | if cfg.target == "html": 349 | result = combine_filename_pred_gt( 350 | filename=batch.name, 351 | pred_id=pred_id, 352 | gt_id=batch.html_tgt, 353 | vocab=self.vocab, 354 | type="html", 355 | ) 356 | elif cfg.target == "cell": 357 | result = combine_filename_pred_gt( 358 | filename=batch.name, 359 | pred_id=pred_id, 360 | gt_id=batch.cell_tgt, 361 | vocab=self.vocab, 362 | type="cell", 363 | ) 364 | elif cfg.target == "bbox": 365 | result = combine_filename_pred_gt( 366 | filename=batch.name, 367 | pred_id=pred_id, 368 | gt_id=batch.bbox_tgt, 369 | vocab=self.vocab, 370 | type="bbox", 371 | ) 372 | else: 373 | raise NotImplementedError 374 | 375 | total_result.update(result) 376 | 377 | if i % 10 == 0: 378 | self.log.info( 379 | printer( 380 | self.device, 381 | f"Test: Step {i + 1}/{len(test_dataloader)}", 382 | ) 383 | ) 384 | 385 | self.log.info( 386 | printer( 387 | self.device, 388 | f"Converting {len(total_result)} samples to html tables ...", 389 | ) 390 | ) 391 | 392 | with open( 393 | os.path.join(save_to, cfg.save_to_prefix + f"_{self.device}.json"), 394 | "w", 395 | encoding="utf-8", 396 | ) as f: 397 | json.dump(total_result, f, indent=4) 398 | 399 | return total_result 400 | 401 | def save_model(self, epoch: int): 402 | filename = Path(self.exp_dir) / "model" / f"epoch{epoch}_model.pt" 403 | torch.save(self.model.module.state_dict(), filename) 404 | self.log.info(printer(self.device, f"Saving model to {filename}")) 405 | filename = Path(self.exp_dir) / "model" / "best.pt" 406 | torch.save(self.model.module.state_dict(), filename) 407 | 408 | def load_model(self, path: Union[str, Path]): 409 | self.model.load_state_dict(torch.load(path, map_location="cpu")) 410 | self.log.info(printer(self.device, f"Loading model from {path}")) 411 | 412 | def save_snapshot(self, epoch: int, best_loss: float): 413 | state_info = { 414 | "EPOCH": epoch + 1, 415 | "STEP": self.global_step, 416 | "OPTIMIZER": self.optimizer.state_dict(), 417 | "LR_SCHEDULER": self.lr_scheduler.state_dict(), 418 | "MODEL": self.model.module.state_dict(), 419 | "LOSS": best_loss, 420 | } 421 | 422 | snapshot_path = Path(self.exp_dir) / "snapshot" / f"epoch{epoch}_snapshot.pt" 423 | torch.save(state_info, snapshot_path) 424 | 425 | self.log.info(printer(self.device, f"Saving snapshot to {snapshot_path}")) 426 | 427 | def load_snapshot(self, path: Path): 428 | self.log.info(printer(self.device, f"Loading snapshot from {path}")) 429 | snapshot = torch.load(path, map_location="cpu") 430 | assert SNAPSHOT_KEYS.issubset(snapshot.keys()) 431 | return snapshot 432 | 433 | def load_pretrained_beit(self, path: Path): 434 | self.log.info(printer(self.device, f"Loading pretrained BEiT from {path}")) 435 | beit = torch.load(path, map_location="cpu") 436 | redundant_keys_in_beit = [ 437 | "cls_token", 438 | "mask_token", 439 | "generator.weight", 440 | "generator.bias", 441 | ] 442 | for key in redundant_keys_in_beit: 443 | if key in beit: 444 | del beit[key] 445 | 446 | # max_seq_len in finetuning may go beyond the length in pretraining 447 | if ( 448 | self.model.pos_embed.embedding.weight.shape[0] 449 | != beit["pos_embed.embedding.weight"].shape[0] 450 | ): 451 | emb_shape = self.model.pos_embed.embedding.weight.shape 452 | ckpt_emb = beit["pos_embed.embedding.weight"].clone() 453 | assert emb_shape[1] == ckpt_emb.shape[1] 454 | 455 | ckpt_emb = ckpt_emb.unsqueeze(0).permute(0, 2, 1) 456 | ckpt_emb = F.interpolate(ckpt_emb, emb_shape[0], mode="nearest") 457 | beit["pos_embed.embedding.weight"] = ckpt_emb.permute(0, 2, 1).squeeze() 458 | 459 | out = self.model.load_state_dict(beit, strict=False) 460 | 461 | # ensure missing keys are just token_embed, decoder, and generator 462 | missing_keys_prefix = ("token_embed", "decoder", "generator") 463 | for key in out[0]: 464 | assert key.startswith( 465 | missing_keys_prefix 466 | ), f"Key {key} should be loaded from BEiT, but missing in current state dict." 467 | assert len(out[1]) == 0, f"Unexpected keys from BEiT: {out[1]}" 468 | -------------------------------------------------------------------------------- /src/trainer/train_vqvae.py: -------------------------------------------------------------------------------- 1 | import math 2 | import wandb 3 | from pathlib import Path 4 | from typing import Tuple, List, Union, Dict 5 | from omegaconf import DictConfig 6 | from hydra.utils import instantiate 7 | import logging 8 | import torch 9 | import time 10 | from functools import partial 11 | from torch import nn, Tensor, autograd 12 | from torch.utils.data import DataLoader 13 | from torch.optim import Adam 14 | from torch.nn.parallel import DistributedDataParallel as DDP 15 | import torch.distributed as dist 16 | from torchvision.utils import make_grid 17 | 18 | from src.utils import printer, compute_grad_norm 19 | 20 | SNAPSHOT_KEYS = set(["EPOCH", "STEP", "OPTIMIZER", "LR_SCHEDULER", "MODEL", "LOSS"]) 21 | 22 | 23 | class VqvaeTrainer: 24 | def __init__( 25 | self, 26 | device: int, 27 | model: nn.Module, 28 | log: logging.Logger, 29 | exp_dir: Path, 30 | snapshot: Path = None, 31 | model_weights: Path = None, # only for testing 32 | ) -> None: 33 | self.device = device 34 | self.log = log 35 | self.exp_dir = exp_dir 36 | assert ( 37 | snapshot is None or model_weights is None 38 | ), "Snapshot and model weights cannot be set at the same time." 39 | 40 | self.model = model 41 | if snapshot is not None and snapshot.is_file(): 42 | self.snapshot = self.load_snapshot(snapshot) 43 | self.model.load_state_dict(self.snapshot["MODEL"]) 44 | self.start_epoch = self.snapshot["EPOCH"] 45 | self.global_step = self.snapshot["STEP"] 46 | elif model_weights is not None and model_weights.is_file(): 47 | self.load_model(model_weights) 48 | else: 49 | self.snapshot = None 50 | self.start_epoch = 0 51 | 52 | self.model = self.model.to(device) 53 | self.model = DDP(self.model, device_ids=[device]) 54 | 55 | # https://discuss.pytorch.org/t/extra-10gb-memory-on-gpu-0-in-ddp-tutorial/118113 56 | torch.cuda.set_device(device) # master gpu takes up extra memory 57 | torch.cuda.empty_cache() 58 | 59 | def train_epoch( 60 | self, 61 | epoch: int, 62 | starting_temp: float, 63 | anneal_rate: float, 64 | temp_min: float, 65 | grad_clip: float = None, 66 | ): 67 | start = time.time() 68 | total_loss = 0.0 69 | total_samples = 0 70 | 71 | # load data from dataloader 72 | for i, obj in enumerate(self.train_dataloader): 73 | if isinstance(obj, Tensor): 74 | img = obj.to(self.device) 75 | elif isinstance(obj, (list, tuple)): 76 | img = obj[0].to(self.device) 77 | else: 78 | raise ValueError(f"Unrecognized object type {type(obj)}") 79 | 80 | # temperature annealing 81 | self.temp = max( 82 | starting_temp * math.exp(-anneal_rate * self.global_step), temp_min 83 | ) 84 | 85 | with autograd.detect_anomaly(): 86 | loss, soft_recons = self.model( 87 | img, return_loss=True, return_recons=True, temp=self.temp 88 | ) 89 | 90 | self.optimizer.zero_grad() 91 | loss.backward() 92 | if grad_clip: 93 | nn.utils.clip_grad_norm_( 94 | self.model.parameters(), max_norm=grad_clip 95 | ) 96 | self.optimizer.step() 97 | 98 | loss = loss.detach().cpu().data 99 | total_loss += loss * img.shape[0] 100 | total_samples += img.shape[0] 101 | 102 | self.lr_scheduler.step() 103 | self.global_step += 1 104 | 105 | if i % 10 == 0: 106 | grad_norm = compute_grad_norm(self.model) 107 | lr = self.optimizer.param_groups[0]["lr"] 108 | elapsed = time.time() - start 109 | self.log.info( 110 | printer( 111 | self.device, 112 | f"Epoch {epoch} Step {i + 1}/{len(self.train_dataloader)} | Loss {loss:.4f} ({total_loss / total_samples:.4f}) | Grad norm {grad_norm:.3f} | {total_samples / elapsed:4.1f} images/s | lr {lr:5.1e} | Temp {self.temp:.2e}", 113 | ) 114 | ) 115 | 116 | # visualize reconstruction images 117 | if i % 100 == 0 and self.device == 0: 118 | lr = self.optimizer.param_groups[0]["lr"] 119 | k = 4 # num of images saved for visualization 120 | codes = self.model.module.get_codebook_indices(img[:k]) 121 | hard_recons = self.model.module.decode(codes) 122 | 123 | img = img[:k].detach().cpu() 124 | soft_recons = soft_recons[:k].detach().cpu() 125 | codes = codes.flatten(start_dim=1).detach().cpu() 126 | hard_recons = hard_recons.detach().cpu() 127 | 128 | make_vis = partial(make_grid, nrow=int(math.sqrt(k)), normalize=True) 129 | img, soft_recons, hard_recons = map( 130 | make_vis, (img, soft_recons, hard_recons) 131 | ) 132 | 133 | log_info = { 134 | "epoch": epoch, 135 | "train_loss": loss, 136 | "temperature": self.temp, 137 | "learning rate": lr, 138 | "original images": wandb.Image( 139 | img, caption=f"step: {self.global_step}" 140 | ), 141 | "soft reconstruction": wandb.Image( 142 | soft_recons, caption=f"step: {self.global_step}" 143 | ), 144 | "hard reconstruction": wandb.Image( 145 | hard_recons, caption=f"step: {self.global_step}" 146 | ), 147 | "codebook_indices": wandb.Histogram(codes), 148 | } 149 | 150 | wandb.log( 151 | log_info, 152 | step=self.global_step, 153 | ) 154 | 155 | return total_loss, total_samples 156 | 157 | def train( 158 | self, 159 | train_dataloader: DataLoader, 160 | valid_dataloader: DataLoader, 161 | train_cfg: DictConfig, 162 | valid_cfg: DictConfig, 163 | ): 164 | self.train_dataloader = train_dataloader 165 | self.valid_dataloader = valid_dataloader 166 | self.optimizer = instantiate( 167 | train_cfg.optimizer, params=self.model.parameters() 168 | ) 169 | 170 | self.lr_scheduler = instantiate( 171 | train_cfg.lr_scheduler, optimizer=self.optimizer 172 | ) 173 | 174 | if self.snapshot is not None: 175 | self.optimizer.load_state_dict(self.snapshot["OPTIMIZER"]) 176 | self.lr_scheduler.load_state_dict(self.snapshot["LR_SCHEDULER"]) 177 | 178 | best_loss = float("inf") 179 | self.model.train() 180 | self.global_step = 0 181 | # self.temp = train_cfg.starting_temp 182 | for epoch in range(self.start_epoch, train_cfg.epochs): 183 | train_dataloader.sampler.set_epoch(epoch) 184 | epoch_loss, epoch_samples = self.train_epoch( 185 | epoch, 186 | starting_temp=train_cfg.starting_temp, 187 | anneal_rate=train_cfg.temp_anneal_rate, 188 | temp_min=train_cfg.temp_min, 189 | grad_clip=train_cfg.grad_clip, 190 | ) 191 | 192 | torch.cuda.empty_cache() 193 | 194 | valid_loss, valid_samples = self.valid(valid_cfg) 195 | 196 | # reduce loss to gpu 0 197 | training_info = torch.tensor( 198 | [epoch_loss, epoch_samples, valid_loss, valid_samples], 199 | device=self.device, 200 | ) 201 | 202 | dist.reduce( 203 | training_info, 204 | dst=0, 205 | op=dist.ReduceOp.SUM, 206 | ) 207 | 208 | if self.device == 0: 209 | grad_norm = compute_grad_norm(self.model) 210 | epoch_loss, epoch_samples, valid_loss, valid_samples = training_info 211 | epoch_loss, valid_loss = ( 212 | float(epoch_loss) / epoch_samples, 213 | float(valid_loss) / valid_samples, 214 | ) 215 | 216 | log_info = { 217 | "train loss (epoch)": epoch_loss, 218 | "valid loss (epoch)": valid_loss, 219 | "train_samples": epoch_samples, 220 | "valid_samples": valid_samples, 221 | "grad_norm": grad_norm, 222 | } 223 | 224 | wandb.log( 225 | log_info, 226 | step=self.global_step, 227 | ) 228 | 229 | if epoch % train_cfg.save_every == 0: 230 | self.save_snapshot(epoch, best_loss) 231 | if valid_loss < best_loss: 232 | self.save_model(epoch) 233 | best_loss = valid_loss 234 | 235 | def valid(self, cfg: DictConfig): 236 | total_samples = 0 237 | total_loss = 0.0 238 | 239 | self.model.eval() 240 | for i, obj in enumerate(self.valid_dataloader): 241 | if isinstance(obj, Tensor): 242 | img = obj.to(self.device) 243 | elif isinstance(obj, (list, tuple)): 244 | img = obj[0].to(self.device) 245 | else: 246 | raise ValueError(f"Unrecognized object type {type(obj)}") 247 | 248 | with torch.no_grad(): 249 | loss = self.model( 250 | img, return_loss=True, return_recons=False, temp=self.temp 251 | ) 252 | 253 | loss = loss.detach().cpu().data 254 | total_loss += loss * img.shape[0] 255 | total_samples += img.shape[0] 256 | 257 | if i % 10 == 0: 258 | self.log.info( 259 | printer( 260 | self.device, 261 | f"Valid: Step {i + 1}/{len(self.valid_dataloader)} | Loss {loss:.4f} ({total_loss / total_samples:.4f})", 262 | ) 263 | ) 264 | 265 | return total_loss, total_samples 266 | 267 | def save_model(self, epoch: int): 268 | filename = Path(self.exp_dir) / "model" / f"epoch{epoch}_model.pt" 269 | torch.save(self.model.module.state_dict(), filename) 270 | self.log.info(printer(self.device, f"Saving model to {filename}")) 271 | filename = Path(self.exp_dir) / "model" / f"best.pt" 272 | torch.save(self.model.module.state_dict(), filename) 273 | 274 | def load_model(self, path: Union[str, Path]): 275 | self.model.load_state_dict(torch.load(path, map_location="cpu")) 276 | self.log.info(printer(self.device, f"Loading model from {path}")) 277 | 278 | def save_snapshot(self, epoch: int, best_loss: float): 279 | state_info = { 280 | "EPOCH": epoch + 1, 281 | "STEP": self.global_step, 282 | "OPTIMIZER": self.optimizer.state_dict(), 283 | "LR_SCHEDULER": self.lr_scheduler.state_dict(), 284 | "MODEL": self.model.module.state_dict(), 285 | "LOSS": best_loss, 286 | } 287 | 288 | snapshot_path = Path(self.exp_dir) / "snapshot" / f"epoch{epoch}_snapshot.pt" 289 | torch.save(state_info, snapshot_path) 290 | 291 | self.log.info(printer(self.device, f"Saving snapshot to {snapshot_path}")) 292 | 293 | def load_snapshot(self, path: Path): 294 | self.log.info(printer(self.device, f"Loading snapshot from {path}")) 295 | snapshot = torch.load(path, map_location="cpu") 296 | assert SNAPSHOT_KEYS.issubset(snapshot.keys()) 297 | return snapshot 298 | -------------------------------------------------------------------------------- /src/trainer/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Dict 2 | import torch 3 | from torch import Tensor, nn 4 | from torchtext.vocab import Vocab 5 | import tokenizers as tk 6 | 7 | from src.utils import pred_token_within_range, subsequent_mask 8 | from src.vocab import ( 9 | HTML_TOKENS, 10 | TASK_TOKENS, 11 | RESERVED_TOKENS, 12 | BBOX_TOKENS, 13 | ) 14 | 15 | 16 | VALID_HTML_TOKEN = [""] + HTML_TOKENS 17 | INVALID_CELL_TOKEN = ( 18 | ["", "", "", ""] + TASK_TOKENS + RESERVED_TOKENS 19 | ) 20 | VALID_BBOX_TOKEN = [ 21 | "" 22 | ] + BBOX_TOKENS # image size will be addressed after instantiation 23 | 24 | 25 | class Batch: 26 | """Wrap up a batch of training samples with different training targets. 27 | The input is not torch tensor 28 | Shape of the image (src): B, S, E 29 | Shape of the text (tgt): B, N, S, E (M includes 1 table detection, 1 structure, 1 cell, and multiple bbox) 30 | Reshape text to (B * N, S, E) and inflate the image to match the shape of the text 31 | 32 | Args: 33 | ---- 34 | device: gpu id 35 | """ 36 | 37 | def __init__( 38 | self, 39 | device: torch.device, 40 | target: str, 41 | vocab: Vocab, 42 | obj: List, 43 | ) -> None: 44 | self.device = device 45 | self.image = obj[0].to(device) 46 | self.name = obj[1]["filename"] 47 | self.target = target 48 | self.vocab = vocab 49 | self.image_size = self.image.shape[-1] 50 | 51 | if "table" in target: 52 | raise NotImplementedError 53 | 54 | if "html" in target: 55 | self.valid_html_token = [vocab.token_to_id(i) for i in VALID_HTML_TOKEN] 56 | ( 57 | self.html_src, 58 | self.html_tgt, 59 | self.html_casual_mask, 60 | self.html_padding_mask, 61 | ) = self._prepare_transformer_input(obj[1]["html"]) 62 | 63 | if "cell" in target: 64 | self.invalid_cell_token = [vocab.token_to_id(i) for i in INVALID_CELL_TOKEN] 65 | ( 66 | self.cell_src, 67 | self.cell_tgt, 68 | self.cell_casual_mask, 69 | self.cell_padding_mask, 70 | ) = self._prepare_transformer_input(obj[1]["cell"]) 71 | 72 | if "bbox" in target: 73 | ( 74 | self.bbox_src, 75 | self.bbox_tgt, 76 | self.bbox_casual_mask, 77 | self.bbox_padding_mask, 78 | ) = self._prepare_transformer_input(obj[1]["bbox"]) 79 | 80 | def _prepare_transformer_input( 81 | self, seq: List[tk.Encoding] 82 | ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: 83 | tmp = [i.ids for i in seq] 84 | tmp = torch.tensor(tmp, dtype=torch.int32) 85 | src = tmp[:, :-1].to(self.device) 86 | tgt = tmp[:, 1:].type(torch.LongTensor).to(self.device) 87 | casual_mask = subsequent_mask(src.shape[-1]).to(self.device) 88 | tmp = [i.attention_mask[:-1] for i in seq] # padding mask 89 | tmp = torch.tensor(tmp, dtype=torch.bool) 90 | padding_mask = (~tmp).to(self.device) 91 | 92 | return src, tgt, casual_mask, padding_mask 93 | 94 | def _inference_one_task( 95 | self, model, memory, src, casual_mask, padding_mask, use_ddp 96 | ): 97 | if use_ddp: 98 | out = model.module.decode(memory, src, casual_mask, padding_mask) 99 | out = model.module.generator(out) 100 | else: 101 | out = model.decode(memory, src, casual_mask, padding_mask) 102 | out = model.generator(out) 103 | 104 | return out 105 | 106 | def inference( 107 | self, 108 | model: nn.Module, 109 | criterion: nn.Module, 110 | criterion_bbox: nn.Module = None, 111 | loss_weights: dict = None, 112 | use_ddp: bool = True, 113 | ) -> Tuple[Dict, Dict]: 114 | pred = dict() 115 | loss = dict(table=0, html=0, cell=0, bbox=0) 116 | 117 | if use_ddp: 118 | memory = model.module.encode(self.image) 119 | else: 120 | memory = model.encode(self.image) 121 | 122 | # inference + suppress invalid logits + compute loss 123 | if "html" in self.target: 124 | out_html = self._inference_one_task( 125 | model, 126 | memory, 127 | self.html_src, 128 | self.html_casual_mask, 129 | self.html_padding_mask, 130 | use_ddp, 131 | ) 132 | 133 | pred["html"] = pred_token_within_range( 134 | out_html, white_list=self.valid_html_token 135 | ).permute(0, 2, 1) 136 | loss["html"] = criterion(pred["html"], self.html_tgt) 137 | 138 | if "cell" in self.target: 139 | out_cell = self._inference_one_task( 140 | model, 141 | memory, 142 | self.cell_src, 143 | self.cell_casual_mask, 144 | self.cell_padding_mask, 145 | use_ddp, 146 | ) 147 | 148 | pred["cell"] = pred_token_within_range( 149 | out_cell, black_list=self.invalid_cell_token 150 | ).permute(0, 2, 1) 151 | loss["cell"] = criterion(pred["cell"], self.cell_tgt) 152 | 153 | if "bbox" in self.target: 154 | assert criterion_bbox is not None 155 | 156 | out_bbox = self._inference_one_task( 157 | model, 158 | memory, 159 | self.bbox_src, 160 | self.bbox_casual_mask, 161 | self.bbox_padding_mask, 162 | use_ddp, 163 | ) 164 | pred["bbox"] = out_bbox.permute(0, 2, 1) 165 | loss["bbox"] = criterion_bbox(pred["bbox"], self.bbox_tgt) 166 | 167 | total = 0.0 168 | for k, v in loss_weights.items(): 169 | total += loss[k] * v 170 | loss["total"] = total 171 | 172 | return loss, pred 173 | 174 | 175 | def configure_optimizer_weight_decay( 176 | model: nn.Module, weight_decay: float 177 | ) -> List[Dict]: 178 | weight_decay_blacklist = (nn.LayerNorm, nn.BatchNorm2d, nn.Embedding) 179 | 180 | if hasattr(model, "no_weight_decay"): 181 | skip_list = model.no_weight_decay() 182 | decay = set() 183 | no_decay = set() 184 | for mn, m in model.named_modules(): 185 | for pn, p in m.named_parameters(): 186 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name 187 | if pn.endswith("bias"): 188 | no_decay.add(fpn) 189 | elif pn.endswith("weight") and isinstance(m, weight_decay_blacklist): 190 | no_decay.add(fpn) 191 | elif pn in skip_list: 192 | no_decay.add(fpn) 193 | 194 | param_dict = {pn: p for pn, p in model.named_parameters()} 195 | decay = param_dict.keys() - no_decay 196 | 197 | optim_groups = [ 198 | { 199 | "params": [param_dict[pn] for pn in sorted(list(decay))], 200 | "weight_decay": weight_decay, 201 | }, 202 | { 203 | "params": [param_dict[pn] for pn in sorted(list(no_decay))], 204 | "weight_decay": 0.0, 205 | }, 206 | ] 207 | 208 | return optim_groups 209 | 210 | 211 | def turn_off_beit_grad(model: nn.Module): 212 | "Freeze BEiT pretrained weights." 213 | for param in model.encoder.parameters(): 214 | param.requires_grad = False 215 | 216 | for param in model.backbone.parameters(): 217 | param.requires_grad = False 218 | 219 | for param in model.pos_embed.parameters(): 220 | param.requires_grad = False 221 | 222 | 223 | def turn_on_beit_grad(model: nn.Module): 224 | for param in model.parameters(): 225 | param.requires_grad = True 226 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualization import * 2 | from .data import * 3 | from .mask_generator import * 4 | from .misc import * 5 | -------------------------------------------------------------------------------- /src/utils/coco_map.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics.detection import MeanAveragePrecision 3 | from pprint import pprint 4 | 5 | 6 | def compute_coco_map(file): 7 | coco_pred = list() 8 | coco_gt = list() 9 | for _, obj in file.items(): 10 | tmp_pred = { 11 | "boxes": torch.tensor(obj["pred"], device=0), 12 | "labels": torch.tensor([0] * len(obj["pred"]), device=0), 13 | "scores": torch.tensor([0.999] * len(obj["pred"]), device=0), 14 | } 15 | 16 | tmp_gt = { 17 | "boxes": torch.tensor(obj["gt"], device=0), 18 | "labels": torch.tensor([0] * len(obj["gt"]), device=0), 19 | } 20 | 21 | coco_pred.append(tmp_pred) 22 | coco_gt.append(tmp_gt) 23 | 24 | metric = MeanAveragePrecision( 25 | iou_type="bbox", 26 | max_detection_thresholds=[1, 10, 1000], 27 | backend="faster_coco_eval", 28 | ) 29 | metric.update(coco_pred, coco_gt) 30 | pprint(metric.compute()) 31 | 32 | 33 | if __name__ == "__main__": 34 | import json 35 | import argparse 36 | 37 | parser = argparse.ArgumentParser(description="mAP Computation") 38 | 39 | parser.add_argument("-f", "--file", help="path to html table results in json file") 40 | args = parser.parse_args() 41 | 42 | 43 | results_file = args.file 44 | with open(results_file, "r") as f: 45 | results_json = json.load(f) 46 | 47 | compute_coco_map(results_json) 48 | -------------------------------------------------------------------------------- /src/utils/data.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import random 3 | import tokenizers as tk 4 | import torch 5 | from torch import Tensor, nn 6 | import torch.nn.functional as F 7 | 8 | from src.vocab import TASK_TOKENS, CELL_SPECIAL 9 | from src.model.encoderdecoder import EncoderDecoder 10 | from .misc import html_table_template 11 | 12 | __all__ = [ 13 | "subsequent_mask", 14 | "combine_cell_char_seq", 15 | "random_continuous_sequence", 16 | "prepare_html_seq", 17 | "prepare_cell_seq", 18 | "prepare_bbox_seq", 19 | "html_str_to_token_list", 20 | "cell_str_to_token_list", 21 | "bbox_str_to_token_list", 22 | "build_table_from_html_and_cell", 23 | "pred_token_within_range", 24 | "batch_autoregressive_decode", 25 | "greedy_sampling", 26 | "combine_filename_pred_gt", 27 | ] 28 | 29 | 30 | def subsequent_mask(size: int, pad: int = 0): 31 | attn_shape = (size, size) 32 | output = torch.triu(torch.ones(attn_shape), diagonal=1).to(torch.bool) 33 | if pad and pad > 0: 34 | output[:pad] = False 35 | return output 36 | 37 | 38 | def combine_cell_char_seq(seq: List[str]) -> str: 39 | """Replace empty token with in vocab. combine characters into a str""" 40 | if seq: 41 | out = "".join(seq) 42 | else: 43 | out = "" 44 | return out 45 | 46 | 47 | def prepare_html_seq(seq: List[str]) -> List[str]: 48 | """Convert html annotations to html training template.""" 49 | out = ["[html]", *seq, ""] 50 | return out 51 | 52 | 53 | def prepare_cell_seq(seq: str) -> List[str]: 54 | """Convert cell sequence to training template.""" 55 | for black in CELL_SPECIAL: 56 | seq = seq.replace(black, "") 57 | out = ["[cell]", seq, ""] 58 | 59 | return out 60 | 61 | 62 | def prepare_bbox_seq(seq: List[dict]): 63 | tmp = [f"bbox-{round(i)}" for i in seq] 64 | out = ["[bbox]"] + tmp + [""] 65 | 66 | return out 67 | 68 | 69 | def random_continuous_sequence(seq: List, N: int, length: int = 10) -> List: 70 | """Randomly sample a continuous sub-sequence from a sequence for N times.""" 71 | start_idx = [random.randrange(len(seq)) for _ in range(N)] 72 | subseq_len = [random.randrange(1, length) for _ in range(N)] 73 | output = [(i, min(i + j, len(seq))) for i, j in zip(start_idx, subseq_len)] 74 | 75 | return output 76 | 77 | 78 | # def prepare_bbox_seq( 79 | # seq: List[dict], 80 | # N: int, 81 | # delimiter: str = "", 82 | # ) -> List[List[str]]: 83 | # """Convert the annotation to bbox input/output sequence.""" 84 | # out = list() 85 | # # bbox_loss_start_idx = list() 86 | 87 | # subseq_idx = random_continuous_sequence(seq, N) 88 | 89 | # for idx in subseq_idx: 90 | # entry = seq[idx[0] : idx[1]] 91 | # tmp = list() 92 | # bbox_seq = list() 93 | # for i in entry: 94 | # if "tokens" in i.keys(): 95 | # # pubtabnet and synthtabnet 96 | # tmp.append(combine_cell_char_seq(i["tokens"])) 97 | # if "bbox" in i.keys(): 98 | # bbox_seq.extend([f"bbox-{round(j)}" for j in i["bbox"]]) 99 | # elif "text" in i.keys(): 100 | # # pubtables and icdar 101 | # tmp.append(i["text"]) 102 | # if "bbox" in i.keys(): 103 | # bbox_seq.extend([f"bbox-{round(j)}" for j in i["bbox"]]) 104 | 105 | # cell_seq = [delimiter] * len(tmp) 106 | # cell_seq = [q for pair in zip(tmp, cell_seq) for q in pair] 107 | # cell_seq = ["[bbox]", f"{len(entry)}-cell(s)", delimiter] + cell_seq 108 | 109 | # bbox_seq.append("") 110 | # # bbox_loss_start_idx.append(len(cell_seq)) 111 | # out.append(cell_seq + bbox_seq) 112 | 113 | # return out 114 | 115 | 116 | def html_str_to_token_list( 117 | seq: str, splitter: tk.pre_tokenizers.PreTokenizer = None 118 | ) -> List[str]: 119 | """Convert decode output (str) to a list of tokens for constructing html table code""" 120 | 121 | # works for no 122 | seq = seq.split("")[0] 123 | 124 | token_black_list = ["", "", *TASK_TOKENS] 125 | for i in token_black_list: 126 | seq = seq.replace(i, "") 127 | 128 | if not splitter: 129 | splitter = tk.pre_tokenizers.Split(pattern=" ", behavior="contiguous") 130 | 131 | seq = splitter.pre_tokenize_str(seq) 132 | # only preserve the space for spanning cell tokens 133 | seq = [i[0] for i in seq if len(i[0].strip()) != 0 or i[1][1] - i[1][0] != 1] 134 | 135 | return seq 136 | 137 | 138 | def cell_str_to_token_list(seq: str) -> List[str]: 139 | seq = seq.split("")[0] 140 | 141 | token_black_list = ["", "", *TASK_TOKENS] 142 | for i in token_black_list: 143 | seq = seq.replace(i, "") 144 | 145 | seq = seq.strip() 146 | 147 | return seq 148 | 149 | 150 | def build_table_from_html_and_cell( 151 | structure: List[str], content: List[str] = None 152 | ) -> List[str]: 153 | """Build table from html and cell token list""" 154 | assert structure is not None 155 | html_code = list() 156 | 157 | # deal with empty table 158 | if content is None: 159 | content = ["placeholder"] * len(structure) 160 | 161 | for tag in structure: 162 | if tag in ("[]", ">[]"): 163 | if len(content) == 0: 164 | continue 165 | cell = content.pop(0) 166 | html_code.append(tag.replace("[]", cell)) 167 | else: 168 | html_code.append(tag) 169 | 170 | return html_code 171 | 172 | 173 | def bbox_str_to_token_list( 174 | seq: str, splitter: tk.pre_tokenizers.PreTokenizer = None 175 | ) -> List[List[int]]: 176 | """ 177 | Note the out could be an empty list 178 | 179 | return 180 | [[ymin, xmin, ymax, xmax], 181 | [ymin, xmin, ymax, xmax], 182 | ... 183 | ] 184 | """ 185 | 186 | seq = seq.split("")[0] 187 | 188 | token_black_list = ["", "", *TASK_TOKENS] 189 | for i in token_black_list: 190 | seq = seq.replace(i, "") 191 | 192 | if not splitter: 193 | splitter = tk.pre_tokenizers.Split(pattern=" ", behavior="removed") 194 | 195 | seq = splitter.pre_tokenize_str(seq) 196 | seq = [int(i[0].split("-")[1]) for i in seq] 197 | 198 | rounded_seq_len = len(seq) // 4 * 4 199 | out = [seq[i : i + 4] for i in range(0, rounded_seq_len, 4)] 200 | return out 201 | 202 | 203 | def pred_token_within_range( 204 | pred: Tensor, 205 | white_list: List[int] = None, 206 | black_list: List[int] = None, 207 | ) -> Tensor: 208 | assert white_list is None or black_list is None 209 | if white_list: 210 | total = set([i for i in range(pred.shape[-1])]) 211 | black_list = list(total.difference(set(white_list))) 212 | 213 | pred[..., black_list] = -float("inf") 214 | 215 | return pred 216 | 217 | 218 | def greedy_sampling(logits: Tensor): 219 | """logits should have shape [B, |V|].""" 220 | probs = F.softmax(logits, dim=-1) 221 | next_probs, next_tokens = probs.topk(1) 222 | 223 | return next_probs, next_tokens 224 | 225 | 226 | def batch_autoregressive_decode( 227 | device: int, 228 | model: EncoderDecoder, 229 | batch_data, 230 | prefix: List[int], 231 | max_decode_len: int, 232 | eos_id: int, 233 | valid_token_whitelist: List[int] = None, 234 | valid_token_blacklist: List[int] = None, 235 | sampling: str = "greedy", 236 | use_ddp: bool = True, 237 | ) -> Tensor: 238 | """Auto-regressively generate the output.""" 239 | 240 | model.eval() 241 | with torch.no_grad(): 242 | if use_ddp: 243 | memory = model.module.encode(batch_data.image) 244 | else: 245 | memory = model.encode(batch_data.image) 246 | 247 | B = batch_data.image.shape[0] 248 | 249 | context = torch.tensor(prefix, dtype=torch.int32).repeat(B, 1).to(device) 250 | 251 | for _ in range(max_decode_len): 252 | eos_flag = [eos_id in k for k in context] 253 | if all(eos_flag): 254 | break 255 | 256 | # as long as one sample hasn't reached , continue decoding until the max seq len 257 | causal_mask = subsequent_mask(context.shape[1]).to(device) 258 | 259 | with torch.no_grad(): 260 | if use_ddp: 261 | logits = model.module.decode( 262 | memory, context, tgt_mask=causal_mask, tgt_padding_mask=None 263 | ) 264 | logits = model.module.generator(logits)[:, -1, :] 265 | else: 266 | logits = model.decode( 267 | memory, context, tgt_mask=causal_mask, tgt_padding_mask=None 268 | ) 269 | logits = model.generator(logits)[:, -1, :] 270 | 271 | logits = pred_token_within_range( 272 | logits.detach(), 273 | white_list=valid_token_whitelist if valid_token_whitelist else None, 274 | black_list=valid_token_blacklist if valid_token_blacklist else None, 275 | ) 276 | 277 | if sampling == "greedy": 278 | next_probs, next_tokens = greedy_sampling(logits) 279 | else: 280 | raise NotImplementedError 281 | 282 | context = torch.cat([context, next_tokens], dim=1) 283 | 284 | return context 285 | 286 | 287 | def combine_filename_pred_gt( 288 | filename: List[str], pred_id: Tensor, gt_id: Tensor, vocab: tk.Tokenizer, type: str 289 | ) -> dict: 290 | out = dict() 291 | 292 | assert len(filename) == len(pred_id) 293 | 294 | pred_id = pred_id.detach().cpu().numpy() 295 | gt_id = gt_id.detach().cpu().numpy() 296 | 297 | pred_token = vocab.decode_batch(pred_id, skip_special_tokens=False) 298 | gt_token = vocab.decode_batch(gt_id, skip_special_tokens=False) 299 | 300 | for idx, name in enumerate(filename): 301 | if type == "html": 302 | pred_token_list = html_str_to_token_list(pred_token[idx]) 303 | gt_token_list = html_str_to_token_list(gt_token[idx]) 304 | elif type == "cell": 305 | pred_token_list = cell_str_to_token_list(pred_token[idx]) 306 | gt_token_list = cell_str_to_token_list(gt_token[idx]) 307 | elif type == "bbox": 308 | pred_token_list = bbox_str_to_token_list(pred_token[idx]) 309 | gt_token_list = bbox_str_to_token_list(gt_token[idx]) 310 | else: 311 | raise ValueError( 312 | f"The supported tasks are html, cell and bbox, while {type} is provided." 313 | ) 314 | 315 | out[name] = dict(pred=pred_token_list, gt=gt_token_list) 316 | 317 | return out 318 | -------------------------------------------------------------------------------- /src/utils/engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from pathlib import Path 5 | import glob 6 | 7 | from src.utils import build_table_from_html_and_cell, html_table_template 8 | 9 | 10 | def combine_all_json(file_dir: str) -> dict: 11 | total_result = dict() 12 | files = os.listdir(file_dir) 13 | try: 14 | files.remove("final.json") 15 | except ValueError: 16 | pass 17 | for file in files: 18 | with open(os.path.join(file_dir, file), "r") as f: 19 | result = json.load(f) 20 | total_result.update(result) 21 | 22 | print(f"Combined to a json with {len(total_result)} entries.") 23 | 24 | return total_result 25 | 26 | 27 | def json_to_final(file_dir: str, type: str): 28 | if type == "html" or type == "bbox": 29 | result = combine_all_json(file_dir) 30 | elif type == "html+cell": 31 | result_cell = combine_all_json(file_dir) 32 | result_html_file = os.path.join( 33 | Path(file_dir).parent, 34 | Path(file_dir).name.split("-")[0].replace("cell", "html") + "-html", 35 | ) 36 | assert Path(result_html_file).is_dir(), f"{result_html_file} does not exist." 37 | result = combine_all_json(result_html_file) 38 | assert len(result) == len(result_cell) 39 | else: 40 | # assert html and cell json files have the same length 41 | raise NotImplementedError 42 | 43 | out = dict() 44 | 45 | if type == "bbox": 46 | out = result 47 | else: 48 | for filename, obj in result.items(): 49 | if type == "html": 50 | pred_html = "".join(obj["pred"]) 51 | gt_html = "".join(obj["gt"]) 52 | 53 | out[filename] = dict( 54 | pred=html_table_template(pred_html), gt=html_table_template(gt_html) 55 | ) 56 | elif type == "html+cell": 57 | pred_html_cell = build_table_from_html_and_cell( 58 | obj["pred"], result_cell[filename]["pred"] 59 | ) 60 | gt_html_cell = build_table_from_html_and_cell( 61 | obj["gt"], result_cell[filename]["gt"] 62 | ) 63 | out[filename] = dict( 64 | pred=html_table_template(pred_html_cell), 65 | gt=html_table_template(gt_html_cell), 66 | ) 67 | else: 68 | raise NotImplementedError 69 | 70 | # write to file 71 | with open(os.path.join(file_dir, f"final.json"), "w", encoding="utf-8") as f: 72 | json.dump(out, f, indent=4) 73 | 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser(description="postprecess") 77 | 78 | parser.add_argument( 79 | "-f", "--file", help="path to all json files from difference devices" 80 | ) 81 | parser.add_argument("-t", "--type", help="html, html+cell") 82 | args = parser.parse_args() 83 | 84 | json_to_final(args.file, args.type) 85 | -------------------------------------------------------------------------------- /src/utils/mask_generator.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | from typing import Any 4 | import numpy as np 5 | 6 | """ 7 | Code adapted from beit mask generator: https://github.com/microsoft/unilm/blob/ecff36188001e9b12a90b01bbbaf9058d2b8bda6/beit/masking_generator.py . 8 | """ 9 | 10 | __all__ = ["MaskGenerator"] 11 | 12 | 13 | class MaskGenerator: 14 | def __init__( 15 | self, 16 | input_size: int, 17 | num_mask_patches: int, 18 | min_num_patches: int = 4, 19 | max_num_patches: int = None, 20 | min_aspect: float = 0.3, 21 | max_aspect: float = None, 22 | ) -> None: 23 | if not isinstance(input_size, tuple): 24 | input_size = (input_size,) * 2 25 | self.height, self.width = input_size 26 | 27 | self.num_patches = self.height * self.width 28 | 29 | self.num_mask_patches = num_mask_patches 30 | 31 | self.min_num_patches = min_num_patches 32 | self.max_num_patches = ( 33 | num_mask_patches if max_num_patches is None else max_num_patches 34 | ) 35 | 36 | max_aspect = max_aspect or 1 / min_aspect 37 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 38 | 39 | def __repr__(self): 40 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( 41 | self.height, 42 | self.width, 43 | self.min_num_patches, 44 | self.max_num_patches, 45 | self.num_mask_patches, 46 | self.log_aspect_ratio[0], 47 | self.log_aspect_ratio[1], 48 | ) 49 | return repr_str 50 | 51 | def get_shape(self): 52 | return self.height, self.width 53 | 54 | def _mask(self, mask: np.array, max_mask_patches: int) -> int: 55 | delta = 0 56 | for _ in range(10): 57 | target_area = random.uniform(self.min_num_patches, max_mask_patches) 58 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 59 | h = int(round(math.sqrt(target_area * aspect_ratio))) 60 | w = int(round(math.sqrt(target_area / aspect_ratio))) 61 | if w < self.width and h < self.height: 62 | top = random.randint(0, self.height - h) 63 | left = random.randint(0, self.width - w) 64 | 65 | num_masked = mask[top : top + h, left : left + w].sum() 66 | if 0 < h * w - num_masked <= max_mask_patches: 67 | for i in range(top, top + h): 68 | for j in range(left, left + w): 69 | if mask[i, j] == 0: 70 | mask[i, j] = 1 71 | delta += 1 72 | if delta > 0: 73 | break 74 | return delta 75 | 76 | def __call__(self) -> Any: 77 | mask = np.zeros((self.height, self.width), dtype=np.int32) 78 | mask_count = 0 79 | while mask_count < self.num_mask_patches: 80 | max_mask_patches = self.num_mask_patches - mask_count 81 | max_mask_patches = min(max_mask_patches, self.max_num_patches) 82 | 83 | delta = self._mask(mask, max_mask_patches) 84 | if delta == 0: 85 | break 86 | else: 87 | mask_count += delta 88 | 89 | return mask 90 | 91 | 92 | if __name__ == "__main__": 93 | mg = MaskGenerator(input_size=14, num_mask_patches=75) 94 | mask = mg() 95 | print(mask) 96 | print(mg, mask.sum()) 97 | -------------------------------------------------------------------------------- /src/utils/misc.py: -------------------------------------------------------------------------------- 1 | import math 2 | import jsonlines 3 | from pathlib import Path 4 | from typing import Dict, Tuple, List, Union 5 | from torch import Tensor, nn 6 | 7 | __all__ = [ 8 | "cosine_schedule_with_warmup", 9 | "load_json_annotations", 10 | "bbox_augmentation_resize", 11 | "count_total_parameters", 12 | "compute_grad_norm", 13 | "printer", 14 | "html_table_template", 15 | ] 16 | 17 | printer = lambda device, output: f"[GPU {device}] " + output 18 | 19 | html_table_template = ( 20 | lambda table: f""" 21 | 22 | 28 | 29 | 30 | {table} 31 |
""" 32 | ) 33 | 34 | 35 | # adpated from https://github.com/huggingface/transformers/blob/v4.33.0/src/transformers/optimization.py 36 | def cosine_schedule_with_warmup( 37 | step: int, 38 | *, 39 | warmup: int, 40 | min_ratio: float, 41 | total_step: int, 42 | cycle: float = 0.5, 43 | ): 44 | if step < warmup: 45 | if step == 0: 46 | step = 1 47 | return float(step) / float(max(1, warmup)) 48 | 49 | if step >= total_step: 50 | step = total_step 51 | progress = float(step - warmup) / float(max(1, total_step - warmup)) 52 | return max( 53 | min_ratio, 0.5 * (1.0 + math.cos(math.pi * float(cycle) * 2.0 * progress)) 54 | ) 55 | 56 | 57 | def load_json_annotations(json_file_dir: Path, split: str): 58 | """Preprocess jsonl in dataset.""" 59 | image_label_pair = list() 60 | with jsonlines.open(json_file_dir) as f: 61 | for obj in f: 62 | if obj["split"] == split: 63 | image_label_pair.append((obj["filename"], obj["html"])) 64 | 65 | return image_label_pair 66 | 67 | 68 | def bbox_augmentation_resize( 69 | bbox: List[int], image_size: List[int], target_size: int 70 | ) -> List[int]: 71 | """Modify the bbox coordinates according to the image resizing.""" 72 | # Assuming the bbox is [xmin, ymin, xmax, ymax] 73 | assert len(image_size) == 2 74 | ratio = [target_size / i for i in image_size] 75 | ratio = ratio * 2 76 | bbox = [int(round(i * j)) for i, j in zip(bbox, ratio)] 77 | return bbox 78 | 79 | 80 | def count_total_parameters(model: nn.Module) -> int: 81 | """Count total parameters that need training.""" 82 | total_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 83 | return total_parameters 84 | 85 | 86 | def compute_grad_norm(model: nn.Module) -> float: 87 | total_norm = 0.0 88 | for p in model.parameters(): 89 | if p.grad is not None and p.requires_grad: 90 | param_norm = p.grad.detach().data.norm(2) 91 | total_norm += param_norm.item() ** 2 92 | total_norm = total_norm**0.5 93 | return total_norm 94 | -------------------------------------------------------------------------------- /src/utils/teds.py: -------------------------------------------------------------------------------- 1 | # code adapted from https://github.com/ibm-aur-nlp/PubTabNet/blob/master/src/metric.py 2 | # tree edit distance video explanation: https://www.youtube.com/watch?v=6Ur8B35xCj8 3 | import apted 4 | import distance 5 | from collections import deque 6 | from lxml import etree, html 7 | from tqdm import tqdm 8 | from concurrent.futures import ProcessPoolExecutor, as_completed 9 | from typing import Tuple 10 | 11 | 12 | class TableTree(apted.helpers.Tree): 13 | def __init__(self, tag, colspan=None, rowspan=None, content=None, *children): 14 | self.tag = tag 15 | self.colspan = colspan 16 | self.rowspan = rowspan 17 | self.content = content 18 | self.children = list(children) 19 | 20 | def bracket(self): 21 | """Show tree using brackets notation.""" 22 | if self.tag == "td": 23 | result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % ( 24 | self.tag, 25 | self.colspan, 26 | self.rowspan, 27 | self.content, 28 | ) 29 | else: 30 | result = '"tag": %s' % self.tag 31 | for child in self.children: 32 | result += child.bracket() 33 | return "{{{}}}".format(result) 34 | 35 | 36 | class CustomConfig(apted.Config): 37 | @staticmethod 38 | def maximum(*sequences): 39 | """Get maximum possible value.""" 40 | return max(map(len, sequences)) 41 | 42 | def normalized_distance(self, *sequences): 43 | """Get distance from 0 to 1.""" 44 | return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) 45 | 46 | def rename(self, node1, node2): 47 | """Compares attributes of trees""" 48 | if ( 49 | (node1.tag != node2.tag) 50 | or (node1.colspan != node2.colspan) 51 | or (node1.rowspan != node2.rowspan) 52 | ): 53 | return 1.0 54 | if node1.tag == "td": 55 | if node1.content or node2.content: 56 | return self.normalized_distance(node1.content, node2.content) 57 | return 0.0 58 | 59 | 60 | class TEDS(object): 61 | """Tree Edit Distance basead Similarity""" 62 | 63 | def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None): 64 | assert isinstance(n_jobs, int) and ( 65 | n_jobs >= 1 66 | ), "n_jobs must be an integer greather than 1" 67 | self.structure_only = structure_only 68 | self.n_jobs = n_jobs 69 | self.ignore_nodes = ignore_nodes 70 | self.__tokens__ = [] 71 | 72 | def tokenize(self, node): 73 | """Tokenizes table cells""" 74 | self.__tokens__.append("<%s>" % node.tag) 75 | if node.text is not None: 76 | self.__tokens__ += list(node.text) 77 | for n in node.getchildren(): 78 | self.tokenize(n) 79 | if node.tag != "unk": 80 | self.__tokens__.append("" % node.tag) 81 | if node.tag != "td" and node.tail is not None: 82 | self.__tokens__ += list(node.tail) 83 | 84 | def load_html_tree(self, node, parent=None): 85 | """Converts HTML tree to the format required by apted""" 86 | global __tokens__ 87 | if node.tag == "td": 88 | if self.structure_only: 89 | cell = [] 90 | else: 91 | self.__tokens__ = [] 92 | self.tokenize(node) 93 | cell = self.__tokens__[1:-1].copy() 94 | new_node = TableTree( 95 | node.tag, 96 | int(node.attrib.get("colspan", "1")), 97 | int(node.attrib.get("rowspan", "1")), 98 | cell, 99 | *deque(), 100 | ) 101 | else: 102 | new_node = TableTree(node.tag, None, None, None, *deque()) 103 | if parent is not None: 104 | parent.children.append(new_node) 105 | if node.tag != "td": 106 | for n in node.getchildren(): 107 | self.load_html_tree(n, new_node) 108 | if parent is None: 109 | return new_node 110 | 111 | def evaluate(self, pred, true): 112 | """Computes TEDS score between the prediction and the ground truth of a 113 | given sample 114 | """ 115 | if (not pred) or (not true): 116 | return 0.0 117 | parser = html.HTMLParser(remove_comments=True, encoding="utf-8") 118 | pred = html.fromstring(pred, parser=parser) 119 | true = html.fromstring(true, parser=parser) 120 | if pred.xpath("body/table") and true.xpath("body/table"): 121 | pred = pred.xpath("body/table")[0] 122 | true = true.xpath("body/table")[0] 123 | if self.ignore_nodes: 124 | etree.strip_tags(pred, *self.ignore_nodes) 125 | etree.strip_tags(true, *self.ignore_nodes) 126 | n_nodes_pred = len(pred.xpath(".//*")) 127 | n_nodes_true = len(true.xpath(".//*")) 128 | n_nodes = max(n_nodes_pred, n_nodes_true) 129 | tree_pred = self.load_html_tree(pred) 130 | tree_true = self.load_html_tree(true) 131 | distance = apted.APTED( 132 | tree_pred, tree_true, CustomConfig() 133 | ).compute_edit_distance() 134 | return 1.0 - (float(distance) / n_nodes) 135 | else: 136 | return 0.0 137 | 138 | def batch_evaluate(self, results_json): 139 | """Computes TEDS score between the prediction and the ground truth of 140 | a batch of samples 141 | @params pred_json: {'FILENAME': 'HTML CODE', ...} 142 | @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...} 143 | @output: {'FILENAME': 'TEDS SCORE', ...} 144 | """ 145 | samples = results_json.keys() 146 | print(f"Total samples: {len(samples)}") 147 | if self.n_jobs == 1: 148 | scores = [ 149 | self.evaluate( 150 | results_json[filename]["pred"], 151 | results_json[filename]["gt"], 152 | ) 153 | for filename in tqdm(samples) 154 | ] 155 | else: 156 | inputs = [ 157 | { 158 | "pred": results_json[filename]["pred"], 159 | "true": results_json[filename]["gt"], 160 | } 161 | for filename in samples 162 | ] 163 | scores = parallel_process( 164 | inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1 165 | ) 166 | output = dict() 167 | for i, j in zip(samples, scores): 168 | if "span" in results_json[i]["gt"]: 169 | output[i] = dict(scores=j, type="complex") 170 | else: 171 | output[i] = dict(scores=j, type="simple") 172 | # scores = dict(zip(samples, scores)) 173 | return output 174 | 175 | 176 | def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0): 177 | """ 178 | A parallel version of the map function with a progress bar. 179 | 180 | Args: 181 | array (array-like): An array to iterate over. 182 | function (function): A python function to apply to the elements of array 183 | n_jobs (int, default=16): The number of cores to use 184 | use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of 185 | keyword arguments to function 186 | front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job. 187 | Useful for catching bugs 188 | Returns: 189 | [function(array[0]), function(array[1]), ...] 190 | """ 191 | # We run the first few iterations serially to catch bugs 192 | if front_num > 0: 193 | front = [ 194 | function(**a) if use_kwargs else function(a) for a in array[:front_num] 195 | ] 196 | else: 197 | front = [] 198 | # If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging. 199 | if n_jobs == 1: 200 | return front + [ 201 | function(**a) if use_kwargs else function(a) 202 | for a in tqdm(array[front_num:]) 203 | ] 204 | # Assemble the workers 205 | with ProcessPoolExecutor(max_workers=n_jobs) as pool: 206 | # Pass the elements of array into function 207 | if use_kwargs: 208 | futures = [pool.submit(function, **a) for a in array[front_num:]] 209 | else: 210 | futures = [pool.submit(function, a) for a in array[front_num:]] 211 | kwargs = { 212 | "total": len(futures), 213 | "unit": "it", 214 | "unit_scale": True, 215 | "leave": True, 216 | } 217 | # Print out the progress as tasks complete 218 | for f in tqdm(as_completed(futures), **kwargs): 219 | pass 220 | out = [] 221 | # Get the results from the futures. 222 | for i, future in tqdm(enumerate(futures)): 223 | try: 224 | out.append(future.result()) 225 | except Exception as e: 226 | out.append(e) 227 | return front + out 228 | 229 | 230 | if __name__ == "__main__": 231 | import json 232 | import pprint 233 | import numpy as np 234 | import argparse 235 | 236 | parser = argparse.ArgumentParser(description="TEDS Computation") 237 | 238 | parser.add_argument("-f", "--file", help="path to html table results in json file") 239 | parser.add_argument("-t", "--type", help="html, html+cell") 240 | parser.add_argument("-n", "--njob", default=200, help="number of jobs in parallel") 241 | args = parser.parse_args() 242 | 243 | results_file = args.file 244 | with open(results_file, "r") as f: 245 | results_json = json.load(f) 246 | 247 | if args.type == "html": 248 | s_only = True 249 | else: 250 | s_only = False 251 | teds = TEDS(structure_only=s_only, n_jobs=args.njob) 252 | scores = teds.batch_evaluate(results_json) 253 | pp = pprint.PrettyPrinter() 254 | pp.pprint(scores) 255 | 256 | # compute teds for simple and complex tables 257 | total, simple, complex = list(), list(), list() 258 | for _, obj in scores.items(): 259 | if obj["type"] == "simple": 260 | simple.append(obj["scores"]) 261 | elif obj["type"] == "complex": 262 | complex.append(obj["scores"]) 263 | total.append(obj["scores"]) 264 | 265 | total, simple, complex = np.array(total), np.array(simple), np.array(complex) 266 | print( 267 | f"Simple: {np.mean(simple)} \nComplex: {np.mean(complex)} \nTotal: {np.mean(total)}" 268 | ) 269 | -------------------------------------------------------------------------------- /src/utils/visualization.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import numpy as np 3 | 4 | 5 | def normalize_image_for_visualization(mean: float, std: float): 6 | invNormalization = transforms.Compose( 7 | [ 8 | transforms.Normalize(mean=[0.0] * 3, std=1.0 / np.array(std)), 9 | transforms.Normalize(mean=-1.0 * np.array(mean), std=[1.0] * 3), 10 | ] 11 | ) 12 | 13 | return invNormalization 14 | -------------------------------------------------------------------------------- /src/vocab/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !constant.py -------------------------------------------------------------------------------- /src/vocab/__init__.py: -------------------------------------------------------------------------------- 1 | from .constant import * 2 | -------------------------------------------------------------------------------- /src/vocab/constant.py: -------------------------------------------------------------------------------- 1 | SPECIAL_TOKENS = ["", "", "", "", "", ""] 2 | TASK_TOKENS = ["[table]", "[html]", "[cell]", "[bbox]", "[cell+bbox]"] 3 | RESERVED_TOKENS = [ 4 | f"reserved {i+1}" for i in range(20 - len(SPECIAL_TOKENS) - len(TASK_TOKENS)) 5 | ] 6 | CELL_NUM_TOKENS = [f"{i+1}-cell(s)" for i in range(100)] 7 | BBOX_TOKENS = [f"bbox-{i}" for i in range(880)] 8 | 9 | HTML_TOKENS = [ 10 | "", 11 | "[]", 12 | "", 14 | ">[]", 15 | "", 16 | "", 17 | "", 18 | "", 19 | "", 20 | "", 21 | ' rowspan="2"', 22 | ' rowspan="3"', 23 | ' rowspan="4"', 24 | ' rowspan="5"', 25 | ' rowspan="6"', 26 | ' rowspan="7"', 27 | ' rowspan="8"', 28 | ' rowspan="9"', 29 | ' rowspan="10"', 30 | ' rowspan="11"', 31 | ' rowspan="12"', 32 | ' rowspan="13"', 33 | ' rowspan="14"', 34 | ' rowspan="15"', 35 | ' rowspan="16"', 36 | ' rowspan="17"', 37 | ' rowspan="18"', 38 | ' rowspan="19"', 39 | ' colspan="2"', 40 | ' colspan="3"', 41 | ' colspan="4"', 42 | ' colspan="5"', 43 | ' colspan="6"', 44 | ' colspan="7"', 45 | ' colspan="8"', 46 | ' colspan="9"', 47 | ' colspan="10"', 48 | ' colspan="11"', 49 | ' colspan="12"', 50 | ' colspan="13"', 51 | ' colspan="14"', 52 | ' colspan="15"', 53 | ' colspan="16"', 54 | ' colspan="17"', 55 | ' colspan="18"', 56 | ' colspan="19"', 57 | ' colspan="25"', 58 | ] 59 | 60 | CELL_SPECIAL = ["", "", "", "", "", "", "", ""] 61 | -------------------------------------------------------------------------------- /vocab/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !vocab_html.json 4 | !vocab_bbox.json 5 | !vocab_cell_6k.json -------------------------------------------------------------------------------- /vocab/vocab_html.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "1.0", 3 | "truncation": null, 4 | "padding": { 5 | "strategy": "BatchLongest", 6 | "direction": "Right", 7 | "pad_to_multiple_of": null, 8 | "pad_id": 2, 9 | "pad_type_id": 0, 10 | "pad_token": "" 11 | }, 12 | "added_tokens": [ 13 | { 14 | "id": 0, 15 | "content": "", 16 | "single_word": false, 17 | "lstrip": false, 18 | "rstrip": false, 19 | "normalized": false, 20 | "special": true 21 | }, 22 | { 23 | "id": 1, 24 | "content": "", 25 | "single_word": false, 26 | "lstrip": false, 27 | "rstrip": false, 28 | "normalized": false, 29 | "special": true 30 | }, 31 | { 32 | "id": 2, 33 | "content": "", 34 | "single_word": false, 35 | "lstrip": false, 36 | "rstrip": false, 37 | "normalized": false, 38 | "special": true 39 | }, 40 | { 41 | "id": 3, 42 | "content": "", 43 | "single_word": false, 44 | "lstrip": false, 45 | "rstrip": false, 46 | "normalized": false, 47 | "special": true 48 | }, 49 | { 50 | "id": 4, 51 | "content": "", 52 | "single_word": false, 53 | "lstrip": false, 54 | "rstrip": false, 55 | "normalized": false, 56 | "special": true 57 | }, 58 | { 59 | "id": 5, 60 | "content": "", 61 | "single_word": false, 62 | "lstrip": false, 63 | "rstrip": false, 64 | "normalized": false, 65 | "special": true 66 | }, 67 | { 68 | "id": 6, 69 | "content": "[table]", 70 | "single_word": false, 71 | "lstrip": false, 72 | "rstrip": false, 73 | "normalized": false, 74 | "special": true 75 | }, 76 | { 77 | "id": 7, 78 | "content": "[html]", 79 | "single_word": false, 80 | "lstrip": false, 81 | "rstrip": false, 82 | "normalized": false, 83 | "special": true 84 | }, 85 | { 86 | "id": 8, 87 | "content": "[cell]", 88 | "single_word": false, 89 | "lstrip": false, 90 | "rstrip": false, 91 | "normalized": false, 92 | "special": true 93 | }, 94 | { 95 | "id": 9, 96 | "content": "[bbox]", 97 | "single_word": false, 98 | "lstrip": false, 99 | "rstrip": false, 100 | "normalized": false, 101 | "special": true 102 | }, 103 | { 104 | "id": 10, 105 | "content": "[cell+bbox]", 106 | "single_word": false, 107 | "lstrip": false, 108 | "rstrip": false, 109 | "normalized": false, 110 | "special": true 111 | }, 112 | { 113 | "id": 11, 114 | "content": "", 115 | "single_word": false, 116 | "lstrip": false, 117 | "rstrip": false, 118 | "normalized": false, 119 | "special": true 120 | }, 121 | { 122 | "id": 12, 123 | "content": "[]", 124 | "single_word": false, 125 | "lstrip": false, 126 | "rstrip": false, 127 | "normalized": false, 128 | "special": true 129 | }, 130 | { 131 | "id": 13, 132 | "content": "", 142 | "single_word": false, 143 | "lstrip": false, 144 | "rstrip": false, 145 | "normalized": false, 146 | "special": true 147 | }, 148 | { 149 | "id": 15, 150 | "content": ">[]", 151 | "single_word": false, 152 | "lstrip": false, 153 | "rstrip": false, 154 | "normalized": false, 155 | "special": true 156 | }, 157 | { 158 | "id": 16, 159 | "content": "", 160 | "single_word": false, 161 | "lstrip": false, 162 | "rstrip": false, 163 | "normalized": false, 164 | "special": true 165 | }, 166 | { 167 | "id": 17, 168 | "content": "", 169 | "single_word": false, 170 | "lstrip": false, 171 | "rstrip": false, 172 | "normalized": false, 173 | "special": true 174 | }, 175 | { 176 | "id": 18, 177 | "content": "", 178 | "single_word": false, 179 | "lstrip": false, 180 | "rstrip": false, 181 | "normalized": false, 182 | "special": true 183 | }, 184 | { 185 | "id": 19, 186 | "content": "", 187 | "single_word": false, 188 | "lstrip": false, 189 | "rstrip": false, 190 | "normalized": false, 191 | "special": true 192 | }, 193 | { 194 | "id": 20, 195 | "content": "", 196 | "single_word": false, 197 | "lstrip": false, 198 | "rstrip": false, 199 | "normalized": false, 200 | "special": true 201 | }, 202 | { 203 | "id": 21, 204 | "content": "", 205 | "single_word": false, 206 | "lstrip": false, 207 | "rstrip": false, 208 | "normalized": false, 209 | "special": true 210 | }, 211 | { 212 | "id": 22, 213 | "content": " rowspan=\"2\"", 214 | "single_word": false, 215 | "lstrip": false, 216 | "rstrip": false, 217 | "normalized": false, 218 | "special": true 219 | }, 220 | { 221 | "id": 23, 222 | "content": " rowspan=\"3\"", 223 | "single_word": false, 224 | "lstrip": false, 225 | "rstrip": false, 226 | "normalized": false, 227 | "special": true 228 | }, 229 | { 230 | "id": 24, 231 | "content": " rowspan=\"4\"", 232 | "single_word": false, 233 | "lstrip": false, 234 | "rstrip": false, 235 | "normalized": false, 236 | "special": true 237 | }, 238 | { 239 | "id": 25, 240 | "content": " rowspan=\"5\"", 241 | "single_word": false, 242 | "lstrip": false, 243 | "rstrip": false, 244 | "normalized": false, 245 | "special": true 246 | }, 247 | { 248 | "id": 26, 249 | "content": " rowspan=\"6\"", 250 | "single_word": false, 251 | "lstrip": false, 252 | "rstrip": false, 253 | "normalized": false, 254 | "special": true 255 | }, 256 | { 257 | "id": 27, 258 | "content": " rowspan=\"7\"", 259 | "single_word": false, 260 | "lstrip": false, 261 | "rstrip": false, 262 | "normalized": false, 263 | "special": true 264 | }, 265 | { 266 | "id": 28, 267 | "content": " rowspan=\"8\"", 268 | "single_word": false, 269 | "lstrip": false, 270 | "rstrip": false, 271 | "normalized": false, 272 | "special": true 273 | }, 274 | { 275 | "id": 29, 276 | "content": " rowspan=\"9\"", 277 | "single_word": false, 278 | "lstrip": false, 279 | "rstrip": false, 280 | "normalized": false, 281 | "special": true 282 | }, 283 | { 284 | "id": 30, 285 | "content": " rowspan=\"10\"", 286 | "single_word": false, 287 | "lstrip": false, 288 | "rstrip": false, 289 | "normalized": false, 290 | "special": true 291 | }, 292 | { 293 | "id": 31, 294 | "content": " rowspan=\"11\"", 295 | "single_word": false, 296 | "lstrip": false, 297 | "rstrip": false, 298 | "normalized": false, 299 | "special": true 300 | }, 301 | { 302 | "id": 32, 303 | "content": " rowspan=\"12\"", 304 | "single_word": false, 305 | "lstrip": false, 306 | "rstrip": false, 307 | "normalized": false, 308 | "special": true 309 | }, 310 | { 311 | "id": 33, 312 | "content": " rowspan=\"13\"", 313 | "single_word": false, 314 | "lstrip": false, 315 | "rstrip": false, 316 | "normalized": false, 317 | "special": true 318 | }, 319 | { 320 | "id": 34, 321 | "content": " rowspan=\"14\"", 322 | "single_word": false, 323 | "lstrip": false, 324 | "rstrip": false, 325 | "normalized": false, 326 | "special": true 327 | }, 328 | { 329 | "id": 35, 330 | "content": " rowspan=\"15\"", 331 | "single_word": false, 332 | "lstrip": false, 333 | "rstrip": false, 334 | "normalized": false, 335 | "special": true 336 | }, 337 | { 338 | "id": 36, 339 | "content": " rowspan=\"16\"", 340 | "single_word": false, 341 | "lstrip": false, 342 | "rstrip": false, 343 | "normalized": false, 344 | "special": true 345 | }, 346 | { 347 | "id": 37, 348 | "content": " rowspan=\"17\"", 349 | "single_word": false, 350 | "lstrip": false, 351 | "rstrip": false, 352 | "normalized": false, 353 | "special": true 354 | }, 355 | { 356 | "id": 38, 357 | "content": " rowspan=\"18\"", 358 | "single_word": false, 359 | "lstrip": false, 360 | "rstrip": false, 361 | "normalized": false, 362 | "special": true 363 | }, 364 | { 365 | "id": 39, 366 | "content": " rowspan=\"19\"", 367 | "single_word": false, 368 | "lstrip": false, 369 | "rstrip": false, 370 | "normalized": false, 371 | "special": true 372 | }, 373 | { 374 | "id": 40, 375 | "content": " colspan=\"2\"", 376 | "single_word": false, 377 | "lstrip": false, 378 | "rstrip": false, 379 | "normalized": false, 380 | "special": true 381 | }, 382 | { 383 | "id": 41, 384 | "content": " colspan=\"3\"", 385 | "single_word": false, 386 | "lstrip": false, 387 | "rstrip": false, 388 | "normalized": false, 389 | "special": true 390 | }, 391 | { 392 | "id": 42, 393 | "content": " colspan=\"4\"", 394 | "single_word": false, 395 | "lstrip": false, 396 | "rstrip": false, 397 | "normalized": false, 398 | "special": true 399 | }, 400 | { 401 | "id": 43, 402 | "content": " colspan=\"5\"", 403 | "single_word": false, 404 | "lstrip": false, 405 | "rstrip": false, 406 | "normalized": false, 407 | "special": true 408 | }, 409 | { 410 | "id": 44, 411 | "content": " colspan=\"6\"", 412 | "single_word": false, 413 | "lstrip": false, 414 | "rstrip": false, 415 | "normalized": false, 416 | "special": true 417 | }, 418 | { 419 | "id": 45, 420 | "content": " colspan=\"7\"", 421 | "single_word": false, 422 | "lstrip": false, 423 | "rstrip": false, 424 | "normalized": false, 425 | "special": true 426 | }, 427 | { 428 | "id": 46, 429 | "content": " colspan=\"8\"", 430 | "single_word": false, 431 | "lstrip": false, 432 | "rstrip": false, 433 | "normalized": false, 434 | "special": true 435 | }, 436 | { 437 | "id": 47, 438 | "content": " colspan=\"9\"", 439 | "single_word": false, 440 | "lstrip": false, 441 | "rstrip": false, 442 | "normalized": false, 443 | "special": true 444 | }, 445 | { 446 | "id": 48, 447 | "content": " colspan=\"10\"", 448 | "single_word": false, 449 | "lstrip": false, 450 | "rstrip": false, 451 | "normalized": false, 452 | "special": true 453 | }, 454 | { 455 | "id": 49, 456 | "content": " colspan=\"11\"", 457 | "single_word": false, 458 | "lstrip": false, 459 | "rstrip": false, 460 | "normalized": false, 461 | "special": true 462 | }, 463 | { 464 | "id": 50, 465 | "content": " colspan=\"12\"", 466 | "single_word": false, 467 | "lstrip": false, 468 | "rstrip": false, 469 | "normalized": false, 470 | "special": true 471 | }, 472 | { 473 | "id": 51, 474 | "content": " colspan=\"13\"", 475 | "single_word": false, 476 | "lstrip": false, 477 | "rstrip": false, 478 | "normalized": false, 479 | "special": true 480 | }, 481 | { 482 | "id": 52, 483 | "content": " colspan=\"14\"", 484 | "single_word": false, 485 | "lstrip": false, 486 | "rstrip": false, 487 | "normalized": false, 488 | "special": true 489 | }, 490 | { 491 | "id": 53, 492 | "content": " colspan=\"15\"", 493 | "single_word": false, 494 | "lstrip": false, 495 | "rstrip": false, 496 | "normalized": false, 497 | "special": true 498 | }, 499 | { 500 | "id": 54, 501 | "content": " colspan=\"16\"", 502 | "single_word": false, 503 | "lstrip": false, 504 | "rstrip": false, 505 | "normalized": false, 506 | "special": true 507 | }, 508 | { 509 | "id": 55, 510 | "content": " colspan=\"17\"", 511 | "single_word": false, 512 | "lstrip": false, 513 | "rstrip": false, 514 | "normalized": false, 515 | "special": true 516 | }, 517 | { 518 | "id": 56, 519 | "content": " colspan=\"18\"", 520 | "single_word": false, 521 | "lstrip": false, 522 | "rstrip": false, 523 | "normalized": false, 524 | "special": true 525 | }, 526 | { 527 | "id": 57, 528 | "content": " colspan=\"19\"", 529 | "single_word": false, 530 | "lstrip": false, 531 | "rstrip": false, 532 | "normalized": false, 533 | "special": true 534 | }, 535 | { 536 | "id": 58, 537 | "content": " colspan=\"25\"", 538 | "single_word": false, 539 | "lstrip": false, 540 | "rstrip": false, 541 | "normalized": false, 542 | "special": true 543 | } 544 | ], 545 | "normalizer": { 546 | "type": "Sequence", 547 | "normalizers": [ 548 | { 549 | "type": "NFD" 550 | }, 551 | { 552 | "type": "Lowercase" 553 | }, 554 | { 555 | "type": "StripAccents" 556 | }, 557 | { 558 | "type": "Strip", 559 | "strip_left": true, 560 | "strip_right": true 561 | } 562 | ] 563 | }, 564 | "pre_tokenizer": { 565 | "type": "Whitespace" 566 | }, 567 | "post_processor": null, 568 | "decoder": { 569 | "type": "WordPiece", 570 | "prefix": "##", 571 | "cleanup": true 572 | }, 573 | "model": { 574 | "type": "WordPiece", 575 | "unk_token": "", 576 | "continuing_subword_prefix": "##", 577 | "max_input_chars_per_word": 100, 578 | "vocab": { 579 | "": 0, 580 | "": 1, 581 | "": 2, 582 | "": 3, 583 | "": 4, 584 | "": 5, 585 | "[table]": 6, 586 | "[html]": 7, 587 | "[cell]": 8, 588 | "[bbox]": 9, 589 | "[cell+bbox]": 10, 590 | "": 11, 591 | "[]": 12, 592 | "": 14, 594 | ">[]": 15, 595 | "": 16, 596 | "": 17, 597 | "": 18, 598 | "": 19, 599 | "": 20, 600 | "": 21, 601 | " rowspan=\"2\"": 22, 602 | " rowspan=\"3\"": 23, 603 | " rowspan=\"4\"": 24, 604 | " rowspan=\"5\"": 25, 605 | " rowspan=\"6\"": 26, 606 | " rowspan=\"7\"": 27, 607 | " rowspan=\"8\"": 28, 608 | " rowspan=\"9\"": 29, 609 | " rowspan=\"10\"": 30, 610 | " rowspan=\"11\"": 31, 611 | " rowspan=\"12\"": 32, 612 | " rowspan=\"13\"": 33, 613 | " rowspan=\"14\"": 34, 614 | " rowspan=\"15\"": 35, 615 | " rowspan=\"16\"": 36, 616 | " rowspan=\"17\"": 37, 617 | " rowspan=\"18\"": 38, 618 | " rowspan=\"19\"": 39, 619 | " colspan=\"2\"": 40, 620 | " colspan=\"3\"": 41, 621 | " colspan=\"4\"": 42, 622 | " colspan=\"5\"": 43, 623 | " colspan=\"6\"": 44, 624 | " colspan=\"7\"": 45, 625 | " colspan=\"8\"": 46, 626 | " colspan=\"9\"": 47, 627 | " colspan=\"10\"": 48, 628 | " colspan=\"11\"": 49, 629 | " colspan=\"12\"": 50, 630 | " colspan=\"13\"": 51, 631 | " colspan=\"14\"": 52, 632 | " colspan=\"15\"": 53, 633 | " colspan=\"16\"": 54, 634 | " colspan=\"17\"": 55, 635 | " colspan=\"18\"": 56, 636 | " colspan=\"19\"": 57, 637 | " colspan=\"25\"": 58 638 | } 639 | } 640 | } -------------------------------------------------------------------------------- /website/unitable-demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/website/unitable-demo.gif -------------------------------------------------------------------------------- /website/unitable-demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/website/unitable-demo.mp4 -------------------------------------------------------------------------------- /website/wandb_screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/website/wandb_screenshot.png --------------------------------------------------------------------------------