├── .gitignore
├── CONFIG.mk
├── LICENSE
├── Makefile
├── README.md
├── configs
├── dataset
│ ├── augmentation
│ │ ├── beit.yaml
│ │ ├── resize_normalize.yaml
│ │ └── vqvae.yaml
│ ├── concat_dataset.yaml
│ ├── fintabnet
│ │ ├── test_dataset.yaml
│ │ ├── train_dataset.yaml
│ │ └── valid_dataset.yaml
│ ├── icdar
│ │ ├── test_dataset.yaml
│ │ ├── train_dataset.yaml
│ │ └── valid_dataset.yaml
│ ├── mini_pubtabnet
│ │ ├── test_dataset.yaml
│ │ ├── train_dataset.yaml
│ │ └── valid_dataset.yaml
│ ├── pubtables1m
│ │ ├── test_dataset.yaml
│ │ ├── train_dataset.yaml
│ │ └── valid_dataset.yaml
│ ├── pubtabnet
│ │ ├── test_dataset.yaml
│ │ ├── train_dataset.yaml
│ │ └── valid_dataset.yaml
│ ├── single_dataset.yaml
│ ├── synthtabnet_fintabnet
│ │ ├── test_dataset.yaml
│ │ ├── train_dataset.yaml
│ │ └── valid_dataset.yaml
│ ├── synthtabnet_marketing
│ │ ├── test_dataset.yaml
│ │ ├── train_dataset.yaml
│ │ └── valid_dataset.yaml
│ ├── synthtabnet_pubtabnet
│ │ ├── test_dataset.yaml
│ │ ├── train_dataset.yaml
│ │ └── valid_dataset.yaml
│ ├── synthtabnet_sparse
│ │ ├── test_dataset.yaml
│ │ ├── train_dataset.yaml
│ │ └── valid_dataset.yaml
│ └── tablebank
│ │ ├── test_dataset.yaml
│ │ ├── train_dataset.yaml
│ │ └── valid_dataset.yaml
├── main.yaml
├── model
│ ├── beit.yaml
│ ├── encoderdecoder.yaml
│ ├── model
│ │ ├── backbone
│ │ │ ├── imgcnn.yaml
│ │ │ ├── imgconvstem.yaml
│ │ │ └── imglinear.yaml
│ │ ├── decoder
│ │ │ └── transformer.yaml
│ │ └── encoder
│ │ │ └── transformer.yaml
│ └── vqvae.yaml
├── trainer
│ ├── beit.yaml
│ ├── table.yaml
│ ├── train
│ │ ├── lr_scheduler
│ │ │ ├── cosine.yaml
│ │ │ ├── exponential.yaml
│ │ │ └── step.yaml
│ │ └── optimizer
│ │ │ ├── adam.yaml
│ │ │ └── adamw.yaml
│ └── vqvae.yaml
└── vocab
│ ├── bbox.yaml
│ ├── cell.yaml
│ ├── empty.yaml
│ └── html.yaml
├── dataset
└── mini_pubtabnet
│ ├── mini_pubtabnet_examples.jsonl
│ ├── train
│ ├── PMC1626454_002_00.png
│ ├── PMC2753619_002_00.png
│ ├── PMC2759935_007_01.png
│ ├── PMC2838834_005_00.png
│ ├── PMC3519711_003_00.png
│ ├── PMC3826085_003_00.png
│ ├── PMC3907710_006_00.png
│ ├── PMC4003957_018_00.png
│ ├── PMC4172848_007_00.png
│ ├── PMC4517499_004_00.png
│ ├── PMC4682394_003_00.png
│ ├── PMC4776821_005_00.png
│ ├── PMC4840965_004_00.png
│ ├── PMC5134617_013_00.png
│ ├── PMC5198506_004_00.png
│ ├── PMC5332562_005_00.png
│ ├── PMC5402779_004_00.png
│ ├── PMC5577841_001_00.png
│ ├── PMC5679144_002_01.png
│ └── PMC5897438_004_00.png
│ └── val
│ ├── PMC1626454_002_00.png
│ ├── PMC2753619_002_00.png
│ ├── PMC2759935_007_01.png
│ ├── PMC2838834_005_00.png
│ ├── PMC3519711_003_00.png
│ ├── PMC3826085_003_00.png
│ ├── PMC3907710_006_00.png
│ ├── PMC4003957_018_00.png
│ ├── PMC4172848_007_00.png
│ ├── PMC4517499_004_00.png
│ ├── PMC4682394_003_00.png
│ ├── PMC4776821_005_00.png
│ ├── PMC4840965_004_00.png
│ ├── PMC5134617_013_00.png
│ ├── PMC5198506_004_00.png
│ ├── PMC5332562_005_00.png
│ ├── PMC5402779_004_00.png
│ ├── PMC5577841_001_00.png
│ ├── PMC5679144_002_01.png
│ └── PMC5897438_004_00.png
├── experiments
└── .gitignore
├── notebooks
├── .gitignore
└── full_pipeline.ipynb
├── requirements.txt
├── setup.py
├── src
├── datamodule
│ ├── __init__.py
│ ├── augmentation.py
│ ├── dataloader.py
│ ├── fintabnet.py
│ ├── pubtables1m.py
│ ├── pubtabnet.py
│ ├── synthtabnet.py
│ └── tablebank.py
├── main.py
├── model
│ ├── __init__.py
│ ├── beit.py
│ ├── components.py
│ ├── encoderdecoder.py
│ └── vqvae.py
├── trainer
│ ├── __init__.py
│ ├── train_beit.py
│ ├── train_table.py
│ ├── train_vqvae.py
│ └── utils.py
├── utils
│ ├── __init__.py
│ ├── coco_map.py
│ ├── data.py
│ ├── engine.py
│ ├── mask_generator.py
│ ├── misc.py
│ ├── teds.py
│ └── visualization.py
└── vocab
│ ├── .gitignore
│ ├── __init__.py
│ └── constant.py
├── vocab
├── .gitignore
├── vocab_bbox.json
├── vocab_cell_6k.json
└── vocab_html.json
└── website
├── unitable-demo.gif
├── unitable-demo.mp4
└── wandb_screenshot.png
/.gitignore:
--------------------------------------------------------------------------------
1 | # VS Code
2 | .history
3 |
4 | # make targets & code outputs
5 | .done*
6 | outputs
7 |
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 | cover/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | .pybuilder/
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # IPython
89 | profile_default/
90 | ipython_config.py
91 |
92 | # pyenv
93 | # For a library or package, you might want to ignore these files since the code is
94 | # intended to run in multiple environments; otherwise, check them in:
95 | # .python-version
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # poetry
105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106 | # This is especially recommended for binary packages to ensure reproducibility, and is more
107 | # commonly ignored for libraries.
108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109 | #poetry.lock
110 |
111 | # pdm
112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113 | #pdm.lock
114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115 | # in version control.
116 | # https://pdm.fming.dev/#use-with-ide
117 | .pdm.toml
118 |
119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120 | __pypackages__/
121 |
122 | # Celery stuff
123 | celerybeat-schedule
124 | celerybeat.pid
125 |
126 | # SageMath parsed files
127 | *.sage.py
128 |
129 | # Environments
130 | .env
131 | .venv
132 | env/
133 | venv/
134 | ENV/
135 | env.bak/
136 | venv.bak/
137 |
138 | # Spyder project settings
139 | .spyderproject
140 | .spyproject
141 |
142 | # Rope project settings
143 | .ropeproject
144 |
145 | # mkdocs documentation
146 | /site
147 |
148 | # mypy
149 | .mypy_cache/
150 | .dmypy.json
151 | dmypy.json
152 |
153 | # Pyre type checker
154 | .pyre/
155 |
156 | # pytype static type analyzer
157 | .pytype/
158 |
159 | # Cython debug symbols
160 | cython_debug/
161 |
162 | # PyCharm
163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165 | # and can be added to the global gitignore or merged into this file. For a more nuclear
166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167 | #.idea/
168 |
--------------------------------------------------------------------------------
/CONFIG.mk:
--------------------------------------------------------------------------------
1 | ##################################################
2 | # Configurations #
3 | ##################################################
4 |
5 | #
6 | # Datasets
7 | #
8 |
9 | # label type
10 | LABEL_IMAGE = ++trainer.label_type="image"
11 | LABEL_HTML = ++trainer.label_type="html" "++trainer.train.loss_weights.html=1"
12 | LABEL_CELL = ++trainer.label_type="cell" "++trainer.train.loss_weights.cell=1"
13 | LABEL_BBOX = ++trainer.label_type="bbox" "++trainer.train.loss_weights.bbox=1"
14 | MEAN = [0.86597056,0.88463002,0.87491087]
15 | STD = [0.20686628,0.18201602,0.18485524]
16 |
17 | # augmentation
18 | AUG_VQVAE = dataset/augmentation=vqvae
19 | AUG_BEIT = dataset/augmentation=beit \
20 | ++dataset.augmentation.mean=$(MEAN) ++dataset.augmentation.std=$(STD)
21 | AUG_RESIZE_NORM = dataset/augmentation=resize_normalize \
22 | ++dataset.augmentation.transforms.2.mean=$(MEAN) ++dataset.augmentation.transforms.2.std=$(STD)
23 |
24 | # single dataset
25 | DATA_SINGLE = dataset=single_dataset
26 | PUBTABNET = $(DATA_SINGLE) \
27 | +dataset/pubtabnet@dataset.train_dataset=train_dataset \
28 | +dataset/pubtabnet@dataset.valid_dataset=valid_dataset \
29 | +dataset/pubtabnet@dataset.test_dataset=test_dataset
30 | MINIPUBTABNET = $(DATA_SINGLE) \
31 | +dataset/mini_pubtabnet@dataset.train_dataset=train_dataset \
32 | +dataset/mini_pubtabnet@dataset.valid_dataset=valid_dataset \
33 | +dataset/mini_pubtabnet@dataset.test_dataset=test_dataset
34 |
35 | # multiple datasets
36 | DATA_MULTI = dataset=concat_dataset
37 | PUBTABNET_M = +dataset/pubtabnet@dataset.train.d1=train_dataset \
38 | +dataset/pubtabnet@dataset.valid.d1=valid_dataset \
39 | +dataset/pubtabnet@dataset.test.d1=test_dataset
40 | SYN_MARKET_M = +dataset/synthtabnet_marketing@dataset.train.d2=train_dataset \
41 | +dataset/synthtabnet_marketing@dataset.valid.d2=valid_dataset \
42 | +dataset/synthtabnet_marketing@dataset.test.d2=test_dataset
43 | SYN_FIN_M = +dataset/synthtabnet_fintabnet@dataset.train.d3=train_dataset \
44 | +dataset/synthtabnet_fintabnet@dataset.valid.d3=valid_dataset \
45 | +dataset/synthtabnet_fintabnet@dataset.test.d3=test_dataset
46 | SYN_SPARSE_M = +dataset/synthtabnet_sparse@dataset.train.d4=train_dataset \
47 | +dataset/synthtabnet_sparse@dataset.valid.d4=valid_dataset \
48 | +dataset/synthtabnet_sparse@dataset.test.d4=test_dataset
49 | SYN_PUB_M = +dataset/synthtabnet_pubtabnet@dataset.train.d5=train_dataset \
50 | +dataset/synthtabnet_pubtabnet@dataset.valid.d5=valid_dataset \
51 | +dataset/synthtabnet_pubtabnet@dataset.test.d5=test_dataset
52 | PUBTABLES_M = +dataset/pubtables1m@dataset.train.d7=train_dataset \
53 | +dataset/pubtables1m@dataset.valid.d7=valid_dataset \
54 | +dataset/pubtables1m@dataset.test.d7=test_dataset
55 | TABLEBANK_M = +dataset/tablebank@dataset.train.d8=train_dataset \
56 | +dataset/tablebank@dataset.valid.d8=valid_dataset \
57 | +dataset/tablebank@dataset.test.d8=test_dataset
58 | FINTABNET_M = +dataset/fintabnet@dataset.train.d9=train_dataset \
59 | +dataset/fintabnet@dataset.valid.d9=valid_dataset \
60 | +dataset/fintabnet@dataset.test.d9=test_dataset
61 |
62 | DATA_VQVAE_1M = $(DATA_MULTI) \
63 | $(PUBTABNET_M) $(SYN_MARKET_M) $(SYN_FIN_M) $(SYN_SPARSE_M)
64 | DATA_VQVAE_2M = $(DATA_MULTI) \
65 | $(PUBTABNET_M) $(SYN_MARKET_M) $(SYN_FIN_M) $(SYN_SPARSE_M) $(SYN_PUB_M) \
66 | $(PUBTABLES_M) $(TABLEBANK_M)
67 |
68 | PUBTABLES1M = $(DATA_MULTI) $(PUBTABLES_M)
69 | FINTABNET = $(DATA_MULTI) $(FINTABNET_M)
70 |
71 | PUB_SYN = $(DATA_MULTI) \
72 | $(PUBTABNET_M) $(SYN_MARKET_M) $(SYN_FIN_M) $(SYN_SPARSE_M) $(SYN_PUB_M)
73 |
74 | PUB_SYN_FIN = $(DATA_MULTI) $(PUBTABNET_M) $(FINTABNET_M) \
75 | $(SYN_MARKET_M) $(SYN_FIN_M) $(SYN_SPARSE_M) $(SYN_PUB_M)
76 |
77 | PUB_SYN_PUB1M = $(DATA_MULTI) $(PUBTABNET_M) $(PUBTABLES_M) \
78 | $(SYN_MARKET_M) $(SYN_FIN_M) $(SYN_SPARSE_M) $(SYN_PUB_M)
79 |
80 | SYN = $(DATA_MULTI) $(SYN_MARKET_M) $(SYN_FIN_M) $(SYN_SPARSE_M) $(SYN_PUB_M)
81 |
82 | SYN_fin = $(DATA_MULTI) $(SYN_FIN_M)
83 | SYN_market = $(DATA_MULTI) $(SYN_MARKET_M)
84 | SYN_pub = $(DATA_MULTI) $(SYN_PUB_M)
85 | SYN_sparse = $(DATA_MULTI) $(SYN_SPARSE_M)
86 |
87 | #
88 | # Vocab
89 | #
90 | VOCAB_NONE = vocab=empty
91 | VOCAB_HTML = vocab=html
92 | VOCAB_BBOX = vocab=bbox
93 | VOCAB_CELL = vocab=cell
94 |
95 |
96 | #
97 | # Trainer
98 | #
99 |
100 | # trainer type
101 | TRAINER_VQVAE = trainer=vqvae
102 | TRAINER_BEIT = trainer=beit
103 | TRAINER_TABLE = trainer=table
104 |
105 | # input image size
106 | I224 = ++trainer.img_size=[224,224]
107 | I448 = ++trainer.img_size=[448,448]
108 | I112_448 = ++trainer.img_size=[112,448]
109 |
110 | # max sequence length
111 | SEQ200 = trainer.max_seq_len=200
112 | SEQ512 = trainer.max_seq_len=512
113 | SEQ1024 = trainer.max_seq_len=1024
114 |
115 | # batch size + epoch
116 | BATCH24 = ++trainer.train.dataloader.batch_size=24 ++trainer.valid.dataloader.batch_size=24
117 | BATCH48 = ++trainer.train.dataloader.batch_size=48 ++trainer.valid.dataloader.batch_size=48
118 | BATCH72 = ++trainer.train.dataloader.batch_size=72 ++trainer.valid.dataloader.batch_size=72
119 | BATCH80 = ++trainer.train.dataloader.batch_size=80 ++trainer.valid.dataloader.batch_size=80
120 | BATCH96 = ++trainer.train.dataloader.batch_size=96 ++trainer.valid.dataloader.batch_size=96
121 | BATCH256 = ++trainer.train.dataloader.batch_size=256 ++trainer.valid.dataloader.batch_size=256
122 | BATCH384 = ++trainer.train.dataloader.batch_size=384 ++trainer.valid.dataloader.batch_size=384
123 |
124 | EPOCH24 = ++trainer.train.epochs=24
125 | EPOCH30 = ++trainer.train.epochs=30
126 | EPOCH48 = ++trainer.train.epochs=48
127 |
128 | # optimizer
129 | OPT_ADAMW = trainer/train/optimizer=adamw
130 | OPT_WD5e2 = ++trainer.train.optimizer.weight_decay=5e-2
131 |
132 | # lr + scheduler
133 | LR_5e4 = ++trainer.train.optimizer.lr=5e-4
134 | LR_3e4 = ++trainer.train.optimizer.lr=3e-4
135 | LR_1e4 = ++trainer.train.optimizer.lr=1e-4
136 | LR_8e5 = ++trainer.train.optimizer.lr=8e-5
137 |
138 | LR_cosine = trainer/train/lr_scheduler=cosine ++trainer.train.lr_scheduler.lr_lambda.min_ratio=5e-3
139 | LR_cosine93k_warm6k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=93400 ++trainer.train.lr_scheduler.lr_lambda.warmup=5800
140 | LR_cosine77k_warm8k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=76600 ++trainer.train.lr_scheduler.lr_lambda.warmup=7660
141 | LR_cosine30k_warm4k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=30500 ++trainer.train.lr_scheduler.lr_lambda.warmup=4000
142 | LR_cosine8k_warm1k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=7600 ++trainer.train.lr_scheduler.lr_lambda.warmup=800
143 | LR_cosine44k_warm6k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=44100 ++trainer.train.lr_scheduler.lr_lambda.warmup=5500
144 | LR_cosine118k_warm15k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=117800 ++trainer.train.lr_scheduler.lr_lambda.warmup=14700
145 | LR_cosine216k_warm27k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=216000 ++trainer.train.lr_scheduler.lr_lambda.warmup=27000
146 | LR_cosine32k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=32000 ++trainer.train.lr_scheduler.lr_lambda.warmup=0
147 | LR_cosine118k = $(LR_cosine) ++trainer.train.lr_scheduler.lr_lambda.total_step=118000 ++trainer.train.lr_scheduler.lr_lambda.warmup=0
148 |
149 | GRAD_CLIP12 = ++trainer.train.grad_clip=12
150 |
151 | # vqvae
152 | VQVAE_TEMP_1M = ++trainer.train.starting_temp=1. \
153 | ++trainer.train.temp_min=5e-3 ++trainer.train.temp_anneal_rate=1e-3
154 | VQVAE_TEMP_2M = ++trainer.train.starting_temp=1. \
155 | ++trainer.train.temp_min=1e-3 ++trainer.train.temp_anneal_rate=2e-4
156 |
157 | # pretraining specific
158 | TRANS448_VQVAE224_GRID28_MASK300 = ++trainer.trans_size=[448,448] ++trainer.vqvae_size=[224,224] ++trainer.grid_size=28 ++trainer.num_mask_patches=300
159 | VQVAE1M_WEIGHTS = $(MODEL_VQVAE) ++trainer.vqvae_weights="../unitable_weights/vqvae_1m.pt"
160 | VQVAE2M_WEIGHTS = $(MODEL_VQVAE_L) ++trainer.vqvae_weights="../unitable_weights/vqvae_2m.pt"
161 |
162 | # finetuning specific
163 | WEIGHTS_mtim_1m_base = ++trainer.trainer.beit_pretrained_weights="../unitable_weights/ssp_1m_base.pt"
164 | WEIGHTS_mtim_1m_large = ++trainer.trainer.beit_pretrained_weights="../unitable_weights/ssp_1m_large.pt"
165 | WEIGHTS_mtim_2m_base = ++trainer.trainer.beit_pretrained_weights="../unitable_weights/ssp_2m_base.pt"
166 | WEIGHTS_mtim_2m_large = ++trainer.trainer.beit_pretrained_weights="../unitable_weights/ssp_2m_large.pt"
167 | LOCK_MTIM_4 = ++trainer.trainer.freeze_beit_epoch=4
168 |
169 | #
170 | # Models
171 | #
172 |
173 | # model type
174 | MODEL_VQVAE = model=vqvae
175 | MODEL_VQVAE_L = $(MODEL_VQVAE) ++model.codebook_tokens=16384 ++model.hidden_dim=512
176 | MODEL_BEIT = model=beit
177 | MODEL_ENCODER_DECODER = model=encoderdecoder
178 |
179 | # backbone for input preprocessing: resnet, linear projection, and convstem
180 | IMGCNN = model/model/backbone=imgcnn
181 | IMGLINEAR = model/model/backbone=imglinear
182 | IMGCONVSTEM = model/model/backbone=imgconvstem
183 |
184 | # number of layers
185 | E4 = ++model.model.encoder.nlayer=4
186 | E12 = ++model.model.encoder.nlayer=12
187 | E24 = ++model.model.encoder.nlayer=24
188 | D4 = ++model.model.decoder.nlayer=4
189 |
190 | # transformer layer: attention heads, hidden size, activation, norm
191 | FF4 = ++model.ff_ratio=4
192 |
193 | NHEAD8 = ++model.nhead=8
194 | NHEAD12 = ++model.nhead=12
195 |
196 | NORM_FIRST = ++model.norm_first=true
197 | NORM_LAST = ++model.norm_first=false
198 |
199 | ACT_RELU = ++model.activation="relu"
200 | ACT_GELU = ++model.activation="gelu"
201 |
202 | D_MODEL512 = ++model.d_model=512
203 | D_MODEL768 = ++model.d_model=768
204 |
205 | # regularization
206 | REG_d00 = ++model.dropout=0.0
207 | REG_d02 = ++model.dropout=0.2
208 |
209 | # linear projection patch size
210 | P16 = ++model.backbone_downsampling_factor=16
211 | P28 = ++model.backbone_downsampling_factor=28
212 | P32 = ++model.backbone_downsampling_factor=32
213 |
214 | # cnn backbone
215 | R18 = ++model.model.backbone.backbone._target_=torchvision.models.resnet18 \
216 | ++model.model.backbone.output_channels=512
217 |
218 | MTIM_BASE = $(MODEL_BEIT) $(IMGLINEAR) $(NHEAD8) $(FF4) $(ACT_GELU) \
219 | $(NORM_FIRST) $(D_MODEL512) $(REG_d02) $(P16) $(E4)
220 | MTIM_LARGE = $(MODEL_BEIT) $(IMGLINEAR) $(NHEAD12) $(FF4) $(ACT_GELU) \
221 | $(NORM_FIRST) $(D_MODEL768) $(REG_d02) $(P16) $(E12)
222 |
223 | ARCH_BASE = $(MTIM_BASE) $(MODEL_ENCODER_DECODER) $(D4)
224 | ARCH_LARGE = $(MTIM_LARGE) $(MODEL_ENCODER_DECODER) $(D4)
225 |
226 |
227 | ###############################################
228 | # Experiments #
229 | ###############################################
230 |
231 | TRAIN_vqvae := $(VOCAB_NONE) \
232 | $(LABEL_IMAGE) $(AUG_VQVAE) $(I224) \
233 | $(TRAINER_VQVAE) $(OPT_ADAMW) $(LR_1e4) $(EPOCH24)
234 |
235 | TRAIN_mtim := $(VOCAB_NONE) \
236 | $(LABEL_IMAGE) $(AUG_BEIT) \
237 | $(TRAINER_BEIT) $(OPT_ADAMW) $(OPT_WD5e2) $(LR_5e4) \
238 | $(TRANS448_VQVAE224_GRID28_MASK300)
239 |
240 | #
241 | # mini_pubtabnet pretraining example (dataset code: mini)
242 | #
243 |
244 | # vq-vae
245 | # > make experiments/vqvae_mini/.done_pretrain
246 | EXP_vqvae_mini := $(TRAIN_vqvae) $(MINIPUBTABNET) $(VQVAE_TEMP_2M) $(BATCH80) $(MODEL_VQVAE) $(LR_cosine32k)
247 |
248 | # visual encoder pretraining - masked tabular image modeling (MTIM)
249 | # > make experiments/mtim_mini_base/.done_pretrain
250 | EXP_mtim_mini_base := $(TRAIN_mtim) $(MINIPUBTABNET) $(VQVAE2M_WEIGHTS) $(MTIM_BASE) \
251 | $(BATCH384) $(LR_cosine8k_warm1k) $(EPOCH24)
252 |
253 | #
254 | # mini_pubtabnet finetuning example
255 | #
256 |
257 | # table structure (task code: html)
258 | # > make experiments/ssp_2m_mini_html_base/.done_finetune
259 | TRAIN_mini_html := $(VOCAB_HTML) \
260 | $(MINIPUBTABNET) $(LABEL_HTML) $(AUG_RESIZE_NORM) \
261 | $(TRAINER_TABLE) $(I448) $(SEQ512) \
262 | $(EPOCH48) $(OPT_ADAMW) $(OPT_WD5e2) $(LR_8e5)
263 |
264 | EXP_ssp_2m_mini_html_base := $(TRAIN_mini_html) $(ARCH_BASE) \
265 | $(WEIGHTS_mtim_2m_base) $(LOCK_MTIM_4) $(BATCH72) $(LR_cosine93k_warm6k)
266 |
267 | # table cell bbox (task code: bbox)
268 | # > make experiments/ssp_2m_mini_bbox_base/.done_finetune
269 | TRAIN_mini_bbox := $(VOCAB_BBOX) \
270 | $(MINIPUBTABNET) $(LABEL_BBOX) $(AUG_RESIZE_NORM) \
271 | $(TRAINER_TABLE) $(I448) $(SEQ1024) \
272 | $(EPOCH30) $(OPT_ADAMW) $(OPT_WD5e2) $(LR_3e4) $(GRAD_CLIP12)
273 |
274 | EXP_ssp_2m_mini_bbox_base := $(TRAIN_mini_bbox) $(ARCH_BASE) \
275 | $(WEIGHTS_mtim_2m_base) $(LOCK_MTIM_4) $(BATCH48) $(LR_cosine77k_warm8k)
276 |
277 | # table cell content (task code: cell)
278 | # > make experiments/ssp_2m_mini_cell_base/.done_finetune
279 | TRAIN_mini_cell := $(VOCAB_CELL) \
280 | $(MINIPUBTABNET) $(LABEL_CELL) $(AUG_RESIZE_NORM) \
281 | $(TRAINER_TABLE) $(I112_448) $(SEQ200) \
282 | $(EPOCH24) $(OPT_ADAMW) $(OPT_WD5e2) $(LR_8e5) $(GRAD_CLIP12)
283 |
284 | EXP_ssp_2m_mini_cell_base := $(TRAIN_mini_cell) $(ARCH_BASE) \
285 | $(WEIGHTS_mtim_2m_base) $(LOCK_MTIM_4) $(BATCH24) $(LR_cosine216k_warm27k)
286 |
287 | #
288 | # cross-dataset pretraining
289 | #
290 |
291 | # vq-vae
292 | EXP_vqvae_1M := $(TRAIN_vqvae) $(DATA_VQVAE_1M) $(VQVAE_TEMP_1M) $(BATCH80) $(MODEL_VQVAE) $(LR_cosine32k)
293 | EXP_vqvae_2M := $(TRAIN_vqvae) $(DATA_VQVAE_2M) $(VQVAE_TEMP_2M) $(BATCH48) $(MODEL_VQVAE_L) $(LR_cosine118k)
294 |
295 | # visual encoder pretraining
296 | EXP_mtim_1M_base := $(TRAIN_mtim) $(PUB_SYN) $(VQVAE1M_WEIGHTS) $(MTIM_BASE) \
297 | $(BATCH384) $(LR_cosine8k_warm1k) $(EPOCH24)
298 | EXP_mtim_1M_large := $(TRAIN_mtim) $(PUB_SYN) $(VQVAE1M_WEIGHTS) $(MTIM_LARGE) \
299 | $(BATCH96) $(LR_cosine30k_warm4k) $(EPOCH24)
300 | EXP_mtim_2M_base := $(TRAIN_mtim) $(DATA_VQVAE_2M) $(VQVAE2M_WEIGHTS) $(MTIM_BASE) \
301 | $(BATCH256) $(LR_cosine44k_warm6k) $(EPOCH48)
302 | EXP_mtim_2M_large := $(TRAIN_mtim) $(DATA_VQVAE_2M) $(VQVAE2M_WEIGHTS) $(MTIM_LARGE) \
303 | $(BATCH96) $(LR_cosine118k_warm15k) $(EPOCH48)
304 |
305 | #
306 | # cross-dataset finetuning
307 | #
308 |
309 | # table structure
310 | # > make experiments/ssp_2m_syn_pub_html_medium/.done_finetune
311 | TRAIN_syn_pub_html := $(VOCAB_HTML) \
312 | $(PUB_SYN) $(LABEL_HTML) $(AUG_RESIZE_NORM) \
313 | $(TRAINER_TABLE) $(I448) $(SEQ512) \
314 | $(EPOCH48) $(OPT_ADAMW) $(OPT_WD5e2) $(LR_8e5)
315 |
316 | EXP_ssp_2m_syn_pub_html_large := $(TRAIN_syn_pub_html) $(ARCH_LARGE) \
317 | $(WEIGHTS_mtim_2m_large) $(LOCK_MTIM_4) $(BATCH72) $(LR_cosine93k_warm6k)
318 |
319 | # table cell bbox
320 | # > make experiments/ssp_2m_syn_pub_bbox_medium/.done_finetune
321 | TRAIN_syn_pub_bbox := $(VOCAB_BBOX) \
322 | $(PUB_SYN) $(LABEL_BBOX) $(AUG_RESIZE_NORM) \
323 | $(TRAINER_TABLE) $(I448) $(SEQ1024) \
324 | $(EPOCH30) $(OPT_ADAMW) $(OPT_WD5e2) $(LR_3e4) $(GRAD_CLIP12)
325 |
326 | EXP_ssp_2m_syn_pub_bbox_large := $(TRAIN_syn_pub_bbox) $(ARCH_LARGE) \
327 | $(WEIGHTS_mtim_2m_large) $(LOCK_MTIM_4) $(BATCH48) $(LR_cosine77k_warm8k)
328 |
329 | # table cell content
330 | # > make experiments/syn_pub_pub1m_cell_medium/.done_finetune
331 | TRAIN_syn_pub_pub1m_cell := $(VOCAB_CELL) \
332 | $(PUB_SYN_PUB1M) $(LABEL_CELL) $(AUG_RESIZE_NORM) \
333 | $(TRAINER_TABLE) $(I112_448) $(SEQ200) \
334 | $(EPOCH24) $(OPT_ADAMW) $(OPT_WD5e2) $(LR_8e5) $(GRAD_CLIP12)
335 |
336 | EXP_ssp_2m_syn_pub_pub1m_cell_large := $(TRAIN_syn_pub_pub1m_cell) $(ARCH_LARGE) \
337 | $(WEIGHTS_mtim_2m_base) $(LOCK_MTIM_4) $(BATCH24) $(LR_cosine216k_warm27k)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 ShengYun (Anthony) Peng.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | SHELL := /bin/bash
2 | VENV_NAME := unitable
3 | CONDA_ACTIVATE := source $$(conda info --base)/etc/profile.d/conda.sh && conda activate $(VENV_NAME)
4 | PYTHON := $(CONDA_ACTIVATE) && python
5 | PIP := $(CONDA_ACTIVATE) && pip3
6 | # Stacked single-node multi-worker: https://pytorch.org/docs/stable/elastic/run.html#stacked-single-node-multi-worker
7 | TORCHRUN = $(CONDA_ACTIVATE) && torchrun --rdzv-backend=c10d --rdzv_endpoint localhost:0 --nnodes=1 --nproc_per_node=$(NGPU)
8 |
9 | # Taken from https://tech.davis-hansson.com/p/make/
10 | ifeq ($(origin .RECIPEPREFIX), undefined)
11 | $(error This Make does not support .RECIPEPREFIX. Please use GNU Make 4.0 or later)
12 | endif
13 | .RECIPEPREFIX = >
14 |
15 | #
16 | # Virtual Environment Targets
17 | #
18 | clean:
19 | > rm -f .venv_done
20 |
21 | .done_venv: clean
22 | > conda create -n $(VENV_NAME) python=3.9 -y
23 | > $(PIP) install -r requirements.txt
24 | > $(PIP) install -e .
25 | > touch $@
26 |
27 | #
28 | # Download pretrained and UniTable model weights
29 | #
30 | WEIGHTS_PATH = experiments/unitable_weights
31 | M_VQVAE_1M = $(WEIGHTS_PATH)/vqvae_1m.pt
32 | M_VQVAE_2M = $(WEIGHTS_PATH)/vqvae_2m.pt
33 | M_SSP_1M_BASE = $(WEIGHTS_PATH)/ssp_1m_base.pt
34 | M_SSP_1M_LARGE = $(WEIGHTS_PATH)/ssp_1m_large.pt
35 | M_SSP_2M_BASE = $(WEIGHTS_PATH)/ssp_2m_base.pt
36 | M_SSP_2M_LARGE = $(WEIGHTS_PATH)/ssp_2m_large.pt
37 | UNITABLE_HTML = $(WEIGHTS_PATH)/unitable_large_structure.pt
38 | UNITABLE_BBOX = $(WEIGHTS_PATH)/unitable_large_bbox.pt
39 | UNITABLE_CELL = $(WEIGHTS_PATH)/unitable_large_content.pt
40 |
41 | .done_download_weights:
42 | ifeq ("$(words $(wildcard $(WEIGHTS_PATH)/*.pt))", "9")
43 | > $(info All 9 model weights have already been downloaded to $(WEIGHTS_PATH).)
44 | else
45 | > $(info There should be 9 weights file under $(WEIGHTS_PATH), but only $(words $(wildcard $(WEIGHTS_PATH)/*.pt)) are found.)
46 | > $(info Begin downloading weights from HuggingFace ...)
47 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/vqvae_1m.pt -P $(WEIGHTS_PATH)
48 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/vqvae_2m.pt -P $(WEIGHTS_PATH)
49 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/ssp_1m_base.pt -P $(WEIGHTS_PATH)
50 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/ssp_1m_large.pt -P $(WEIGHTS_PATH)
51 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/ssp_2m_base.pt -P $(WEIGHTS_PATH)
52 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/ssp_2m_large.pt -P $(WEIGHTS_PATH)
53 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/unitable_large_structure.pt -P $(WEIGHTS_PATH)
54 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/unitable_large_bbox.pt -P $(WEIGHTS_PATH)
55 | > wget -c https://huggingface.co/poloclub/UniTable/resolve/main/unitable_large_content.pt -P $(WEIGHTS_PATH)
56 | > $(info Completed!)
57 | endif
58 |
59 | #
60 | # Python Targets
61 | #
62 | include CONFIG.mk
63 | SRC := src
64 | BEST_MODEL = "../$(word 1,$(subst -, ,$*))/model/best.pt"
65 | RESULT_JSON := html.json
66 | TEDS_STRUCTURE = -f "../experiments/$*/$(RESULT_JSON)" -s
67 |
68 | ######################
69 | NGPU := 1 # number of gpus used in the experiments
70 |
71 | .SECONDARY:
72 |
73 | # vq-vae and self-supervised pretraining
74 | experiments/%/.done_pretrain:
75 | > @echo "Using experiment configurations from variable EXP_$*"
76 | > cd $(SRC) && $(TORCHRUN) -m main ++name=$* $(EXP_$*) ++trainer.mode="train"
77 | > touch $@
78 |
79 | # finetuning from SSP weights for table structure, cell bbox and cell content
80 | experiments/%/.done_finetune:
81 | > @echo "Finetuning phase 1 - using experiment configurations from variable EXP_$*"
82 | > cd $(SRC) && $(TORCHRUN) -m main ++name=$* $(EXP_$*) ++trainer.mode="train"
83 | > @echo "Finetuning phase 2 - starting from epoch 4"
84 | > cd $(SRC) && $(TORCHRUN) -m main ++name=$* $(EXP_$*) ++trainer.mode="train" ++trainer.trainer.snapshot="epoch3_snapshot.pt" ++trainer.trainer.beit_pretrained_weights=null
85 | > touch $@
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UniTable: Towards a Unified Table Foundation Model
2 |
3 |

4 |
5 | 1. 📈 [High-Performance Transformers for Table Structure Recognition Need Early Convolutions](https://arxiv.org/abs/2311.05565). ShengYun Peng, Seongmin Lee, Xiaojing Wang, Rajarajeswari Balasubramaniyan, Duen Horng Chau. In *NeurIPS Second Table Representation Learning Workshop*, 2023. (Oral)
6 | 2. 🚀 [Self-Supervised Pretraining for Table Structure Recognition Transformer](https://arxiv.org/abs/2402.15578). ShengYun Peng, Seongmin Lee, Xiaojing Wang, Rajarajeswari Balasubramaniyan, Duen Horng Chau. In *AAAI Scientific Document Understanding Workshop*, 2024. (Oral)
7 | 3. 🆕 [UniTable: Towards a Unified Framework for Table Structure Recognition via Self-Supervised Pretraining](https://arxiv.org/abs/2403.04822). ShengYun Peng, Seongmin Lee, Xiaojing Wang, Rajarajeswari Balasubramaniyan, Duen Horng Chau. ArXiv, 2024.
8 |
9 | Tables convey factual and quantitative data with implicit conventions created by humans that are often challenging for machines to parse. Prior work on table recognition (TR) has mainly centered around complex task-specific combinations of available inputs and tools. We present UniTable, a training framework that unifies training paradigm, training objective, and model architecture of TR.
10 | Its training paradigm combines the simplicity of purely pixel-level inputs with the effectiveness and scalability empowered by self-supervised pretraining (SSP) from diverse unannotated tabular images. Our framework unifies the training of all three TR tasks — extracting table structure, cell content, and cell bounding box (bbox) — into a unified task-agnostic training objective: language modeling. Extensive quantitative and qualitative analyses highlight UniTable’s state-of-the-art (SOTA) performance on four of the largest TR datasets. To promote reproducible research, enhance transparency, and SOTA innovations, we have released the first-of-its-kind [Jupyter Notebook](./notebooks/full_pipeline.ipynb) of the entire inference pipeline, fine-tuned across multiple TR datasets, supporting all three TR tasks.
11 |
12 |
13 | > This repo includes code for linear projection Transformers. For convolutional stem (early convolution) Transformers, please check out our [tsr-convstem repo](https://github.com/poloclub/tsr-convstem).
14 |
15 | # News
16 | `Apr. 2024` - You can fully digitalize your own tabular image in our [Jupyter Notebook](./notebooks/full_pipeline.ipynb).
17 |
18 | `Apr. 2024` - UniTable v1.0.0 is now online with model weights available at [HuggingFace](https://huggingface.co/poloclub/UniTable/tree/main).
19 |
20 | `Feb. 2024` - We presented "Self-Supervised Pretraining" paper at AAAI'24.
21 |
22 | `Jan. 2024` - "Self-Supervised Pretraining" paper was selected as [oral](https://sites.google.com/view/sdu-aaai24/schedule?authuser=0).
23 |
24 | `Dec. 2023` - "Self-Supervised Pretraining" paper was accepted by [AAAI'24 Scientific Document Understanding Workshop](https://sites.google.com/view/sdu-aaai24/schedule?authuser=0).
25 |
26 | `Dec. 2023` - We presented "Early Convolutions" paper at [NeurIPS'23](https://x.com/RealAnthonyPeng/status/1735715161476866135?s=20).
27 |
28 | `Oct. 2023` - "Early Convolutions" paper was selected as [oral](https://table-representation-learning.github.io/#accepted-papers).
29 |
30 | `Oct. 2023` - "Early Convolutions" paper was accepted by [NeurIPS'23 Table Representation Learning Workshop](https://table-representation-learning.github.io/).
31 |
32 | # Quick Start
33 | 1. Set up virtual environment (unitable) by running `make .done_venv` in your terminal.
34 | 2. Download all the model weights from [HuggingFace](https://huggingface.co/poloclub/UniTable/tree/main) by running `make .done_download_weights` in your terminal.
35 | 3. Try out our demo [Jupyter Notebook](./notebooks/full_pipeline.ipynb) with your own tabular image! Remember to select "unitable" as your notebook kernel.
36 |
37 | # Training
38 | Our code is driven by [Makefile targets](https://www.gnu.org/software/make/manual/make.html) and configured by [Hydra](https://hydra.cc/docs/intro/). Experiment names are defined as `EXP_` in [CONFIG.mk Sec. Experiments](CONFIG.mk). We have also provided how to launch the make target in the comment above each experiment.
39 | ## Dataset annotation format
40 | We provide a tiny portion (20 samples) of PubTabNet as an example for a quick walk through of the training process. The dataset (tabular image and annotation) is available at [dataset/mini_pubtabnet](./dataset/mini_pubtabnet/). The annotation for all images are in [mini_pubtabnet_examples.jsonl](./dataset/mini_pubtabnet/mini_pubtabnet_examples.jsonl). Each line is a `json` object that corresponds to a `png` image with the following structure:
41 |
42 | ```python
43 | "filename": "tabular image filename inside one of the 'train', 'val', or 'test'",
44 | "split": "One of 'train', 'val', or 'test'",
45 | "html": "table structure, cell content, and cell bbox",
46 | "cells": "Array with all cell content and bbox",
47 | "tokens": "Array with the content of the cell",
48 | "bbox": "The bounding bbox of the cell in [x1, y1, x2, y2] format. Only provided when cell is non-empty.",
49 | "structure": "Dict with table structure",
50 | "tokens": "Array with html tags that describe the table structure. '[]' represents non-empty cell",
51 | ```
52 |
53 | If you want to train on your own dataset, please convert your dataset based on the provided format.
54 | The five datasets we used in the paper are [PubTabNet](https://github.com/ibm-aur-nlp/PubTabNet), [SynthTabNet](https://github.com/IBM/SynthTabNet), [FinTabNet](https://developer.ibm.com/exchanges/data/all/fintabnet/), [ICDAR 2019 B2 Modern](https://github.com/cndplab-founder/ICDAR2019_cTDaR), and [PubTables-1M](https://huggingface.co/datasets/bsmock/pubtables-1m).
55 | After downloading these datasets, please update the `root_dir` for each dataset under [configs/dataset/](./configs/dataset/).
56 |
57 | ## Tracking your training progress
58 | Please register [Weights & Biases account](https://wandb.ai/site) if you want to visualize training curves and reconstructed tables (for pretraining VQ-VAE only). An example of reconstructed tables by VQ-VAE:
59 |
60 | 
61 |
62 |
63 | ## Finetuning
64 | We present finetuning on the provided mini-PubTabNet. For more details on cross dataset finetuning, please check [CONFIG.mk](CONFIG.mk).
65 | ```bash
66 | # table structure
67 | make experiments/ssp_2m_mini_html_base/.done_finetune
68 |
69 | # cell bbox
70 | make experiments/ssp_2m_mini_bbox_base/.done_finetune
71 |
72 | # cell content
73 | make experiments/ssp_2m_mini_cell_base/.done_finetune
74 | ```
75 |
76 | ## Pretraining
77 | We present training the VQ-VAE and pretraining the visual encoder on the provided mini-PubTabNet. For more details on cross dataset finetuning, please check [CONFIG.mk](CONFIG.mk).
78 |
79 | ### VQ-VAE
80 | ```bash
81 | make experiments/vqvae_mini/.done_pretrain
82 | ```
83 |
84 | ### SSP visual encoder - Masked tabular image modeling (MTIM)
85 | ```bash
86 | make experiments/mtim_mini_base/.done_pretrain
87 | ```
88 |
89 | ## Multi-GPU
90 | The default setting is a single gpu, i.e., `NGPU := 1` in [Makefile](Makefile). To enable multi-GPU, please launch the above command with the following format: `CUDA_VISIBLE_DEVICES=0,1,2,3 make NGPU=4 experiment//.done_`.
91 |
92 | ## Citation
93 | ```bibtex
94 | @article{peng2024unitable,
95 | title={UniTable: Towards a Unified Framework for Table Structure Recognition via Self-Supervised Pretraining},
96 | author={Peng, ShengYun and Lee, Seongmin and Wang, Xiaojing and Balasubramaniyan, Rajarajeswari and Chau, Duen Horng},
97 | journal={arXiv preprint},
98 | year={2024}
99 | }
100 |
101 | @article{peng2024self,
102 | title={Self-Supervised Pretraining for Table Structure Recognition Transformer},
103 | author={Peng, ShengYun and Lee, Seongmin and Wang, Xiaojing and Balasubramaniyan, Rajarajeswari and Chau, Duen Horng},
104 | journal={arXiv preprint},
105 | year={2024}
106 | }
107 |
108 | @inproceedings{peng2023high,
109 | title={High-Performance Transformers for Table Structure Recognition Need Early Convolutions},
110 | author={Peng, Anthony and Lee, Seongmin and Wang, Xiaojing and Balasubramaniyan, Rajarajeswari Raji and Chau, Duen Horng},
111 | booktitle={NeurIPS 2023 Second Table Representation Learning Workshop},
112 | year={2023}
113 | }
114 | ```
115 | ## Contact
116 | If you have any questions, feel free to contact [Anthony Peng](https://shengyun-peng.github.io/) (CS PhD @Georgia Tech).
117 |
118 |
--------------------------------------------------------------------------------
/configs/dataset/augmentation/beit.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.datamodule.augmentation.AugmentationForMIM
2 | mean: [0.86597056, 0.88463002, 0.87491087]
3 | std: [0.20686628, 0.18201602, 0.18485524]
4 | trans_size: ${trainer.trans_size}
5 | vqvae_size: ${trainer.vqvae_size}
6 | trans_interpolation: bicubic
7 | vqvae_interpolation: lanczos
--------------------------------------------------------------------------------
/configs/dataset/augmentation/resize_normalize.yaml:
--------------------------------------------------------------------------------
1 | _target_: torchvision.transforms.Compose
2 | transforms:
3 | - _target_: torchvision.transforms.Resize
4 | size: ${trainer.img_size}
5 | - _target_: torchvision.transforms.ToTensor
6 | - _target_: torchvision.transforms.Normalize
7 | mean: [0.86597056, 0.88463002, 0.87491087]
8 | std: [0.20686628, 0.18201602, 0.18485524]
9 |
10 |
--------------------------------------------------------------------------------
/configs/dataset/augmentation/vqvae.yaml:
--------------------------------------------------------------------------------
1 | _target_: torchvision.transforms.Compose
2 | transforms:
3 | - _target_: torchvision.transforms.Resize
4 | size: ${trainer.img_size}
5 | - _target_: torchvision.transforms.ToTensor
--------------------------------------------------------------------------------
/configs/dataset/concat_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - augmentation: beit
4 | # - pubtabnet@train.d1: train_dataset
5 | # - pubtabnet@valid.d1: valid_dataset
6 | # - synthtabnet_marketing@train.d2: train_dataset
7 | # - synthtabnet_marketing@valid.d2: valid_dataset
8 | # - synthtabnet_fintabnet@train.d3: train_dataset
9 | # - synthtabnet_fintabnet@valid.d3: valid_dataset
10 | # - synthtabnet_sparse@train.d4: train_dataset
11 | # - synthtabnet_sparse@valid.d4: valid_dataset
12 | # - synthtabnet_pubtabnet@train.d5: train_dataset
13 | # - synthtabnet_pubtabnet@valid.d5: valid_dataset
14 |
15 |
16 | label_type: ${trainer.label_type}
17 | cell_limit: 10
18 |
19 | train_dataset:
20 | _target_: torch.utils.data.ConcatDataset
21 | datasets: ${oc.dict.values:..train}
22 |
23 | valid_dataset:
24 | _target_: torch.utils.data.ConcatDataset
25 | datasets: ${oc.dict.values:..valid}
26 |
27 | test_dataset:
28 | _target_: torch.utils.data.ConcatDataset
29 | datasets: ${oc.dict.values:..test}
--------------------------------------------------------------------------------
/configs/dataset/fintabnet/test_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - valid_dataset
3 |
4 | jsonl_filename: clean_FinTabNet_1.0.0_cell_test.jsonl
--------------------------------------------------------------------------------
/configs/dataset/fintabnet/train_dataset.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.datamodule.FinTabNet
2 | root_dir: ../../../../DATASETS/finTabNet
3 | label_type: ${dataset.label_type}
4 | jsonl_filename: clean_FinTabNet_1.0.0_cell_train.jsonl
5 | transform: ${dataset.augmentation}
--------------------------------------------------------------------------------
/configs/dataset/fintabnet/valid_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_dataset
3 |
4 | jsonl_filename: clean_FinTabNet_1.0.0_cell_val.jsonl
--------------------------------------------------------------------------------
/configs/dataset/icdar/test_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - valid_dataset
3 |
4 | split: test
--------------------------------------------------------------------------------
/configs/dataset/icdar/train_dataset.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.datamodule.ICDAR
2 | root_dir: ../../../../DATASETS/ICDAR-2013
3 | label_type: ${dataset.label_type}
4 | split: train
5 | transform: ${dataset.augmentation}
--------------------------------------------------------------------------------
/configs/dataset/icdar/valid_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_dataset
3 |
4 | split: val
--------------------------------------------------------------------------------
/configs/dataset/mini_pubtabnet/test_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - valid_dataset
3 |
4 | cell_limit: 256
--------------------------------------------------------------------------------
/configs/dataset/mini_pubtabnet/train_dataset.yaml:
--------------------------------------------------------------------------------
1 |
2 | _target_: src.datamodule.pubtabnet.PubTabNet
3 | root_dir: ../../dataset/mini_pubtabnet
4 | label_type: ${dataset.label_type}
5 | split: train
6 | json_html: mini_pubtabnet_examples.jsonl
7 | transform: ${dataset.augmentation}
8 | cell_limit: 150
--------------------------------------------------------------------------------
/configs/dataset/mini_pubtabnet/valid_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_dataset
--------------------------------------------------------------------------------
/configs/dataset/pubtables1m/test_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - valid_dataset
3 |
4 | split: test
--------------------------------------------------------------------------------
/configs/dataset/pubtables1m/train_dataset.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.datamodule.PubTables
2 | root_dir: ../../../../DATASETS/pubtables1m/PubTables-1M-Structure
3 | label_type: ${dataset.label_type}
4 | split: train
5 | transform: ${dataset.augmentation}
6 | cell_limit: ${dataset.cell_limit}
--------------------------------------------------------------------------------
/configs/dataset/pubtables1m/valid_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_dataset
3 |
4 | split: val
--------------------------------------------------------------------------------
/configs/dataset/pubtabnet/test_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - valid_dataset
3 |
4 | cell_limit: 256
--------------------------------------------------------------------------------
/configs/dataset/pubtabnet/train_dataset.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.datamodule.PubTabNet
2 | root_dir: ../../../../DATASETS/pubtabnet
3 | label_type: ${dataset.label_type}
4 | split: train
5 | json_html: clean_html_PubTabNet_2.0.0.jsonl
6 | transform: ${dataset.augmentation}
7 | cell_limit: ${dataset.cell_limit}
--------------------------------------------------------------------------------
/configs/dataset/pubtabnet/valid_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_dataset
3 |
4 | split: val
--------------------------------------------------------------------------------
/configs/dataset/single_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - augmentation: beit
4 | # - pubtabnet@train_dataset: train_dataset
5 | # - pubtabnet@valid_dataset: valid_dataset
6 | # - pubtabnet@test_dataset: test_dataset
7 |
8 | label_type: ${trainer.label_type}
9 | cell_limit: 10
--------------------------------------------------------------------------------
/configs/dataset/synthtabnet_fintabnet/test_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_dataset
3 |
4 | split: test
--------------------------------------------------------------------------------
/configs/dataset/synthtabnet_fintabnet/train_dataset.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.datamodule.Synthtabnet
2 | root_dir: ../../../../DATASETS/synthtabnet/fintabnet
3 | label_type: ${dataset.label_type}
4 | split: train
5 | json_html: clean_html_synthetic_data.jsonl
6 | transform: ${dataset.augmentation}
7 | cell_limit: ${dataset.cell_limit}
--------------------------------------------------------------------------------
/configs/dataset/synthtabnet_fintabnet/valid_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_dataset
3 |
4 | split: val
--------------------------------------------------------------------------------
/configs/dataset/synthtabnet_marketing/test_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_dataset
3 |
4 | split: test
--------------------------------------------------------------------------------
/configs/dataset/synthtabnet_marketing/train_dataset.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.datamodule.Synthtabnet
2 | root_dir: ../../../../DATASETS/synthtabnet/marketing
3 | label_type: ${dataset.label_type}
4 | split: train
5 | json_html: clean_html_synthetic_data.jsonl
6 | transform: ${dataset.augmentation}
7 | cell_limit: ${dataset.cell_limit}
--------------------------------------------------------------------------------
/configs/dataset/synthtabnet_marketing/valid_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_dataset
3 |
4 | split: val
--------------------------------------------------------------------------------
/configs/dataset/synthtabnet_pubtabnet/test_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_dataset
3 |
4 | split: test
--------------------------------------------------------------------------------
/configs/dataset/synthtabnet_pubtabnet/train_dataset.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.datamodule.Synthtabnet
2 | root_dir: ../../../../DATASETS/synthtabnet/pubtabnet
3 | label_type: ${dataset.label_type}
4 | split: train
5 | json_html: clean_html_synthetic_data.jsonl
6 | transform: ${dataset.augmentation}
7 | cell_limit: ${dataset.cell_limit}
--------------------------------------------------------------------------------
/configs/dataset/synthtabnet_pubtabnet/valid_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_dataset
3 |
4 | split: val
--------------------------------------------------------------------------------
/configs/dataset/synthtabnet_sparse/test_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_dataset
3 |
4 | split: test
--------------------------------------------------------------------------------
/configs/dataset/synthtabnet_sparse/train_dataset.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.datamodule.Synthtabnet
2 | root_dir: ../../../../DATASETS/synthtabnet/sparse
3 | label_type: ${dataset.label_type}
4 | split: train
5 | json_html: clean_html_synthetic_data.jsonl
6 | transform: ${dataset.augmentation}
7 | cell_limit: ${dataset.cell_limit}
--------------------------------------------------------------------------------
/configs/dataset/synthtabnet_sparse/valid_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_dataset
3 |
4 | split: val
--------------------------------------------------------------------------------
/configs/dataset/tablebank/test_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - valid_dataset
3 |
4 | split: test
--------------------------------------------------------------------------------
/configs/dataset/tablebank/train_dataset.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.datamodule.TableBank
2 | root_dir: ../../../../DATASETS/tablebank/Recognition
3 | label_type: ${dataset.label_type}
4 | split: train
5 | transform: ${dataset.augmentation}
6 |
--------------------------------------------------------------------------------
/configs/dataset/tablebank/valid_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_dataset
3 |
4 | split: val
--------------------------------------------------------------------------------
/configs/main.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - dataset: mini_pubtabnet
4 | - model: encoderdecoder
5 | - trainer: table
6 | - vocab: html
7 | - override hydra/job_logging: colorlog
8 | - override hydra/hydra_logging: colorlog
9 |
10 |
11 | hydra:
12 | run:
13 | dir: ../experiments/${name}
14 | sweep:
15 | dir: ../experiments/${name}
16 | job:
17 | name: ${name}
18 | chdir: true
19 |
20 | wandb:
21 | project: UniTable
22 |
23 | seed: 1234
--------------------------------------------------------------------------------
/configs/model/beit.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - model/backbone: imglinear
4 | - model/encoder: transformer
5 |
6 | nhead: 12
7 | ff_ratio: 4
8 | activation: gelu
9 | norm_first: true
10 | d_model: 768
11 | dropout: 0.0
12 | backbone_downsampling_factor: 16
13 |
14 | codebook_tokens: 8192
15 | hidden_dim: 256
16 |
17 | model:
18 | _target_: src.model.beit.BeitEncoder
19 | d_model: ${model.d_model}
20 | codebook_tokens: ${model.codebook_tokens}
21 | dropout: ${model.dropout}
22 | norm_layer:
23 | _partial_: true
24 | _target_: torch.nn.LayerNorm
25 | eps: 1e-6
26 |
27 | model_vqvae:
28 | _target_: src.model.vqvae.DiscreteVAE
29 | image_size: ${trainer.vqvae_size}
30 | codebook_tokens: ${model.codebook_tokens}
31 | codebook_dim: 512
32 | num_layers: 3
33 | hidden_dim: ${model.hidden_dim}
34 | smooth_l1_loss: false
35 | kl_div_loss_weight: 0.0
--------------------------------------------------------------------------------
/configs/model/encoderdecoder.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - model/backbone: imgcnn
4 | - model/encoder: transformer
5 | - model/decoder: transformer
6 |
7 |
8 | nhead: 4
9 | ff_ratio: 2
10 | activation: relu
11 | norm_first: false
12 | d_model: 512
13 | dropout: 0.5
14 | backbone_downsampling_factor: 16
15 |
16 |
17 | model:
18 | _target_: src.model.EncoderDecoder
19 | vocab_size: -1
20 | d_model: ${model.d_model}
21 | padding_idx: -1
22 | max_seq_len: ${trainer.max_seq_len}
23 | dropout: ${model.dropout}
24 | norm_layer:
25 | _partial_: true
26 | _target_: torch.nn.LayerNorm
27 | eps: 1e-6
28 |
29 |
30 |
--------------------------------------------------------------------------------
/configs/model/model/backbone/imgcnn.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.model.components.ImgCnnBackbone
2 | backbone:
3 | _target_: torchvision.models.resnet18
4 | output_channels: 512
5 | d_model: ${model.d_model}
6 | drop_layer:
7 | - 3
8 | - 8
9 | - 9
--------------------------------------------------------------------------------
/configs/model/model/backbone/imgconvstem.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.model.components.ImgConvStemBackbone
2 | d_model: ${model.d_model}
3 | downsample_factor: ${model.backbone_downsampling_factor}
4 | output_channels: 192
5 | kernel_size: 3
--------------------------------------------------------------------------------
/configs/model/model/backbone/imglinear.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.model.components.ImgLinearBackbone
2 | d_model: ${model.d_model}
3 | patch_size: ${model.backbone_downsampling_factor}
--------------------------------------------------------------------------------
/configs/model/model/decoder/transformer.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.model.components.Decoder
2 | d_model: ${model.d_model}
3 | nhead: ${model.nhead}
4 | dropout: ${model.dropout}
5 | activation: ${model.activation}
6 | norm_first: ${model.norm_first}
7 | nlayer: 4
8 | ff_ratio: ${model.ff_ratio}
--------------------------------------------------------------------------------
/configs/model/model/encoder/transformer.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.model.components.Encoder
2 | d_model: ${model.d_model}
3 | nhead: ${model.nhead}
4 | dropout: ${model.dropout}
5 | activation: ${model.activation}
6 | norm_first: ${model.norm_first}
7 | nlayer: 2
8 | ff_ratio: ${model.ff_ratio}
--------------------------------------------------------------------------------
/configs/model/vqvae.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 |
4 | codebook_tokens: 8192
5 | hidden_dim: 256
6 |
7 | model:
8 | _target_: src.model.vqvae.DiscreteVAE
9 | image_size: ${trainer.img_size}
10 | codebook_tokens: ${model.codebook_tokens}
11 | codebook_dim: 512
12 | num_layers: 3
13 | hidden_dim: ${model.hidden_dim}
14 | smooth_l1_loss: false
15 | kl_div_loss_weight: 0.0
--------------------------------------------------------------------------------
/configs/trainer/beit.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - train/lr_scheduler: exponential
4 | - train/optimizer: adam
5 |
6 | mode: train
7 | trans_size: 448
8 | vqvae_size: 224
9 | grid_size: 28
10 | num_mask_patches: 300
11 | min_num_patches: 16
12 | max_seq_len: null
13 |
14 | vqvae_weights: null
15 |
16 | train:
17 | epochs: 20
18 | grad_clip: 5
19 | save_every: 3
20 | dataloader:
21 | _target_: src.datamodule.dataloader.dataloader_beit
22 | batch_size: 48
23 | grid_size: ${trainer.grid_size}
24 | num_mask_patches: ${trainer.num_mask_patches}
25 | min_num_patches: ${trainer.min_num_patches}
26 | valid:
27 | dataloader:
28 | _target_: src.datamodule.dataloader.dataloader_beit
29 | batch_size: 48
30 | grid_size: ${trainer.grid_size}
31 | num_mask_patches: ${trainer.num_mask_patches}
32 | min_num_patches: ${trainer.min_num_patches}
33 | test:
34 | metrics: null
35 |
36 |
37 | trainer:
38 | _target_: src.trainer.BeitTrainer
39 | snapshot: null
40 | model_weights: null
--------------------------------------------------------------------------------
/configs/trainer/table.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - train/lr_scheduler: step
4 | - train/optimizer: adam
5 |
6 |
7 | mode: train
8 | img_size: [448,448]
9 | max_seq_len: 512
10 | label_type: html+cell+bbox
11 |
12 | train:
13 | target: ${trainer.label_type}
14 | img_size: ${trainer.img_size}
15 | loss_weights:
16 | table: 0
17 | html: 0
18 | cell: 0
19 | bbox: 0
20 | grad_clip: 5
21 | epochs: 24
22 | save_every: 1
23 | max_seq_len: ${trainer.max_seq_len}
24 | dataloader:
25 | _target_: src.datamodule.dataloader_html
26 | batch_size: 48
27 | label_type: ${trainer.label_type}
28 | valid:
29 | target: ${trainer.label_type}
30 | img_size: ${trainer.img_size}
31 | loss_weights: ${trainer.train.loss_weights}
32 | max_seq_len: ${trainer.max_seq_len}
33 | dataloader:
34 | _target_: src.datamodule.dataloader_html
35 | batch_size: 48
36 | label_type: ${trainer.label_type}
37 | test:
38 | target: ${trainer.train.target}
39 | img_size: ${trainer.img_size}
40 | loss_weights: ${trainer.train.loss_weights}
41 | metrics: teds
42 | max_seq_len: ${trainer.max_seq_len}
43 | sampling: greedy
44 | save_to_prefix: html_table_result
45 | dataloader:
46 | _target_: src.datamodule.dataloader_html
47 | batch_size: 96
48 | label_type: ${trainer.label_type}
49 |
50 |
51 | trainer:
52 | _target_: src.trainer.TableTrainer
53 | snapshot: null
54 | model_weights: null
55 | beit_pretrained_weights: null
56 | freeze_beit_epoch: null
--------------------------------------------------------------------------------
/configs/trainer/train/lr_scheduler/cosine.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.lr_scheduler.LambdaLR
2 | lr_lambda:
3 | _partial_: true
4 | _target_: src.utils.cosine_schedule_with_warmup
5 | warmup: 6
6 | min_ratio: 5e-3
7 | total_step: ${trainer.train.epochs}
8 |
--------------------------------------------------------------------------------
/configs/trainer/train/lr_scheduler/exponential.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.lr_scheduler.ExponentialLR
2 | gamma: 0.98
--------------------------------------------------------------------------------
/configs/trainer/train/lr_scheduler/step.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.lr_scheduler.StepLR
2 | step_size: 12
3 | gamma: 0.1
--------------------------------------------------------------------------------
/configs/trainer/train/optimizer/adam.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.Adam
2 | lr: 1e-4
3 | weight_decay: 1e-4
--------------------------------------------------------------------------------
/configs/trainer/train/optimizer/adamw.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.AdamW
2 | lr: 1e-4
3 | betas: [0.9, 0.999]
4 | weight_decay: 1e-4
--------------------------------------------------------------------------------
/configs/trainer/vqvae.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - train/lr_scheduler: exponential
4 | - train/optimizer: adam
5 |
6 | mode: train
7 | img_size: [256,256]
8 | label_type: image
9 | max_seq_len: null
10 |
11 | train:
12 | epochs: 20
13 | grad_clip: 0.2
14 | starting_temp: 1.
15 | temp_min: 0.06
16 | temp_anneal_rate: 1e-6
17 | save_every: 3
18 | dataloader:
19 | _target_: src.datamodule.dataloader.dataloader_vae
20 | batch_size: 48
21 | valid:
22 | dataloader:
23 | _target_: src.datamodule.dataloader.dataloader_vae
24 | batch_size: 48
25 | test:
26 | metrics: null
27 |
28 |
29 | trainer:
30 | _target_: src.trainer.VqvaeTrainer
31 | snapshot: null
32 | model_weights: null
33 |
34 |
--------------------------------------------------------------------------------
/configs/vocab/bbox.yaml:
--------------------------------------------------------------------------------
1 | need_vocab: true
2 | type: html
3 | dir: ${hydra:runtime.cwd}/../vocab/vocab_bbox.json
--------------------------------------------------------------------------------
/configs/vocab/cell.yaml:
--------------------------------------------------------------------------------
1 | need_vocab: true
2 | type: cell
3 | dir: ${hydra:runtime.cwd}/../vocab/vocab_cell_6k.json
--------------------------------------------------------------------------------
/configs/vocab/empty.yaml:
--------------------------------------------------------------------------------
1 | need_vocab: false
--------------------------------------------------------------------------------
/configs/vocab/html.yaml:
--------------------------------------------------------------------------------
1 | need_vocab: true
2 | type: html
3 | dir: ${hydra:runtime.cwd}/../vocab/vocab_html.json
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC1626454_002_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC1626454_002_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC2753619_002_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC2753619_002_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC2759935_007_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC2759935_007_01.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC2838834_005_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC2838834_005_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC3519711_003_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC3519711_003_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC3826085_003_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC3826085_003_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC3907710_006_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC3907710_006_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC4003957_018_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC4003957_018_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC4172848_007_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC4172848_007_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC4517499_004_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC4517499_004_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC4682394_003_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC4682394_003_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC4776821_005_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC4776821_005_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC4840965_004_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC4840965_004_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC5134617_013_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC5134617_013_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC5198506_004_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC5198506_004_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC5332562_005_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC5332562_005_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC5402779_004_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC5402779_004_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC5577841_001_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC5577841_001_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC5679144_002_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC5679144_002_01.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/train/PMC5897438_004_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/train/PMC5897438_004_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC1626454_002_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC1626454_002_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC2753619_002_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC2753619_002_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC2759935_007_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC2759935_007_01.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC2838834_005_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC2838834_005_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC3519711_003_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC3519711_003_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC3826085_003_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC3826085_003_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC3907710_006_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC3907710_006_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC4003957_018_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC4003957_018_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC4172848_007_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC4172848_007_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC4517499_004_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC4517499_004_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC4682394_003_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC4682394_003_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC4776821_005_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC4776821_005_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC4840965_004_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC4840965_004_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC5134617_013_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC5134617_013_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC5198506_004_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC5198506_004_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC5332562_005_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC5332562_005_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC5402779_004_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC5402779_004_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC5577841_001_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC5577841_001_00.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC5679144_002_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC5679144_002_01.png
--------------------------------------------------------------------------------
/dataset/mini_pubtabnet/val/PMC5897438_004_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/dataset/mini_pubtabnet/val/PMC5897438_004_00.png
--------------------------------------------------------------------------------
/experiments/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 | !*.ipynb
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | torchaudio
4 | torchtext
5 | jsonlines
6 | beautifulsoup4
7 | matplotlib
8 | hydra-core
9 | hydra_colorlog
10 | apted
11 | Distance
12 | lxml==4.9.3
13 | torchmetrics
14 | wandb
15 | einops
16 | ptflops
17 | tokenizers
18 | pycocotools
19 | torchmetrics
20 | faster-coco-eval
21 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages
2 | from setuptools import setup
3 |
4 | setup(name="unitable", version="1.0.0", packages=find_packages())
--------------------------------------------------------------------------------
/src/datamodule/__init__.py:
--------------------------------------------------------------------------------
1 | from .pubtabnet import PubTabNet
2 | from .synthtabnet import Synthtabnet
3 | from .dataloader import dataloader_vae, dataloader_beit, dataloader_html
4 | from .pubtables1m import PubTables
5 | from .tablebank import TableBank
6 | from .fintabnet import FinTabNet
7 |
--------------------------------------------------------------------------------
/src/datamodule/augmentation.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Any, Optional, Union
2 | from torch import Tensor
3 | import random
4 | from PIL import Image
5 | import torchvision.transforms.functional as F
6 | from torchvision import datasets, transforms
7 |
8 | from torchvision.transforms.transforms import _setup_size
9 |
10 |
11 | _PIL_INTERPOLATION = {
12 | "bilinear": Image.BILINEAR,
13 | "bicubic": Image.BICUBIC,
14 | "lanczos": Image.LANCZOS,
15 | "hamming": Image.HAMMING,
16 | }
17 |
18 | get_interpolation = lambda method: _PIL_INTERPOLATION.get(method, Image.BILINEAR)
19 |
20 |
21 | class RandomResizedCropAndInterpolationWithTwoPic(transforms.RandomResizedCrop):
22 | """Ensure both crops of vqvae and visual encoder have the same scale and size."""
23 |
24 | def __init__(
25 | self,
26 | size: Union[int, Tuple[int, int]], # transformer
27 | second_size: Union[int, Tuple[int, int]], # vqvae
28 | scale: Tuple[float, float] = (0.08, 1.0),
29 | ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
30 | interpolation: str = "bilinear",
31 | second_interpolation: str = "lanczos",
32 | ):
33 | self.second_size = _setup_size(
34 | second_size,
35 | error_msg="Please provide only two dimensions (h, w) for second size.",
36 | )
37 |
38 | if interpolation == "random":
39 | interpolation = random.choice(
40 | [get_interpolation("bilinear"), get_interpolation("bicubic")]
41 | )
42 | else:
43 | interpolation = get_interpolation(interpolation)
44 | self.second_interpolation = get_interpolation(second_interpolation)
45 |
46 | super().__init__(
47 | size=size, scale=scale, ratio=ratio, interpolation=interpolation
48 | )
49 |
50 | def forward(self, img: Image):
51 | i, j, h, w = self.get_params(img, self.scale, self.ratio)
52 | out = F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
53 | out_second = F.resized_crop(
54 | img, i, j, h, w, self.second_size, self.second_interpolation
55 | )
56 |
57 | return out, out_second
58 |
59 |
60 | class AugmentationForMIM(object):
61 | def __init__(
62 | self,
63 | mean: float,
64 | std: float,
65 | trans_size: Union[int, Tuple[int, int]],
66 | vqvae_size: Union[int, Tuple[int, int]],
67 | trans_interpolation: str,
68 | vqvae_interpolation: str,
69 | ) -> None:
70 | self.common_transform = transforms.Compose(
71 | [
72 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
73 | transforms.RandomHorizontalFlip(p=0.5),
74 | RandomResizedCropAndInterpolationWithTwoPic(
75 | size=trans_size,
76 | second_size=vqvae_size,
77 | interpolation=trans_interpolation,
78 | second_interpolation=vqvae_interpolation,
79 | ),
80 | ]
81 | )
82 |
83 | self.trans_transform = transforms.Compose(
84 | [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]
85 | )
86 |
87 | self.vqvae_transform = transforms.ToTensor()
88 |
89 | def __call__(self, img: Image) -> Tuple[Tensor, Tensor]:
90 | trans_img, vqvae_img = self.common_transform(img)
91 | trans_img = self.trans_transform(trans_img)
92 | vqvae_img = self.vqvae_transform(vqvae_img)
93 |
94 | return trans_img, vqvae_img
95 |
96 |
97 | if __name__ == "__main__":
98 | mean = [240.380, 240.390, 240.486]
99 | std = [45.735, 45.785, 45.756]
100 |
101 | T = RandomResizedCropAndInterpolationWithTwoPic(
102 | size=(256, 256),
103 | second_size=(256, 256),
104 | interpolation="bicubic",
105 | second_interpolation="lanczos",
106 | )
107 |
108 | print(T)
109 |
--------------------------------------------------------------------------------
/src/datamodule/dataloader.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from torch.utils.data import DataLoader, Dataset, Sampler
3 | from functools import partial
4 | import tokenizers as tk
5 | import torch
6 | from torch.utils.data import default_collate
7 | from src.utils.mask_generator import MaskGenerator
8 | from src.utils import (
9 | prepare_html_seq,
10 | prepare_cell_seq,
11 | prepare_bbox_seq,
12 | )
13 |
14 |
15 | class Collator:
16 | def __init__(
17 | self,
18 | vocab: tk.Tokenizer,
19 | max_seq_len: int,
20 | label_type: str,
21 | ) -> None:
22 | self.vocab = vocab
23 | self.vocab.enable_truncation(max_seq_len)
24 | self.label_type = label_type
25 |
26 | def __call__(self, batch) -> Any:
27 | return self._collate_batch(batch, self.vocab, self.label_type)
28 |
29 | def _collate_batch(
30 | self,
31 | batch: list[dict],
32 | vocab: tk.Tokenizer,
33 | label_type: str,
34 | ):
35 | if "cell" in label_type:
36 | image_list = [j for i in batch for j in i[0]]
37 | else:
38 | image_list = [i["image"] for i in batch]
39 | image_list = default_collate(image_list)
40 |
41 | if "cell" in label_type:
42 | filename = [(j["filename"], j["bbox_id"]) for i in batch for j in i[1]]
43 | else:
44 | filename = [i["filename"] for i in batch]
45 | label = dict(filename=filename)
46 |
47 | if "html" in label_type:
48 | html_list = ["".join(prepare_html_seq(i["html"])) for i in batch]
49 | label["html"] = vocab.encode_batch(html_list)
50 |
51 | if "cell" in label_type:
52 | cell_list = [
53 | " ".join(prepare_cell_seq(j["cell"])) for i in batch for j in i[1]
54 | ]
55 | label["cell"] = vocab.encode_batch(cell_list)
56 |
57 | if "bbox" in label_type:
58 | bbox_list = [" ".join(prepare_bbox_seq(i["bbox"])) for i in batch]
59 | label["bbox"] = vocab.encode_batch(bbox_list)
60 |
61 | return image_list, label
62 |
63 |
64 | def generate_mask_for_batch_samples(
65 | batch, grid_size: int, num_mask_patches: int, min_num_patches: int
66 | ):
67 | N = len(batch)
68 | mg = MaskGenerator(
69 | input_size=grid_size,
70 | num_mask_patches=num_mask_patches,
71 | min_num_patches=min_num_patches,
72 | )
73 | mask_list = [mg() for _ in range(N)]
74 | return default_collate(batch), default_collate(mask_list)
75 |
76 |
77 | def dataloader_vae(
78 | dataset: Dataset, batch_size: int, sampler: Sampler = None, **kwargs
79 | ) -> DataLoader:
80 | dataloader = DataLoader(
81 | dataset, batch_size, sampler=sampler, num_workers=8, pin_memory=True
82 | )
83 |
84 | return dataloader
85 |
86 |
87 | def dataloader_beit(
88 | dataset: Dataset,
89 | grid_size: int,
90 | num_mask_patches: int,
91 | min_num_patches: int,
92 | batch_size: int,
93 | sampler: Sampler = None,
94 | **kwargs
95 | ):
96 | dataloader = DataLoader(
97 | dataset,
98 | batch_size,
99 | sampler=sampler,
100 | collate_fn=partial(
101 | generate_mask_for_batch_samples,
102 | grid_size=grid_size,
103 | num_mask_patches=num_mask_patches,
104 | min_num_patches=min_num_patches,
105 | ),
106 | num_workers=8,
107 | pin_memory=True,
108 | )
109 |
110 | return dataloader
111 |
112 |
113 | def dataloader_html(
114 | dataset: Dataset,
115 | batch_size: int,
116 | vocab: tk.Tokenizer,
117 | max_seq_len: int,
118 | label_type: str,
119 | sampler=None,
120 | ) -> DataLoader:
121 | collate_fn = Collator(vocab, max_seq_len, label_type)
122 |
123 | dataloader = DataLoader(
124 | dataset,
125 | batch_size=batch_size,
126 | shuffle=False,
127 | num_workers=8,
128 | collate_fn=collate_fn,
129 | pin_memory=True,
130 | sampler=sampler,
131 | )
132 |
133 | return dataloader
134 |
--------------------------------------------------------------------------------
/src/datamodule/fintabnet.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Literal, Union
2 | from pathlib import Path
3 | import jsonlines
4 | from PIL import Image
5 | from torch import Tensor
6 | from torch.utils.data import Dataset
7 | import torchvision.transforms as transforms
8 |
9 |
10 | class FinTabNet(Dataset):
11 | """Load PubTabNet for different training purposes."""
12 |
13 | def __init__(
14 | self,
15 | root_dir: Union[Path, str],
16 | label_type: Literal["image", "html", "cell", "bbox"],
17 | transform: transforms = None,
18 | jsonl_filename: Union[Path, str] = None,
19 | ) -> None:
20 | super().__init__()
21 |
22 | self.root_dir = Path(root_dir)
23 | self.label_type = label_type
24 | self.transform = transform
25 |
26 | if label_type != "image":
27 | jsonl_file = self.root_dir / jsonl_filename
28 | with jsonlines.open(jsonl_file) as f:
29 | self.image_label_pair = list(f)
30 |
31 | def __len__(self):
32 | return len(self.image_label_pair)
33 |
34 | def __getitem__(self, index: int) -> Any:
35 | if self.label_type == "image":
36 | raise ValueError("FinTabNet is not used in pretraining.")
37 | else:
38 | obj = self.image_label_pair[index]
39 | img_name = f"{obj['table_id']}.png"
40 | img = Image.open(self.root_dir / "image" / img_name)
41 | if self.transform:
42 | img = self.transform(img)
43 |
44 | sample = dict(filename=obj["filename"], image=img)
45 |
46 | if self.label_type == "html":
47 | sample["html"] = obj["html"]["structure"]["tokens"]
48 | return sample
49 | else:
50 | raise ValueError("Task not supported in current dataset.")
51 |
--------------------------------------------------------------------------------
/src/datamodule/pubtables1m.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Literal, Union
2 | from pathlib import Path
3 | import jsonlines
4 | from PIL import Image
5 | from torch import Tensor
6 | from torch.utils.data import Dataset
7 | import torchvision.transforms as transforms
8 | import numpy as np
9 | import os
10 | import json
11 |
12 | from src.utils import bbox_augmentation_resize
13 |
14 |
15 | class PubTables(Dataset):
16 | """PubTables-1M-Structure"""
17 |
18 | def __init__(
19 | self,
20 | root_dir: Union[Path, str],
21 | label_type: Literal["image", "cell", "bbox"],
22 | split: Literal["train", "val", "test"],
23 | transform: transforms = None,
24 | cell_limit: int = 100,
25 | ) -> None:
26 | super().__init__()
27 |
28 | self.root_dir = Path(root_dir)
29 | self.split = split
30 | self.label_type = label_type
31 | self.transform = transform
32 | self.cell_limit = cell_limit
33 |
34 | tmp = os.listdir(self.root_dir / self.split)
35 |
36 | self.image_list = [i.split(".xml")[0] for i in tmp]
37 |
38 | def __len__(self):
39 | return len(self.image_list)
40 |
41 | def __getitem__(self, index: int) -> Any:
42 | name = self.image_list[index]
43 | img = Image.open(os.path.join(self.root_dir, "images", name + ".jpg"))
44 |
45 | if self.label_type == "image":
46 | if self.transform:
47 | img = self.transform(img)
48 | return img
49 | elif "bbox" in self.label_type:
50 | img_size = img.size
51 | if self.transform:
52 | img = self.transform(img)
53 | tgt_size = img.shape[-1]
54 | with open(
55 | os.path.join(self.root_dir, "words", name + "_words.json"), "r"
56 | ) as f:
57 | obj = json.load(f)
58 |
59 | obj[:] = [
60 | v
61 | for i in obj
62 | if "bbox" in i.keys()
63 | and all([i["bbox"][w + 2] > i["bbox"][w] for w in range(2)])
64 | for v in bbox_augmentation_resize(
65 | [
66 | min(max(i["bbox"][0], 0), img_size[0]),
67 | min(max(i["bbox"][1], 0), img_size[1]),
68 | min(max(i["bbox"][2], 0), img_size[0]),
69 | min(max(i["bbox"][3], 0), img_size[1]),
70 | ],
71 | img_size,
72 | tgt_size,
73 | )
74 | ]
75 |
76 | sample = {"filename": name, "image": img, "bbox": obj}
77 | return sample
78 |
79 | elif "cell" in self.label_type:
80 | img_size = img.size
81 | with open(
82 | os.path.join(self.root_dir, "words", name + "_words.json"), "r"
83 | ) as f:
84 | obj = json.load(f)
85 |
86 | bboxes_texts = [
87 | (i["bbox"], i["text"])
88 | for idx, i in enumerate(obj)
89 | if "bbox" in i
90 | and i["bbox"][0] < i["bbox"][2]
91 | and i["bbox"][1] < i["bbox"][3]
92 | and i["bbox"][0] >= 0
93 | and i["bbox"][1] >= 0
94 | and i["bbox"][2] < img_size[0]
95 | and i["bbox"][3] < img_size[1]
96 | and idx < self.cell_limit
97 | ]
98 |
99 | img_bboxes = [self.transform(img.crop(bbox[0])) for bbox in bboxes_texts]
100 |
101 | text_bboxes = [
102 | {"filename": name, "bbox_id": i, "cell": j[1]}
103 | for i, j in enumerate(bboxes_texts)
104 | ]
105 | return img_bboxes, text_bboxes
106 |
--------------------------------------------------------------------------------
/src/datamodule/pubtabnet.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Literal, Union
2 | from pathlib import Path
3 | from PIL import Image
4 | from torch import Tensor
5 | from torch.utils.data import Dataset
6 | import torchvision.transforms as transforms
7 | import numpy as np
8 | import os
9 |
10 | from src.utils import load_json_annotations, bbox_augmentation_resize
11 |
12 |
13 | # average html annotation length: train: 181.327 149.753
14 | # samples train: 500777, val: 9115
15 | class PubTabNet(Dataset):
16 | """Load PubTabNet for different training purposes."""
17 |
18 | def __init__(
19 | self,
20 | root_dir: Union[Path, str],
21 | label_type: Literal["image", "html", "cell", "bbox"],
22 | split: Literal["train", "val"],
23 | transform: transforms = None,
24 | json_html: Union[Path, str] = None,
25 | cell_limit: int = 150,
26 | ) -> None:
27 | super().__init__()
28 |
29 | self.root_dir = Path(root_dir)
30 | self.split = split
31 | self.label_type = label_type
32 | self.transform = transform
33 | self.cell_limit = cell_limit
34 |
35 | self.img_list = os.listdir(self.root_dir / self.split)
36 |
37 | if label_type != "image":
38 | self.image_label_pair = load_json_annotations(
39 | json_file_dir=Path(root_dir) / json_html, split=self.split
40 | )
41 |
42 | def __len__(self):
43 | return len(self.img_list)
44 |
45 | def __getitem__(self, index: int) -> Any:
46 | if self.label_type == "image":
47 | img = Image.open(self.root_dir / self.split / self.img_list[index])
48 | if self.transform:
49 | sample = self.transform(img)
50 | return sample
51 | else:
52 | obj = self.image_label_pair[index]
53 | img = Image.open(self.root_dir / self.split / obj[0])
54 |
55 | if self.label_type == "html":
56 | if self.transform:
57 | img = self.transform(img)
58 | sample = dict(
59 | filename=obj[0], image=img, html=obj[1]["structure"]["tokens"]
60 | )
61 | return sample
62 | elif self.label_type == "cell":
63 | bboxes_texts = [
64 | (i["bbox"], "".join(i["tokens"]))
65 | for idx, i in enumerate(obj[1]["cells"])
66 | if "bbox" in i
67 | and i["bbox"][0] < i["bbox"][2]
68 | and i["bbox"][1] < i["bbox"][3]
69 | and idx < self.cell_limit
70 | ]
71 |
72 | img_bboxes = [
73 | self.transform(img.crop(bbox[0])) for bbox in bboxes_texts
74 | ]
75 |
76 | text_bboxes = [
77 | {"filename": obj[0], "bbox_id": i, "cell": j[1]}
78 | for i, j in enumerate(bboxes_texts)
79 | ]
80 | return img_bboxes, text_bboxes
81 | else:
82 | img_size = img.size
83 | if self.transform:
84 | img = self.transform(img)
85 | tgt_size = img.shape[-1]
86 | sample = dict(filename=obj[0], image=img)
87 |
88 | bboxes = [
89 | entry["bbox"]
90 | for entry in obj[1]["cells"]
91 | if "bbox" in entry
92 | and entry["bbox"][0] < entry["bbox"][2]
93 | and entry["bbox"][1] < entry["bbox"][3]
94 | ]
95 |
96 | bboxes[:] = [
97 | i
98 | for entry in bboxes
99 | for i in bbox_augmentation_resize(entry, img_size, tgt_size)
100 | ]
101 |
102 | sample["bbox"] = bboxes
103 |
104 | return sample
105 |
--------------------------------------------------------------------------------
/src/datamodule/synthtabnet.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Literal, Union
2 | from pathlib import Path
3 | import jsonlines
4 | from PIL import Image
5 | from torch import Tensor
6 | from torch.utils.data import Dataset
7 | import torchvision.transforms as transforms
8 | import numpy as np
9 | import os
10 |
11 | from src.utils import load_json_annotations, bbox_augmentation_resize
12 |
13 | # invalid data pairs: image_000000_1634629424.098128.png has 4 channels
14 | INVALID_DATA = [
15 | {
16 | "dataset": "fintabnet",
17 | "split": "train",
18 | "image": "image_009379_1634631303.201671.png",
19 | },
20 | {
21 | "dataset": "marketing",
22 | "split": "train",
23 | "image": "image_000000_1634629424.098128.png",
24 | },
25 | ]
26 |
27 |
28 | class Synthtabnet(Dataset):
29 | def __init__(
30 | self,
31 | root_dir: Union[Path, str],
32 | label_type: Literal["image", "html", "all"],
33 | split: Literal["train", "val", "test"],
34 | transform: transforms = None,
35 | json_html: Union[Path, str] = None,
36 | cell_limit: int = 100,
37 | ) -> None:
38 | super().__init__()
39 |
40 | self.root_dir = Path(root_dir) / "images"
41 | self.split = split
42 | self.label_type = label_type
43 | self.transform = transform
44 | self.cell_limit = cell_limit
45 |
46 | # SSP only needs image
47 | self.img_list = os.listdir(self.root_dir / self.split)
48 | if label_type != "image":
49 | self.image_label_pair = load_json_annotations(
50 | json_file_dir=Path(root_dir) / json_html, split=split
51 | )
52 |
53 | def __len__(self):
54 | return len(self.img_list)
55 |
56 | def __getitem__(self, index: int) -> Any:
57 | if self.label_type == "image":
58 | img = Image.open(self.root_dir / self.split / self.img_list[index])
59 | if self.transform:
60 | sample = self.transform(img)
61 | return sample
62 | else:
63 | obj = self.image_label_pair[index]
64 | img = Image.open(self.root_dir / self.split / obj[0])
65 |
66 | if self.label_type == "html":
67 | if self.transform:
68 | img = self.transform(img)
69 | sample = dict(
70 | filename=obj[0], image=img, html=obj[1]["structure"]["tokens"]
71 | )
72 | return sample
73 | elif self.label_type == "cell":
74 | bboxes_texts = [
75 | (i["bbox"], "".join(i["tokens"]))
76 | for idx, i in enumerate(obj[1]["cells"])
77 | if "bbox" in i
78 | and i["bbox"][0] < i["bbox"][2]
79 | and i["bbox"][1] < i["bbox"][3]
80 | and idx < self.cell_limit
81 | ]
82 |
83 | img_bboxes = [
84 | self.transform(img.crop(bbox[0])) for bbox in bboxes_texts
85 | ] # you can limit the total cropped cells to lower gpu memory
86 |
87 | text_bboxes = [
88 | {"filename": obj[0], "bbox_id": i, "cell": j[1]}
89 | for i, j in enumerate(bboxes_texts)
90 | ]
91 | return img_bboxes, text_bboxes
92 | else:
93 | img_size = img.size
94 | if self.transform:
95 | img = self.transform(img)
96 | tgt_size = img.shape[-1]
97 | sample = dict(filename=obj[0], image=img)
98 |
99 | bboxes = [
100 | entry["bbox"]
101 | for entry in obj[1]["cells"]
102 | if "bbox" in entry
103 | and entry["bbox"][0] < entry["bbox"][2]
104 | and entry["bbox"][1] < entry["bbox"][3]
105 | ]
106 |
107 | bboxes[:] = [
108 | i
109 | for entry in bboxes
110 | for i in bbox_augmentation_resize(entry, img_size, tgt_size)
111 | ]
112 |
113 | sample["bbox"] = bboxes
114 |
115 | return sample
116 |
--------------------------------------------------------------------------------
/src/datamodule/tablebank.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Literal, Union
2 | from pathlib import Path
3 | import jsonlines
4 | from PIL import Image
5 | from torch import Tensor
6 | from torch.utils.data import Dataset
7 | import torchvision.transforms as transforms
8 | import numpy as np
9 | import os
10 | import json
11 |
12 |
13 | class TableBank(Dataset):
14 | """tablebank recognition"""
15 |
16 | def __init__(
17 | self,
18 | root_dir: Union[Path, str],
19 | label_type: Literal["image"],
20 | split: Literal["train", "val", "test"],
21 | transform: transforms = None,
22 | ) -> None:
23 | super().__init__()
24 |
25 | assert label_type == "image", "No annotations"
26 |
27 | self.root_dir = Path(root_dir)
28 | self.label_type = label_type
29 | self.transform = transform
30 | self.image_list = os.listdir(self.root_dir / "images")
31 |
32 | if split == "val" or split == "test":
33 | self.image_list = self.image_list[:1000]
34 |
35 | def __len__(self):
36 | return len(self.image_list)
37 |
38 | def __getitem__(self, index: int) -> Any:
39 | name = self.image_list[index]
40 | img = Image.open(os.path.join(self.root_dir, "images", name))
41 | if self.transform:
42 | img = self.transform(img)
43 |
44 | if self.label_type == "image":
45 | return img
46 | else:
47 | raise ValueError("TableBank doesn't have HTML annotations.")
48 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | import hydra
3 | import logging
4 | import os
5 | import wandb
6 | import torch
7 | import tokenizers as tk
8 | from omegaconf import DictConfig, OmegaConf
9 | from hydra.utils import get_original_cwd, instantiate
10 | from pathlib import Path
11 | import torch.multiprocessing as mp
12 | from torch.utils.data.distributed import DistributedSampler
13 | from torch.distributed import init_process_group, destroy_process_group
14 |
15 | from src.utils import printer, count_total_parameters
16 |
17 | log = logging.getLogger(__name__)
18 |
19 |
20 | @hydra.main(config_path="../configs", config_name="main", version_base="1.3")
21 | def main(cfg: DictConfig):
22 | torch.manual_seed(cfg.seed)
23 | ddp_setup()
24 | device = int(os.environ["LOCAL_RANK"])
25 | cwd = Path(get_original_cwd())
26 | exp_dir = Path(os.getcwd()) # experiment directory
27 |
28 | if cfg.trainer.mode == "train":
29 | (exp_dir / "snapshot").mkdir(parents=True, exist_ok=True)
30 | (exp_dir / "model").mkdir(parents=True, exist_ok=True)
31 | if device == 0:
32 | wandb.init(project=cfg.wandb.project, name=cfg.name, resume=True)
33 |
34 | # vocab is used in finetuning, not in self-supervised pretraining
35 | vocab = None
36 | if cfg.vocab.need_vocab:
37 | log.info(
38 | printer(
39 | device,
40 | f"Loading {cfg.vocab.type} vocab from {(cwd / cfg.vocab.dir).resolve()}",
41 | )
42 | )
43 | vocab = tk.Tokenizer.from_file(str(cwd / cfg.vocab.dir))
44 |
45 | # dataset
46 | if cfg.trainer.mode == "train":
47 | log.info(printer(device, "Loading training dataset"))
48 | train_dataset = instantiate(cfg.dataset.train_dataset)
49 |
50 | log.info(printer(device, "Loading validation dataset"))
51 | valid_dataset = instantiate(cfg.dataset.valid_dataset)
52 |
53 | train_kwargs = {
54 | "dataset": train_dataset,
55 | "sampler": DistributedSampler(train_dataset),
56 | "vocab": vocab,
57 | "max_seq_len": cfg.trainer.max_seq_len,
58 | }
59 |
60 | valid_kwargs = {
61 | "dataset": valid_dataset,
62 | "sampler": DistributedSampler(valid_dataset),
63 | "vocab": vocab,
64 | "max_seq_len": cfg.trainer.max_seq_len,
65 | }
66 |
67 | train_dataloader = instantiate(cfg.trainer.train.dataloader, **train_kwargs)
68 | valid_dataloader = instantiate(cfg.trainer.valid.dataloader, **valid_kwargs)
69 | elif cfg.trainer.mode == "test":
70 | # load testing dataset, same as valid for ssl
71 | log.info(printer(device, "Loading testing dataset"))
72 | test_dataset = instantiate(cfg.dataset.test_dataset)
73 |
74 | test_kwargs = {
75 | "dataset": test_dataset,
76 | "sampler": DistributedSampler(test_dataset),
77 | "vocab": vocab,
78 | "max_seq_len": cfg.trainer.max_seq_len,
79 | }
80 |
81 | test_dataloader = instantiate(cfg.trainer.test.dataloader, **test_kwargs)
82 |
83 | # model
84 | log.info(printer(device, "Loading model ..."))
85 | model_name = str(cfg.model.model._target_).split(".")[-1]
86 | if model_name == "DiscreteVAE":
87 | model = instantiate(cfg.model.model)
88 | elif model_name == "BeitEncoder":
89 | max_seq_len = (
90 | cfg.trainer.trans_size[0] // cfg.model.backbone_downsampling_factor
91 | ) * (cfg.trainer.trans_size[1] // cfg.model.backbone_downsampling_factor)
92 | model = instantiate(
93 | cfg.model.model,
94 | max_seq_len=max_seq_len,
95 | )
96 | # load pretrained vqvae
97 | model_vqvae = instantiate(cfg.model.model_vqvae)
98 |
99 | log.info(printer(device, "Loading pretrained VQVAE model ..."))
100 | assert Path(
101 | cfg.trainer.vqvae_weights
102 | ).is_file(), f"VQVAE weights doesn't exist: {cfg.trainer.vqvae_weights}"
103 | model_vqvae.load_state_dict(
104 | torch.load(cfg.trainer.vqvae_weights, map_location="cpu")
105 | )
106 | elif model_name == "EncoderDecoder":
107 | max_seq_len = max(
108 | (cfg.trainer.img_size[0] // cfg.model.backbone_downsampling_factor)
109 | * (cfg.trainer.img_size[1] // cfg.model.backbone_downsampling_factor),
110 | cfg.trainer.max_seq_len,
111 | ) # for positional embedding
112 | model = instantiate(
113 | cfg.model.model,
114 | max_seq_len=max_seq_len,
115 | vocab_size=vocab.get_vocab_size(),
116 | padding_idx=vocab.token_to_id(""),
117 | )
118 |
119 | log.info(
120 | printer(device, f"Total parameters: {count_total_parameters(model) / 1e6:.2f}M")
121 | )
122 |
123 | # trainer
124 | log.info(printer(device, "Loading trainer ..."))
125 | trainer_name = str(cfg.trainer.trainer._target_).split(".")[-1]
126 | trainer_kwargs = {
127 | "device": device,
128 | "model": model,
129 | "log": log,
130 | "exp_dir": exp_dir,
131 | "snapshot": (
132 | exp_dir / "snapshot" / cfg.trainer.trainer.snapshot
133 | if cfg.trainer.trainer.snapshot
134 | else None
135 | ),
136 | }
137 |
138 | if trainer_name == "VqvaeTrainer":
139 | trainer = instantiate(cfg.trainer.trainer, **trainer_kwargs)
140 | elif trainer_name == "BeitTrainer":
141 | trainer_kwargs["model_vqvae"] = model_vqvae
142 | trainer = instantiate(cfg.trainer.trainer, **trainer_kwargs)
143 | elif trainer_name == "TableTrainer":
144 | trainer_kwargs["vocab"] = vocab
145 | trainer = instantiate(cfg.trainer.trainer, **trainer_kwargs)
146 | else:
147 | raise ValueError(f"The provided trainer type {trainer_name} is not supported.")
148 |
149 | if cfg.trainer.mode == "train":
150 | log.info(printer(device, "Training starts ..."))
151 | trainer.train(
152 | train_dataloader, valid_dataloader, cfg.trainer.train, cfg.trainer.valid
153 | )
154 | elif cfg.trainer.mode == "test":
155 | log.info(printer(device, "Evaluation starts ..."))
156 | save_to = exp_dir / cfg.name
157 | save_to.mkdir(parents=True, exist_ok=True)
158 | trainer.test(test_dataloader, cfg.trainer.test, save_to=save_to)
159 | else:
160 | raise NotImplementedError
161 |
162 | destroy_process_group()
163 |
164 |
165 | def ddp_setup():
166 | init_process_group(backend="nccl")
167 |
168 |
169 | if __name__ == "__main__":
170 | main()
171 |
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .beit import BeitEncoder
2 | from .vqvae import DiscreteVAE
3 | from .encoderdecoder import EncoderDecoder
4 | from .components import *
--------------------------------------------------------------------------------
/src/model/beit.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn, Tensor
4 | from functools import partial
5 |
6 | from src.model.components import ImgLinearBackbone, PositionEmbedding, Encoder
7 |
8 |
9 | class BeitEncoder(nn.Module):
10 | def __init__(
11 | self,
12 | d_model: int, # embed_dim
13 | backbone: nn.Module,
14 | max_seq_len: int, # for positional embedding
15 | codebook_tokens: int,
16 | dropout: float,
17 | encoder: Encoder,
18 | norm_layer: nn.Module,
19 | init_std: float = 0.02,
20 | ) -> None:
21 | super().__init__()
22 |
23 | self.d_model = d_model
24 | self.init_std = init_std
25 |
26 | self.backbone = backbone
27 | self.pos_embed = PositionEmbedding(
28 | max_seq_len=max_seq_len, d_model=d_model, dropout=dropout
29 | )
30 |
31 | self.encoder = encoder
32 | self.norm = norm_layer(d_model)
33 | self.generator = nn.Linear(d_model, codebook_tokens)
34 |
35 | self.trunc_normal = partial(
36 | nn.init.trunc_normal_, std=init_std, a=-init_std, b=init_std
37 | )
38 | self.apply(self._init_weights)
39 |
40 | self.mask_token = nn.Parameter(torch.zeros(1, 1, d_model))
41 |
42 | def _init_weights(self, m: nn.Module):
43 | if isinstance(m, nn.Linear):
44 | self.trunc_normal(m.weight)
45 | if m.bias is not None:
46 | nn.init.constant_(m.bias, 0.0)
47 | elif isinstance(m, nn.LayerNorm):
48 | nn.init.constant_(m.weight, 1.0)
49 | nn.init.constant_(m.bias, 0.0)
50 | elif isinstance(m, nn.Conv2d):
51 | self.trunc_normal(m.weight)
52 | if m.bias is not None:
53 | nn.init.constant_(m.bias, 0.0)
54 | elif isinstance(m, PositionEmbedding):
55 | self.trunc_normal(m.embedding.weight)
56 |
57 | @torch.jit.ignore
58 | def no_weight_decay(self):
59 | return {"pos_embed"}
60 |
61 | def forward(
62 | self, x: Tensor, bool_masked_pos: Tensor, return_all_tokens: bool = False
63 | ):
64 | x = self.backbone(x)
65 | B, S, E = x.shape
66 | assert E == self.d_model
67 |
68 | mask_token = self.mask_token.expand(B, S, -1)
69 |
70 | w = bool_masked_pos.unsqueeze(-1).type_as(mask_token)
71 | x = x * (1 - w) + mask_token * w
72 |
73 | x = self.pos_embed(x)
74 |
75 | x = self.encoder(x)
76 | x = self.norm(x)
77 |
78 | if return_all_tokens:
79 | return self.generator(x)
80 | else:
81 | return self.generator(x[bool_masked_pos])
82 |
83 |
84 | if __name__ == "__main__":
85 | d_model = 512
86 | patch_size = 16
87 | nhead = 8
88 | dropout = 0.0
89 | acitvation = "gelu"
90 | norm_first = True
91 | nlayer = 12
92 | ff_ratio = 4
93 | norm_layer = partial(nn.LayerNorm, eps=1e-6)
94 | codebook_tokens = 8192
95 |
96 | img_size = 448
97 |
98 | max_seq_len = (img_size // patch_size) ** 2
99 |
100 | backbone = ImgLinearBackbone(d_model=d_model, patch_size=patch_size)
101 | encoder = Encoder(
102 | d_model=d_model,
103 | nhead=nhead,
104 | dropout=dropout,
105 | activation=acitvation,
106 | norm_first=norm_first,
107 | nlayer=nlayer,
108 | ff_ratio=ff_ratio,
109 | )
110 |
111 | model = BeitEncoder(
112 | d_model=d_model,
113 | backbone=backbone,
114 | max_seq_len=max_seq_len,
115 | codebook_tokens=codebook_tokens,
116 | dropout=dropout,
117 | encoder=encoder,
118 | norm_layer=norm_layer,
119 | )
120 |
121 | print(model)
122 |
123 | x = torch.rand((1, 3, img_size, img_size))
124 | bool_masked_pos = torch.rand((1, (img_size // patch_size) ** 2)) < 0.5
125 | y = model(x, bool_masked_pos)
126 | print(torch.sum(bool_masked_pos))
127 | print(y.shape)
128 |
--------------------------------------------------------------------------------
/src/model/components.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | import torch
3 | from torch import nn, Tensor
4 | from torchvision.ops.misc import Conv2dNormActivation
5 |
6 |
7 | __all__ = [
8 | "ImgCnnBackbone",
9 | "ImgLinearBackbone",
10 | "ImgConvStemBackbone",
11 | "PositionEmbedding",
12 | "Encoder",
13 | "Decoder",
14 | "TokenEmbedding",
15 | ]
16 |
17 |
18 | class ImgCnnBackbone(nn.Module):
19 | def __init__(
20 | self,
21 | backbone: nn.Module,
22 | output_channels: int,
23 | d_model: int,
24 | drop_layer: Tuple = None,
25 | ) -> None:
26 | super().__init__()
27 |
28 | # drop layers for classification & maxpooling for higher feature resolution
29 | layers = list(backbone.children())
30 | nlayer = len(layers)
31 | keep_layer = set([i for i in range(nlayer)]) - set(drop_layer)
32 | backbone = [layers[i] for i in keep_layer]
33 | self.backbone = nn.Sequential(*backbone)
34 | self.proj = nn.Linear(output_channels, d_model)
35 | self.channels = output_channels
36 |
37 | def forward(self, x: Tensor) -> Tensor:
38 | x = self.backbone(x)
39 | x = x.flatten(start_dim=-2).transpose(1, 2)
40 | assert x.shape[-1] == self.channels, "Image channels size mismatch."
41 | x = self.proj(x)
42 | return x
43 |
44 |
45 | class ImgLinearBackbone(nn.Module):
46 | def __init__(
47 | self,
48 | d_model: int,
49 | patch_size: int,
50 | in_chan: int = 3,
51 | ) -> None:
52 | super().__init__()
53 |
54 | self.conv_proj = nn.Conv2d(
55 | in_chan, out_channels=d_model, kernel_size=patch_size, stride=patch_size
56 | )
57 | self.d_model = d_model
58 |
59 | def forward(self, x: Tensor) -> Tensor:
60 | x = self.conv_proj(x)
61 | x = x.flatten(start_dim=-2).transpose(1, 2)
62 | return x
63 |
64 |
65 | class ImgConvStemBackbone(nn.Module):
66 | def __init__(
67 | self,
68 | d_model: int,
69 | downsample_factor: int,
70 | output_channels: int,
71 | kernel_size: int,
72 | ) -> None:
73 | super().__init__()
74 |
75 | assert downsample_factor % 2 == 0
76 | assert output_channels % (downsample_factor // 2) == 0
77 | input_channels = output_channels // (downsample_factor // 2)
78 |
79 | layers = [
80 | Conv2dNormActivation(
81 | 3, input_channels, kernel_size=kernel_size, stride=2, padding=1
82 | )
83 | ]
84 |
85 | while input_channels != output_channels:
86 | layers.append(
87 | Conv2dNormActivation(
88 | input_channels,
89 | input_channels * 2,
90 | kernel_size=kernel_size,
91 | stride=2,
92 | padding=1,
93 | )
94 | )
95 | input_channels = input_channels * 2
96 |
97 | layers.append(nn.Conv2d(output_channels, d_model, kernel_size=1))
98 |
99 | self.conv_stem = nn.Sequential(*layers)
100 |
101 | def forward(self, x: Tensor) -> Tensor:
102 | x = self.conv_stem(x)
103 | x = x.flatten(start_dim=-2).transpose(1, 2)
104 | return x
105 |
106 |
107 | class Encoder(nn.Module):
108 | def __init__(
109 | self,
110 | d_model: int,
111 | nhead: int,
112 | dropout: float,
113 | activation: str,
114 | norm_first: bool,
115 | nlayer: int,
116 | ff_ratio: int = 4,
117 | ) -> None:
118 | super().__init__()
119 |
120 | encoder_layer = nn.TransformerEncoderLayer(
121 | d_model,
122 | nhead=nhead,
123 | dim_feedforward=ff_ratio * d_model,
124 | dropout=dropout,
125 | activation=activation,
126 | batch_first=True,
127 | norm_first=norm_first,
128 | )
129 |
130 | self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=nlayer)
131 |
132 | def forward(self, x: Tensor) -> Tensor:
133 | x = self.encoder(x)
134 | return x
135 |
136 |
137 | class Decoder(nn.Module):
138 | def __init__(
139 | self,
140 | d_model: int,
141 | nhead: int,
142 | dropout: float,
143 | activation: str,
144 | norm_first: bool,
145 | nlayer: int,
146 | ff_ratio: int = 4,
147 | ) -> None:
148 | super().__init__()
149 | decoder_layer = nn.TransformerDecoderLayer(
150 | d_model,
151 | nhead,
152 | dim_feedforward=ff_ratio * d_model,
153 | dropout=dropout,
154 | activation=activation,
155 | batch_first=True,
156 | norm_first=norm_first,
157 | )
158 |
159 | self.decoder = nn.TransformerDecoder(decoder_layer, nlayer)
160 |
161 | def forward(
162 | self, x: Tensor, memory: Tensor, tgt_mask: Tensor, tgt_padding_mask: Tensor
163 | ) -> Tensor:
164 | x = self.decoder(
165 | x, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_padding_mask
166 | )
167 | return x
168 |
169 |
170 | class PositionEmbedding(nn.Module):
171 | def __init__(self, max_seq_len: int, d_model: int, dropout: float) -> None:
172 | super().__init__()
173 | self.embedding = nn.Embedding(max_seq_len, d_model)
174 | self.dropout = nn.Dropout(dropout)
175 |
176 | def forward(self, x: Tensor) -> Tensor:
177 | # assume x is batch first
178 | out = self.embedding(torch.arange(x.shape[1], device=x.device))
179 | return self.dropout(out + x)
180 |
181 |
182 | class TokenEmbedding(nn.Module):
183 | def __init__(
184 | self,
185 | vocab_size: int,
186 | d_model: int,
187 | padding_idx: int,
188 | ) -> None:
189 | super().__init__()
190 | assert vocab_size > 0
191 | self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
192 |
193 | def forward(self, x: Tensor) -> Tensor:
194 | return self.embedding(x)
195 |
196 |
197 | class PrintLayer(nn.Module):
198 | """Only for debugging when loss is nan."""
199 |
200 | def __init__(self):
201 | super().__init__()
202 |
203 | def forward(self, x):
204 | print(
205 | "torch.isfinite(x).all(): {}, min. {:.5f}, max. {:.5f}".format(
206 | torch.isfinite(x).all(), x.min(), x.max()
207 | )
208 | )
209 | return x
210 |
211 |
212 | if __name__ == "__main__":
213 | from torchvision import models
214 |
215 | x = torch.rand(1, 3, 392, 392)
216 | model = ImgConvStemBackbone(
217 | d_model=512, downsample_factor=16, output_channels=64, kernel_size=5
218 | )
219 | y = model(x)
220 | print(model)
221 | print(y.shape)
222 |
223 | model = ImgCnnBackbone(
224 | backbone=models.resnet34(),
225 | output_channels=512,
226 | d_model=512,
227 | drop_layer=(3, 8, 9),
228 | )
229 |
230 | # print(model)
231 | y = model(x)
232 | print(y.shape)
233 |
--------------------------------------------------------------------------------
/src/model/encoderdecoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor, nn
3 | from functools import partial
4 |
5 | from src.model.components import (
6 | ImgCnnBackbone,
7 | ImgLinearBackbone,
8 | ImgConvStemBackbone,
9 | Encoder,
10 | Decoder,
11 | PositionEmbedding,
12 | TokenEmbedding,
13 | )
14 |
15 |
16 | class EncoderDecoder(nn.Module):
17 | """Encoder decoder architecture that takes in a tabular image and generates the text output.
18 | Backbone serves as the image processor. There are three types of backbones: CNN, linear projection, and ConvStem.
19 |
20 | Args:
21 | ----
22 | backbone: tabular image processor
23 | encoder: transformer encoder
24 | decoder: transformer decoder
25 | vocab_size: size of the vocabulary
26 | d_model: feature size
27 | padding_idx: index of in the vocabulary
28 | max_seq_len: max sequence length of generated text
29 | dropout: dropout rate
30 | norm_layer: layernorm
31 | init_std: std in weights initialization
32 | """
33 |
34 | def __init__(
35 | self,
36 | backbone: nn.Module,
37 | encoder: nn.Module,
38 | decoder: nn.Module,
39 | vocab_size: int,
40 | d_model: int,
41 | padding_idx: int,
42 | max_seq_len: int,
43 | dropout: float,
44 | norm_layer: nn.Module,
45 | init_std: float = 0.02,
46 | ):
47 | super().__init__()
48 |
49 | self.backbone = backbone
50 | self.encoder = encoder
51 | self.decoder = decoder
52 | self.norm = norm_layer(d_model)
53 | self.token_embed = TokenEmbedding(
54 | vocab_size=vocab_size, d_model=d_model, padding_idx=padding_idx
55 | )
56 | self.pos_embed = PositionEmbedding(
57 | max_seq_len=max_seq_len, d_model=d_model, dropout=dropout
58 | )
59 | self.generator = nn.Linear(d_model, vocab_size)
60 |
61 | self.trunc_normal = partial(
62 | nn.init.trunc_normal_, std=init_std, a=-init_std, b=init_std
63 | )
64 | self.apply(self._init_weights)
65 |
66 | def _init_weights(self, m: nn.Module):
67 | if isinstance(m, nn.Linear):
68 | self.trunc_normal(m.weight)
69 | if m.bias is not None:
70 | nn.init.constant_(m.bias, 0.0)
71 | elif isinstance(m, nn.LayerNorm):
72 | nn.init.constant_(m.weight, 1.0)
73 | nn.init.constant_(m.bias, 0.0)
74 | elif isinstance(m, nn.Conv2d):
75 | self.trunc_normal(m.weight)
76 | if m.bias is not None:
77 | nn.init.constant_(m.bias, 0.0)
78 | elif isinstance(m, PositionEmbedding):
79 | self.trunc_normal(m.embedding.weight)
80 | elif isinstance(m, TokenEmbedding):
81 | self.trunc_normal(m.embedding.weight)
82 |
83 | @torch.jit.ignore
84 | def no_weight_decay(self):
85 | return {"token_embed", "pos_embed"}
86 |
87 | def encode(self, src: Tensor) -> Tensor:
88 | src_feature = self.backbone(src)
89 | src_feature = self.pos_embed(src_feature)
90 | memory = self.encoder(src_feature)
91 | memory = self.norm(memory)
92 | return memory
93 |
94 | def decode(
95 | self, memory: Tensor, tgt: Tensor, tgt_mask: Tensor, tgt_padding_mask: Tensor
96 | ) -> Tensor:
97 | tgt_feature = self.pos_embed(self.token_embed(tgt))
98 | tgt = self.decoder(tgt_feature, memory, tgt_mask, tgt_padding_mask)
99 |
100 | return tgt
101 |
102 | def forward(
103 | self, src: Tensor, tgt: Tensor, tgt_mask: Tensor, tgt_padding_mask: Tensor
104 | ) -> Tensor:
105 | memory = self.encode(src)
106 | tgt = self.decode(memory, tgt, tgt_mask, tgt_padding_mask)
107 | tgt = self.generator(tgt)
108 |
109 | return tgt
110 |
--------------------------------------------------------------------------------
/src/model/vqvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor, einsum
3 | from typing import Optional, Tuple
4 | import math
5 | from functools import partial
6 | from collections import OrderedDict
7 | import torch.nn.functional as F
8 | from einops import rearrange
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def default(val, d):
16 | return val if exists(val) else d
17 |
18 |
19 | def eval_decorator(fn):
20 | def inner(model, *args, **kwargs):
21 | was_training = model.training
22 | model.eval()
23 | out = fn(model, *args, **kwargs)
24 | model.train(was_training)
25 | return out
26 |
27 | return inner
28 |
29 |
30 | class ResBlock(nn.Module):
31 | def __init__(self, chan_in, hidden_size, chan_out):
32 | super().__init__()
33 | self.net = nn.Sequential(
34 | nn.Conv2d(chan_in, hidden_size, 3, padding=1),
35 | nn.ReLU(),
36 | nn.Conv2d(hidden_size, hidden_size, 3, padding=1),
37 | nn.ReLU(),
38 | nn.Conv2d(hidden_size, chan_out, 1),
39 | )
40 |
41 | def forward(self, x):
42 | return self.net(x) + x
43 |
44 |
45 | class BasicVAE(nn.Module):
46 | def get_codebook_indices(self, images):
47 | raise NotImplementedError()
48 |
49 | def decode(self, img_seq):
50 | raise NotImplementedError()
51 |
52 | def get_codebook_probs(self, img_seq):
53 | raise NotImplementedError()
54 |
55 | def get_image_tokens_size(self):
56 | pass
57 |
58 | def get_image_size(self):
59 | pass
60 |
61 |
62 | class DiscreteVAE(BasicVAE):
63 | def __init__(
64 | self,
65 | image_size: Tuple[int, int] = [256, 256], # input image size
66 | codebook_tokens: int = 512, # codebook vocab size
67 | codebook_dim: int = 512, # codebook embedding dimension
68 | num_layers: int = 3, # layers of resnet blocks in encoder/decoder
69 | hidden_dim: int = 64, # dimension in resnet blocks
70 | channels: int = 3, # input channels
71 | smooth_l1_loss: bool = False, # prevents exploding gradients
72 | temperature: float = 0.9, # tau in gumbel softmax
73 | straight_through: bool = False, # if True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd
74 | kl_div_loss_weight: float = 0.0,
75 | ):
76 | super().__init__()
77 | assert num_layers >= 1, "number of layers must be greater than or equal to 1"
78 |
79 | self.image_size = image_size
80 | self.codebook_tokens = codebook_tokens
81 | self.num_layers = num_layers
82 | self.temperature = temperature
83 | self.straight_through = straight_through
84 | self.codebook = nn.Embedding(codebook_tokens, codebook_dim)
85 |
86 | encoder_layers = list()
87 | decoder_layers = list()
88 |
89 | encoder_in = channels
90 | decoder_in = codebook_dim
91 |
92 | for _ in range(num_layers):
93 | encoder_layers.append(
94 | nn.Sequential(
95 | nn.Conv2d(encoder_in, hidden_dim, 4, stride=2, padding=1), nn.ReLU()
96 | )
97 | )
98 | encoder_layers.append(
99 | ResBlock(
100 | chan_in=hidden_dim, hidden_size=hidden_dim, chan_out=hidden_dim
101 | )
102 | )
103 | encoder_in = hidden_dim
104 |
105 | decoder_layers.append(
106 | nn.Sequential(
107 | nn.ConvTranspose2d(decoder_in, hidden_dim, 4, stride=2, padding=1),
108 | nn.ReLU(),
109 | )
110 | )
111 | decoder_layers.append(
112 | ResBlock(
113 | chan_in=hidden_dim, hidden_size=hidden_dim, chan_out=hidden_dim
114 | )
115 | )
116 | decoder_in = hidden_dim
117 |
118 | encoder_layers.append(nn.Conv2d(hidden_dim, codebook_tokens, 1))
119 | decoder_layers.append(nn.Conv2d(hidden_dim, channels, 1))
120 |
121 | self.encoder = nn.Sequential(*encoder_layers)
122 | self.decoder = nn.Sequential(*decoder_layers)
123 |
124 | self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
125 | self.kl_div_loss_weight = kl_div_loss_weight
126 |
127 | def get_image_size(self):
128 | return self.image_size
129 |
130 | def get_image_tokens_size(self) -> int:
131 | ds_ratio = math.pow(2, self.num_layers)
132 | return int((self.image_size[0] // ds_ratio) * (self.image_size[1] // ds_ratio))
133 |
134 | @torch.no_grad()
135 | @eval_decorator
136 | def get_codebook_indices(self, images: Tensor):
137 | logits = self.forward(images, return_logits=True)
138 | codebook_indices = logits.argmax(dim=1)
139 | return codebook_indices
140 |
141 | @torch.no_grad()
142 | @eval_decorator
143 | def get_codebook_probs(self, images: Tensor):
144 | logits = self.forward(images, return_logits=True)
145 | return nn.Softmax(dim=1)(logits)
146 |
147 | def decode(self, img_seq: Tensor):
148 | image_embeds = self.codebook(img_seq)
149 | image_embeds = image_embeds.permute((0, 3, 1, 2)).contiguous()
150 |
151 | # image_embeds = rearrange(image_embeds, "b h w d -> b d h w", h=h, w=w)
152 | images = self.decoder(image_embeds)
153 | return images
154 |
155 | def forward(
156 | self,
157 | img: Tensor,
158 | return_loss: bool = False,
159 | return_recons: bool = False,
160 | return_logits: bool = False,
161 | temp=None,
162 | ) -> Tuple[Tensor, Optional[Tensor]]:
163 | assert (
164 | img.shape[-1] == self.image_size[0] and img.shape[-2] == self.image_size[1]
165 | ), f"input must have the correct image size {self.image_size}"
166 |
167 | logits = self.encoder(img)
168 |
169 | if return_logits:
170 | return logits # return logits for getting hard image indices for DALL-E training
171 |
172 | temp = default(temp, self.temperature)
173 | soft_one_hot = F.gumbel_softmax(
174 | logits, tau=temp, dim=1, hard=self.straight_through
175 | )
176 | sampled = einsum(
177 | "b n h w, n d -> b d h w", soft_one_hot, self.codebook.weight
178 | ).contiguous()
179 | out = self.decoder(sampled)
180 |
181 | if not return_loss:
182 | return out
183 |
184 | # reconstruction loss
185 | recon_loss = self.loss_fn(img, out)
186 |
187 | # kl divergence
188 | logits = rearrange(logits, "b n h w -> b (h w) n").contiguous()
189 | qy = F.softmax(logits, dim=-1)
190 |
191 | log_qy = torch.log(qy + 1e-10)
192 | log_uniform = torch.log(
193 | torch.tensor([1.0 / self.codebook_tokens], device=img.device)
194 | )
195 | kl_div = F.kl_div(log_uniform, log_qy, None, None, "batchmean", log_target=True)
196 |
197 | loss = recon_loss + (kl_div * self.kl_div_loss_weight)
198 |
199 | if not return_recons:
200 | return loss
201 |
202 | return loss, out
203 |
204 |
205 | if __name__ == "__main__":
206 | input = torch.rand(1, 3, 256, 256)
207 | model = DiscreteVAE()
208 | loss, output = model(input, return_loss=True, return_recons=True)
209 |
210 | print(model)
211 | print(model.get_image_tokens_size())
212 | print(model.get_codebook_indices(input).shape)
213 | print(loss, output.shape, output.max(), output.min())
214 |
--------------------------------------------------------------------------------
/src/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .train_beit import BeitTrainer
2 | from .train_vqvae import VqvaeTrainer
3 | from .train_table import TableTrainer
--------------------------------------------------------------------------------
/src/trainer/train_beit.py:
--------------------------------------------------------------------------------
1 | import wandb
2 | from pathlib import Path
3 | from typing import Tuple, List, Union, Dict
4 | from omegaconf import DictConfig
5 | from hydra.utils import instantiate
6 | import logging
7 | import torch
8 | import time
9 | from torch import nn, Tensor, autograd
10 | from torch.utils.data import DataLoader
11 | from torch.nn.parallel import DistributedDataParallel as DDP
12 |
13 | from src.utils import printer, compute_grad_norm
14 | from src.trainer.utils import configure_optimizer_weight_decay
15 |
16 | SNAPSHOT_KEYS = set(["EPOCH", "STEP", "OPTIMIZER", "LR_SCHEDULER", "MODEL", "LOSS"])
17 |
18 |
19 | class BeitTrainer:
20 | def __init__(
21 | self,
22 | device: int,
23 | model: nn.Module,
24 | model_vqvae: nn.Module,
25 | log: logging.Logger,
26 | exp_dir: Path,
27 | snapshot: Path = None,
28 | model_weights: Path = None, # only for testing
29 | ) -> None:
30 | self.device = device
31 | self.log = log
32 | self.exp_dir = exp_dir
33 | self.criterion = nn.CrossEntropyLoss()
34 | assert (
35 | snapshot is None or model_weights is None
36 | ), "Snapshot and model weights cannot be set at the same time."
37 |
38 | self.model = model
39 | if snapshot is not None and snapshot.is_file():
40 | self.snapshot = self.load_snapshot(snapshot)
41 | self.model.load_state_dict(self.snapshot["MODEL"])
42 | self.start_epoch = self.snapshot["EPOCH"]
43 | self.global_step = self.snapshot["STEP"]
44 | elif model_weights is not None and model_weights.is_file():
45 | self.load_model(model_weights)
46 | else:
47 | self.snapshot = None
48 | self.start_epoch = 0
49 | self.global_step = 0
50 |
51 | self.model = self.model.to(device)
52 | self.model = DDP(self.model, device_ids=[device])
53 | self.model_vqvae = model_vqvae.to(device)
54 |
55 | # https://discuss.pytorch.org/t/extra-10gb-memory-on-gpu-0-in-ddp-tutorial/118113
56 | torch.cuda.set_device(device) # master gpu takes up extra memory
57 | torch.cuda.empty_cache()
58 |
59 | def train_epoch(self, epoch: int, grad_clip: float = None):
60 | start = time.time()
61 | total_loss = 0.0
62 | total_samples = 0
63 |
64 | for i, obj in enumerate(self.train_dataloader):
65 | (trans_image, vqvae_image), bool_mask_pos = obj
66 | trans_image, vqvae_image, bool_mask_pos = (
67 | trans_image.to(self.device),
68 | vqvae_image.to(self.device),
69 | bool_mask_pos.to(self.device),
70 | )
71 |
72 | with torch.no_grad():
73 | input_ids = self.model_vqvae.get_codebook_indices(vqvae_image).flatten(
74 | 1
75 | )
76 | bool_mask_pos = bool_mask_pos.flatten(1).to(torch.bool)
77 | labels = input_ids[bool_mask_pos]
78 |
79 | with autograd.detect_anomaly():
80 | outputs = self.model(
81 | trans_image, bool_mask_pos, return_all_tokens=False
82 | )
83 | loss = self.criterion(outputs, labels)
84 |
85 | self.optimizer.zero_grad()
86 | loss.backward()
87 | if grad_clip:
88 | nn.utils.clip_grad_norm_(
89 | self.model.parameters(), max_norm=grad_clip
90 | )
91 | self.optimizer.step()
92 |
93 | loss = loss.detach().cpu().data
94 | total_loss += loss * trans_image.shape[0]
95 | total_samples += trans_image.shape[0]
96 |
97 | self.lr_scheduler.step()
98 | self.global_step += 1
99 |
100 | if i % 10 == 0:
101 | grad_norm = compute_grad_norm(self.model)
102 | lr = self.optimizer.param_groups[0]["lr"]
103 | elapsed = time.time() - start
104 | self.log.info(
105 | printer(
106 | self.device,
107 | f"Epoch {epoch} Step {i + 1}/{len(self.train_dataloader)} | Loss {loss:.4f} ({total_loss / total_samples:.4f}) | Grad norm {grad_norm:.3f} | {total_samples / elapsed:4.1f} images/s | lr {lr:5.1e}",
108 | )
109 | )
110 |
111 | if i % 100 == 0 and self.device == 0:
112 | lr = self.optimizer.param_groups[0]["lr"]
113 | log_info = {
114 | "epoch": epoch,
115 | "train_loss": loss,
116 | "learning rate": lr,
117 | "grad_norm": grad_norm,
118 | }
119 |
120 | wandb.log(
121 | log_info,
122 | step=self.global_step,
123 | )
124 |
125 | return total_loss / total_samples
126 |
127 | def train(
128 | self,
129 | train_dataloader: DataLoader,
130 | valid_dataloader: DataLoader,
131 | train_cfg: DictConfig,
132 | valid_cfg: DictConfig,
133 | ):
134 | self.train_dataloader = train_dataloader
135 | self.valid_dataloader = valid_dataloader
136 |
137 | # ensure correct weight decay: https://github.com/karpathy/minGPT/blob/37baab71b9abea1b76ab957409a1cc2fbfba8a26/mingpt/model.py#L215
138 | optim_params = configure_optimizer_weight_decay(
139 | self.model.module, weight_decay=train_cfg.optimizer.weight_decay
140 | )
141 | self.optimizer = instantiate(train_cfg.optimizer, optim_params)
142 |
143 | self.lr_scheduler = instantiate(
144 | train_cfg.lr_scheduler, optimizer=self.optimizer
145 | )
146 |
147 | if self.snapshot is not None:
148 | self.optimizer.load_state_dict(self.snapshot["OPTIMIZER"])
149 | self.lr_scheduler.load_state_dict(self.snapshot["LR_SCHEDULER"])
150 |
151 | best_loss = float("inf")
152 | self.model.train()
153 | for epoch in range(self.start_epoch, train_cfg.epochs):
154 | train_dataloader.sampler.set_epoch(epoch)
155 | train_loss = self.train_epoch(epoch, grad_clip=train_cfg.grad_clip)
156 |
157 | torch.cuda.empty_cache()
158 |
159 | valid_loss = self.valid(valid_cfg)
160 |
161 | if self.device == 0:
162 | wandb.log(
163 | {
164 | "train loss (epoch)": train_loss,
165 | "valid loss (epoch)": valid_loss,
166 | },
167 | step=self.global_step,
168 | )
169 |
170 | if epoch % train_cfg.save_every == 0:
171 | self.save_snapshot(epoch, best_loss)
172 | if valid_loss < best_loss:
173 | self.save_model(epoch)
174 | best_loss = valid_loss
175 |
176 | def valid(self, cfg: DictConfig):
177 | total_samples = 0
178 | total_loss = 0.0
179 |
180 | self.model.eval()
181 | for i, obj in enumerate(self.valid_dataloader):
182 | (trans_image, vqvae_image), bool_mask_pos = obj
183 | trans_image, vqvae_image, bool_mask_pos = (
184 | trans_image.to(self.device),
185 | vqvae_image.to(self.device),
186 | bool_mask_pos.to(self.device),
187 | )
188 |
189 | with torch.no_grad():
190 | input_ids = self.model_vqvae.get_codebook_indices(vqvae_image).flatten(
191 | 1
192 | )
193 | bool_mask_pos = bool_mask_pos.flatten(1).to(torch.bool)
194 | labels = input_ids[bool_mask_pos]
195 |
196 | outputs = self.model(
197 | trans_image, bool_mask_pos, return_all_tokens=False
198 | )
199 | loss = self.criterion(outputs, labels)
200 |
201 | loss = loss.detach().cpu().data
202 | total_loss += loss * trans_image.shape[0]
203 | total_samples += trans_image.shape[0]
204 |
205 | if i % 10 == 0:
206 | self.log.info(
207 | printer(
208 | self.device,
209 | f"Valid: Step {i + 1}/{len(self.valid_dataloader)} | Loss {loss:.4f} ({total_loss / total_samples:.4f})",
210 | )
211 | )
212 |
213 | return total_loss / total_samples
214 |
215 | def save_model(self, epoch: int):
216 | filename = Path(self.exp_dir) / "model" / f"epoch{epoch}_model.pt"
217 | torch.save(self.model.module.state_dict(), filename)
218 | self.log.info(printer(self.device, f"Saving model to {filename}"))
219 | filename = Path(self.exp_dir) / "model" / f"best.pt"
220 | torch.save(self.model.module.state_dict(), filename)
221 |
222 | def load_model(self, path: Union[str, Path]):
223 | self.model.load_state_dict(torch.load(path, map_location="cpu"))
224 | self.log.info(printer(self.device, f"Loading model from {path}"))
225 |
226 | def save_snapshot(self, epoch: int, best_loss: float):
227 | state_info = {
228 | "EPOCH": epoch + 1,
229 | "STEP": self.global_step,
230 | "OPTIMIZER": self.optimizer.state_dict(),
231 | "LR_SCHEDULER": self.lr_scheduler.state_dict(),
232 | "MODEL": self.model.module.state_dict(),
233 | "LOSS": best_loss,
234 | }
235 |
236 | snapshot_path = Path(self.exp_dir) / "snapshot" / f"epoch{epoch}_snapshot.pt"
237 | torch.save(state_info, snapshot_path)
238 |
239 | self.log.info(printer(self.device, f"Saving snapshot to {snapshot_path}"))
240 |
241 | def load_snapshot(self, path: Path):
242 | self.log.info(printer(self.device, f"Loading snapshot from {path}"))
243 | snapshot = torch.load(path, map_location="cpu")
244 | assert SNAPSHOT_KEYS.issubset(snapshot.keys())
245 | return snapshot
246 |
--------------------------------------------------------------------------------
/src/trainer/train_table.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, List, Union, Dict, Optional
2 | import torch
3 | import wandb
4 | import json
5 | import os
6 | from torch import nn, Tensor, autograd
7 | from torch.utils.data import DataLoader
8 | from omegaconf import DictConfig
9 | from hydra.utils import instantiate
10 | import logging
11 | from pathlib import Path
12 | from torch.nn.parallel import DistributedDataParallel as DDP
13 | import tokenizers as tk
14 | import torch.nn.functional as F
15 |
16 | from src.trainer.utils import (
17 | Batch,
18 | configure_optimizer_weight_decay,
19 | turn_off_beit_grad,
20 | VALID_HTML_TOKEN,
21 | INVALID_CELL_TOKEN,
22 | VALID_BBOX_TOKEN,
23 | )
24 | from src.utils import (
25 | printer,
26 | compute_grad_norm,
27 | count_total_parameters,
28 | batch_autoregressive_decode,
29 | combine_filename_pred_gt,
30 | )
31 |
32 | SNAPSHOT_KEYS = set(["EPOCH", "STEP", "OPTIMIZER", "LR_SCHEDULER", "MODEL", "LOSS"])
33 |
34 |
35 | class TableTrainer:
36 | """A trainer for table recognition. The supported tasks are:
37 | 1) table structure extraction
38 | 2) table cell bbox detection
39 | 3) table cell content recognition
40 |
41 | Args:
42 | ----
43 | device: gpu id
44 | vocab: a vocab shared among all tasks
45 | model: model architecture
46 | log: logger
47 | exp_dir: the experiment directory that saves logs, wandb files, model weights, and checkpoints (snapshots)
48 | snapshot: specify which snapshot to use, only used in training
49 | model_weights: specify which model weight to use, only used in testing
50 | beit_pretrained_weights: load SSL pretrained visual encoder
51 | freeze_beit_epoch: freeze beit weights for the first {freeze_beit_epoch} epochs
52 | """
53 |
54 | def __init__(
55 | self,
56 | device: int,
57 | vocab: tk.Tokenizer,
58 | model: nn.Module,
59 | log: logging.Logger,
60 | exp_dir: Path,
61 | snapshot: Path = None,
62 | model_weights: str = None,
63 | beit_pretrained_weights: str = None,
64 | freeze_beit_epoch: int = None,
65 | ) -> None:
66 | self.device = device
67 | self.log = log
68 | self.exp_dir = exp_dir
69 | self.vocab = vocab
70 | self.padding_idx = vocab.token_to_id("")
71 | self.freeze_beit_epoch = freeze_beit_epoch
72 |
73 | # loss for training html, cell
74 | self.criterion = nn.CrossEntropyLoss(ignore_index=self.padding_idx)
75 |
76 | self.model = model
77 |
78 | if (
79 | beit_pretrained_weights is not None
80 | and Path(beit_pretrained_weights).is_file()
81 | ):
82 | self.load_pretrained_beit(Path(beit_pretrained_weights))
83 |
84 | assert (
85 | snapshot is None or model_weights is None
86 | ), "Cannot set snapshot and model_weights at the same time!"
87 |
88 | if snapshot is not None and snapshot.is_file():
89 | self.snapshot = self.load_snapshot(snapshot)
90 | self.model.load_state_dict(self.snapshot["MODEL"])
91 | self.start_epoch = self.snapshot["EPOCH"]
92 | self.global_step = self.snapshot["STEP"]
93 | elif model_weights is not None and Path(model_weights).is_file():
94 | self.load_model(Path(model_weights))
95 | else:
96 | self.snapshot = None
97 | self.start_epoch = 0
98 | self.global_step = 0
99 |
100 | if freeze_beit_epoch and freeze_beit_epoch > 0:
101 | self._freeze_beit()
102 |
103 | self.model = self.model.to(device)
104 | self.model = DDP(self.model, device_ids=[device])
105 |
106 | # https://discuss.pytorch.org/t/extra-10gb-memory-on-gpu-0-in-ddp-tutorial/118113
107 | torch.cuda.set_device(device) # master gpu takes up extra memory
108 | torch.cuda.empty_cache()
109 |
110 | def _freeze_beit(self):
111 | if self.start_epoch < self.freeze_beit_epoch:
112 | turn_off_beit_grad(self.model)
113 | self.log.info(
114 | printer(
115 | self.device,
116 | f"Lock SSL params for {self.freeze_beit_epoch} epochs (params: {count_total_parameters(self.model) / 1e6:.2f}M) - Current epoch {self.start_epoch + 1}",
117 | )
118 | )
119 | else:
120 | self.log.info(
121 | printer(
122 | self.device,
123 | f"Unlock all weights (params: {count_total_parameters(self.model) / 1e6:.2f}M) - Current epoch {self.start_epoch + 1}",
124 | )
125 | )
126 |
127 | def train_epoch(
128 | self,
129 | epoch: int,
130 | target: str,
131 | loss_weights: List[float],
132 | grad_clip: float = None,
133 | ):
134 | avg_loss = 0.0
135 |
136 | # load data from dataloader
137 | for i, obj in enumerate(self.train_dataloader):
138 | batch = Batch(device=self.device, target=target, vocab=self.vocab, obj=obj)
139 |
140 | with autograd.detect_anomaly():
141 | loss, _ = batch.inference(
142 | self.model,
143 | criterion=self.criterion,
144 | criterion_bbox=self.criterion_bbox,
145 | loss_weights=loss_weights,
146 | )
147 |
148 | total_loss = loss["total"]
149 |
150 | self.optimizer.zero_grad()
151 | total_loss.backward()
152 | if grad_clip:
153 | nn.utils.clip_grad_norm_(
154 | self.model.parameters(), max_norm=grad_clip
155 | )
156 | self.optimizer.step()
157 |
158 | total_loss = total_loss.detach().cpu().data
159 | avg_loss += total_loss
160 | self.lr_scheduler.step()
161 | self.global_step += 1
162 |
163 | if i % 10 == 0:
164 | grad_norm = compute_grad_norm(self.model)
165 | lr = self.optimizer.param_groups[0]["lr"]
166 | # elapsed = time.time() - start
167 |
168 | loss_info = f"Loss {total_loss:.3f} ({avg_loss / (i + 1):.3f})"
169 | if not isinstance(loss["html"], int):
170 | loss_info += f" Html {loss['html'].detach().cpu().data:.3f}"
171 | if not isinstance(loss["cell"], int):
172 | loss_info += f" Cell {loss['cell'].detach().cpu().data:.3f}"
173 | if not isinstance(loss["bbox"], int):
174 | loss_info += f" Bbox {loss['bbox'].detach().cpu().data:.3f}"
175 | self.log.info(
176 | printer(
177 | self.device,
178 | f"Epoch {epoch} Step {i + 1}/{len(self.train_dataloader)} | {loss_info} | Grad norm {grad_norm:.3f} | lr {lr:5.1e}",
179 | )
180 | )
181 |
182 | if i % 100 == 0 and self.device == 0:
183 | log_info = {
184 | "epoch": epoch,
185 | "train_total_loss": total_loss,
186 | "learning rate": lr,
187 | "grad_norm": grad_norm,
188 | }
189 |
190 | wandb.log(
191 | log_info,
192 | step=self.global_step,
193 | )
194 |
195 | def train(
196 | self,
197 | train_dataloader: DataLoader,
198 | valid_dataloader: DataLoader,
199 | train_cfg: DictConfig,
200 | valid_cfg: DictConfig,
201 | ):
202 | self.train_dataloader = train_dataloader
203 | self.valid_dataloader = valid_dataloader
204 |
205 | # ensure correct weight decay: https://github.com/karpathy/minGPT/blob/37baab71b9abea1b76ab957409a1cc2fbfba8a26/mingpt/model.py#L215
206 | optim_params = configure_optimizer_weight_decay(
207 | self.model.module, weight_decay=train_cfg.optimizer.weight_decay
208 | )
209 |
210 | self.optimizer = instantiate(train_cfg.optimizer, optim_params)
211 |
212 | self.lr_scheduler = instantiate(
213 | train_cfg.lr_scheduler, optimizer=self.optimizer
214 | )
215 |
216 | if self.snapshot is not None:
217 | self.optimizer.load_state_dict(self.snapshot["OPTIMIZER"])
218 | self.lr_scheduler.load_state_dict(self.snapshot["LR_SCHEDULER"])
219 |
220 | self.criterion_bbox = None
221 | if "bbox" in train_cfg.target:
222 | tmp = [
223 | self.vocab.token_to_id(i)
224 | for i in VALID_BBOX_TOKEN[
225 | : train_cfg.img_size[0] + 2
226 | ] # +1 for +1 for bbox == img_size
227 | ]
228 | tmp = [1.0 if i in tmp else 0.0 for i in range(self.vocab.get_vocab_size())]
229 | self.criterion_bbox = nn.CrossEntropyLoss(
230 | weight=torch.tensor(tmp, device=self.device),
231 | ignore_index=self.padding_idx,
232 | )
233 |
234 | best_loss = float("inf")
235 | self.model.train()
236 |
237 | if self.freeze_beit_epoch and self.start_epoch < self.freeze_beit_epoch:
238 | max_epoch = self.freeze_beit_epoch
239 | else:
240 | max_epoch = train_cfg.epochs
241 | for epoch in range(self.start_epoch, max_epoch):
242 | train_dataloader.sampler.set_epoch(epoch)
243 |
244 | self.train_epoch(
245 | epoch,
246 | grad_clip=train_cfg.grad_clip,
247 | target=train_cfg.target,
248 | loss_weights=train_cfg.loss_weights,
249 | )
250 |
251 | torch.cuda.empty_cache()
252 |
253 | valid_loss = self.valid(valid_cfg)
254 |
255 | if self.device == 0:
256 | wandb.log(
257 | {"valid loss (epoch)": valid_loss},
258 | step=self.global_step,
259 | )
260 |
261 | if epoch % train_cfg.save_every == 0:
262 | self.save_snapshot(epoch, best_loss)
263 | if valid_loss < best_loss:
264 | self.save_model(epoch)
265 | best_loss = valid_loss
266 |
267 | def valid(self, cfg: DictConfig):
268 | total_loss = 0.0
269 | avg_loss = 0.0
270 | total_samples = 0
271 |
272 | self.model.eval()
273 | for i, obj in enumerate(self.valid_dataloader):
274 | batch = Batch(
275 | device=self.device, target=cfg.target, vocab=self.vocab, obj=obj
276 | )
277 | with torch.no_grad():
278 | loss, _ = batch.inference(
279 | self.model,
280 | criterion=self.criterion,
281 | criterion_bbox=self.criterion_bbox,
282 | loss_weights=cfg.loss_weights,
283 | )
284 |
285 | total_loss = loss["total"]
286 | total_loss = total_loss.detach().cpu().data
287 | avg_loss += total_loss * batch.image.shape[0]
288 | total_samples += batch.image.shape[0]
289 |
290 | if i % 10 == 0:
291 | loss_info = f"Loss {total_loss:.3f} ({avg_loss / total_samples:.3f})"
292 | if not isinstance(loss["html"], int):
293 | loss_info += f" Html {loss['html'].detach().cpu().data:.3f}"
294 | if not isinstance(loss["cell"], int):
295 | loss_info += f" Cell {loss['cell'].detach().cpu().data:.3f}"
296 | if not isinstance(loss["bbox"], int):
297 | loss_info += f" Bbox {loss['bbox'].detach().cpu().data:.3f}"
298 | self.log.info(
299 | printer(
300 | self.device,
301 | f"Valid: Step {i + 1}/{len(self.valid_dataloader)} | {loss_info}",
302 | )
303 | )
304 |
305 | return avg_loss / total_samples
306 |
307 | def test(self, test_dataloader: DataLoader, cfg: DictConfig, save_to: str):
308 | total_result = dict()
309 | for i, obj in enumerate(test_dataloader):
310 | batch = Batch(
311 | device=self.device, target=cfg.target, vocab=self.vocab, obj=obj
312 | )
313 |
314 | if cfg.target == "html":
315 | prefix = [self.vocab.token_to_id("[html]")]
316 | valid_token_whitelist = [
317 | self.vocab.token_to_id(i) for i in VALID_HTML_TOKEN
318 | ]
319 | valid_token_blacklist = None
320 | elif cfg.target == "cell":
321 | prefix = [self.vocab.token_to_id("[cell]")]
322 | valid_token_whitelist = None
323 | valid_token_blacklist = [
324 | self.vocab.token_to_id(i) for i in INVALID_CELL_TOKEN
325 | ]
326 | elif cfg.target == "bbox":
327 | prefix = [self.vocab.token_to_id("[bbox]")]
328 | valid_token_whitelist = [
329 | self.vocab.token_to_id(i)
330 | for i in VALID_BBOX_TOKEN[: cfg.img_size[0]]
331 | ]
332 | valid_token_blacklist = None
333 | else:
334 | raise NotImplementedError
335 |
336 | pred_id = batch_autoregressive_decode(
337 | device=self.device,
338 | model=self.model,
339 | batch_data=batch,
340 | prefix=prefix,
341 | max_decode_len=cfg.max_seq_len,
342 | eos_id=self.vocab.token_to_id(""),
343 | valid_token_whitelist=valid_token_whitelist,
344 | valid_token_blacklist=valid_token_blacklist,
345 | sampling=cfg.sampling,
346 | )
347 |
348 | if cfg.target == "html":
349 | result = combine_filename_pred_gt(
350 | filename=batch.name,
351 | pred_id=pred_id,
352 | gt_id=batch.html_tgt,
353 | vocab=self.vocab,
354 | type="html",
355 | )
356 | elif cfg.target == "cell":
357 | result = combine_filename_pred_gt(
358 | filename=batch.name,
359 | pred_id=pred_id,
360 | gt_id=batch.cell_tgt,
361 | vocab=self.vocab,
362 | type="cell",
363 | )
364 | elif cfg.target == "bbox":
365 | result = combine_filename_pred_gt(
366 | filename=batch.name,
367 | pred_id=pred_id,
368 | gt_id=batch.bbox_tgt,
369 | vocab=self.vocab,
370 | type="bbox",
371 | )
372 | else:
373 | raise NotImplementedError
374 |
375 | total_result.update(result)
376 |
377 | if i % 10 == 0:
378 | self.log.info(
379 | printer(
380 | self.device,
381 | f"Test: Step {i + 1}/{len(test_dataloader)}",
382 | )
383 | )
384 |
385 | self.log.info(
386 | printer(
387 | self.device,
388 | f"Converting {len(total_result)} samples to html tables ...",
389 | )
390 | )
391 |
392 | with open(
393 | os.path.join(save_to, cfg.save_to_prefix + f"_{self.device}.json"),
394 | "w",
395 | encoding="utf-8",
396 | ) as f:
397 | json.dump(total_result, f, indent=4)
398 |
399 | return total_result
400 |
401 | def save_model(self, epoch: int):
402 | filename = Path(self.exp_dir) / "model" / f"epoch{epoch}_model.pt"
403 | torch.save(self.model.module.state_dict(), filename)
404 | self.log.info(printer(self.device, f"Saving model to {filename}"))
405 | filename = Path(self.exp_dir) / "model" / "best.pt"
406 | torch.save(self.model.module.state_dict(), filename)
407 |
408 | def load_model(self, path: Union[str, Path]):
409 | self.model.load_state_dict(torch.load(path, map_location="cpu"))
410 | self.log.info(printer(self.device, f"Loading model from {path}"))
411 |
412 | def save_snapshot(self, epoch: int, best_loss: float):
413 | state_info = {
414 | "EPOCH": epoch + 1,
415 | "STEP": self.global_step,
416 | "OPTIMIZER": self.optimizer.state_dict(),
417 | "LR_SCHEDULER": self.lr_scheduler.state_dict(),
418 | "MODEL": self.model.module.state_dict(),
419 | "LOSS": best_loss,
420 | }
421 |
422 | snapshot_path = Path(self.exp_dir) / "snapshot" / f"epoch{epoch}_snapshot.pt"
423 | torch.save(state_info, snapshot_path)
424 |
425 | self.log.info(printer(self.device, f"Saving snapshot to {snapshot_path}"))
426 |
427 | def load_snapshot(self, path: Path):
428 | self.log.info(printer(self.device, f"Loading snapshot from {path}"))
429 | snapshot = torch.load(path, map_location="cpu")
430 | assert SNAPSHOT_KEYS.issubset(snapshot.keys())
431 | return snapshot
432 |
433 | def load_pretrained_beit(self, path: Path):
434 | self.log.info(printer(self.device, f"Loading pretrained BEiT from {path}"))
435 | beit = torch.load(path, map_location="cpu")
436 | redundant_keys_in_beit = [
437 | "cls_token",
438 | "mask_token",
439 | "generator.weight",
440 | "generator.bias",
441 | ]
442 | for key in redundant_keys_in_beit:
443 | if key in beit:
444 | del beit[key]
445 |
446 | # max_seq_len in finetuning may go beyond the length in pretraining
447 | if (
448 | self.model.pos_embed.embedding.weight.shape[0]
449 | != beit["pos_embed.embedding.weight"].shape[0]
450 | ):
451 | emb_shape = self.model.pos_embed.embedding.weight.shape
452 | ckpt_emb = beit["pos_embed.embedding.weight"].clone()
453 | assert emb_shape[1] == ckpt_emb.shape[1]
454 |
455 | ckpt_emb = ckpt_emb.unsqueeze(0).permute(0, 2, 1)
456 | ckpt_emb = F.interpolate(ckpt_emb, emb_shape[0], mode="nearest")
457 | beit["pos_embed.embedding.weight"] = ckpt_emb.permute(0, 2, 1).squeeze()
458 |
459 | out = self.model.load_state_dict(beit, strict=False)
460 |
461 | # ensure missing keys are just token_embed, decoder, and generator
462 | missing_keys_prefix = ("token_embed", "decoder", "generator")
463 | for key in out[0]:
464 | assert key.startswith(
465 | missing_keys_prefix
466 | ), f"Key {key} should be loaded from BEiT, but missing in current state dict."
467 | assert len(out[1]) == 0, f"Unexpected keys from BEiT: {out[1]}"
468 |
--------------------------------------------------------------------------------
/src/trainer/train_vqvae.py:
--------------------------------------------------------------------------------
1 | import math
2 | import wandb
3 | from pathlib import Path
4 | from typing import Tuple, List, Union, Dict
5 | from omegaconf import DictConfig
6 | from hydra.utils import instantiate
7 | import logging
8 | import torch
9 | import time
10 | from functools import partial
11 | from torch import nn, Tensor, autograd
12 | from torch.utils.data import DataLoader
13 | from torch.optim import Adam
14 | from torch.nn.parallel import DistributedDataParallel as DDP
15 | import torch.distributed as dist
16 | from torchvision.utils import make_grid
17 |
18 | from src.utils import printer, compute_grad_norm
19 |
20 | SNAPSHOT_KEYS = set(["EPOCH", "STEP", "OPTIMIZER", "LR_SCHEDULER", "MODEL", "LOSS"])
21 |
22 |
23 | class VqvaeTrainer:
24 | def __init__(
25 | self,
26 | device: int,
27 | model: nn.Module,
28 | log: logging.Logger,
29 | exp_dir: Path,
30 | snapshot: Path = None,
31 | model_weights: Path = None, # only for testing
32 | ) -> None:
33 | self.device = device
34 | self.log = log
35 | self.exp_dir = exp_dir
36 | assert (
37 | snapshot is None or model_weights is None
38 | ), "Snapshot and model weights cannot be set at the same time."
39 |
40 | self.model = model
41 | if snapshot is not None and snapshot.is_file():
42 | self.snapshot = self.load_snapshot(snapshot)
43 | self.model.load_state_dict(self.snapshot["MODEL"])
44 | self.start_epoch = self.snapshot["EPOCH"]
45 | self.global_step = self.snapshot["STEP"]
46 | elif model_weights is not None and model_weights.is_file():
47 | self.load_model(model_weights)
48 | else:
49 | self.snapshot = None
50 | self.start_epoch = 0
51 |
52 | self.model = self.model.to(device)
53 | self.model = DDP(self.model, device_ids=[device])
54 |
55 | # https://discuss.pytorch.org/t/extra-10gb-memory-on-gpu-0-in-ddp-tutorial/118113
56 | torch.cuda.set_device(device) # master gpu takes up extra memory
57 | torch.cuda.empty_cache()
58 |
59 | def train_epoch(
60 | self,
61 | epoch: int,
62 | starting_temp: float,
63 | anneal_rate: float,
64 | temp_min: float,
65 | grad_clip: float = None,
66 | ):
67 | start = time.time()
68 | total_loss = 0.0
69 | total_samples = 0
70 |
71 | # load data from dataloader
72 | for i, obj in enumerate(self.train_dataloader):
73 | if isinstance(obj, Tensor):
74 | img = obj.to(self.device)
75 | elif isinstance(obj, (list, tuple)):
76 | img = obj[0].to(self.device)
77 | else:
78 | raise ValueError(f"Unrecognized object type {type(obj)}")
79 |
80 | # temperature annealing
81 | self.temp = max(
82 | starting_temp * math.exp(-anneal_rate * self.global_step), temp_min
83 | )
84 |
85 | with autograd.detect_anomaly():
86 | loss, soft_recons = self.model(
87 | img, return_loss=True, return_recons=True, temp=self.temp
88 | )
89 |
90 | self.optimizer.zero_grad()
91 | loss.backward()
92 | if grad_clip:
93 | nn.utils.clip_grad_norm_(
94 | self.model.parameters(), max_norm=grad_clip
95 | )
96 | self.optimizer.step()
97 |
98 | loss = loss.detach().cpu().data
99 | total_loss += loss * img.shape[0]
100 | total_samples += img.shape[0]
101 |
102 | self.lr_scheduler.step()
103 | self.global_step += 1
104 |
105 | if i % 10 == 0:
106 | grad_norm = compute_grad_norm(self.model)
107 | lr = self.optimizer.param_groups[0]["lr"]
108 | elapsed = time.time() - start
109 | self.log.info(
110 | printer(
111 | self.device,
112 | f"Epoch {epoch} Step {i + 1}/{len(self.train_dataloader)} | Loss {loss:.4f} ({total_loss / total_samples:.4f}) | Grad norm {grad_norm:.3f} | {total_samples / elapsed:4.1f} images/s | lr {lr:5.1e} | Temp {self.temp:.2e}",
113 | )
114 | )
115 |
116 | # visualize reconstruction images
117 | if i % 100 == 0 and self.device == 0:
118 | lr = self.optimizer.param_groups[0]["lr"]
119 | k = 4 # num of images saved for visualization
120 | codes = self.model.module.get_codebook_indices(img[:k])
121 | hard_recons = self.model.module.decode(codes)
122 |
123 | img = img[:k].detach().cpu()
124 | soft_recons = soft_recons[:k].detach().cpu()
125 | codes = codes.flatten(start_dim=1).detach().cpu()
126 | hard_recons = hard_recons.detach().cpu()
127 |
128 | make_vis = partial(make_grid, nrow=int(math.sqrt(k)), normalize=True)
129 | img, soft_recons, hard_recons = map(
130 | make_vis, (img, soft_recons, hard_recons)
131 | )
132 |
133 | log_info = {
134 | "epoch": epoch,
135 | "train_loss": loss,
136 | "temperature": self.temp,
137 | "learning rate": lr,
138 | "original images": wandb.Image(
139 | img, caption=f"step: {self.global_step}"
140 | ),
141 | "soft reconstruction": wandb.Image(
142 | soft_recons, caption=f"step: {self.global_step}"
143 | ),
144 | "hard reconstruction": wandb.Image(
145 | hard_recons, caption=f"step: {self.global_step}"
146 | ),
147 | "codebook_indices": wandb.Histogram(codes),
148 | }
149 |
150 | wandb.log(
151 | log_info,
152 | step=self.global_step,
153 | )
154 |
155 | return total_loss, total_samples
156 |
157 | def train(
158 | self,
159 | train_dataloader: DataLoader,
160 | valid_dataloader: DataLoader,
161 | train_cfg: DictConfig,
162 | valid_cfg: DictConfig,
163 | ):
164 | self.train_dataloader = train_dataloader
165 | self.valid_dataloader = valid_dataloader
166 | self.optimizer = instantiate(
167 | train_cfg.optimizer, params=self.model.parameters()
168 | )
169 |
170 | self.lr_scheduler = instantiate(
171 | train_cfg.lr_scheduler, optimizer=self.optimizer
172 | )
173 |
174 | if self.snapshot is not None:
175 | self.optimizer.load_state_dict(self.snapshot["OPTIMIZER"])
176 | self.lr_scheduler.load_state_dict(self.snapshot["LR_SCHEDULER"])
177 |
178 | best_loss = float("inf")
179 | self.model.train()
180 | self.global_step = 0
181 | # self.temp = train_cfg.starting_temp
182 | for epoch in range(self.start_epoch, train_cfg.epochs):
183 | train_dataloader.sampler.set_epoch(epoch)
184 | epoch_loss, epoch_samples = self.train_epoch(
185 | epoch,
186 | starting_temp=train_cfg.starting_temp,
187 | anneal_rate=train_cfg.temp_anneal_rate,
188 | temp_min=train_cfg.temp_min,
189 | grad_clip=train_cfg.grad_clip,
190 | )
191 |
192 | torch.cuda.empty_cache()
193 |
194 | valid_loss, valid_samples = self.valid(valid_cfg)
195 |
196 | # reduce loss to gpu 0
197 | training_info = torch.tensor(
198 | [epoch_loss, epoch_samples, valid_loss, valid_samples],
199 | device=self.device,
200 | )
201 |
202 | dist.reduce(
203 | training_info,
204 | dst=0,
205 | op=dist.ReduceOp.SUM,
206 | )
207 |
208 | if self.device == 0:
209 | grad_norm = compute_grad_norm(self.model)
210 | epoch_loss, epoch_samples, valid_loss, valid_samples = training_info
211 | epoch_loss, valid_loss = (
212 | float(epoch_loss) / epoch_samples,
213 | float(valid_loss) / valid_samples,
214 | )
215 |
216 | log_info = {
217 | "train loss (epoch)": epoch_loss,
218 | "valid loss (epoch)": valid_loss,
219 | "train_samples": epoch_samples,
220 | "valid_samples": valid_samples,
221 | "grad_norm": grad_norm,
222 | }
223 |
224 | wandb.log(
225 | log_info,
226 | step=self.global_step,
227 | )
228 |
229 | if epoch % train_cfg.save_every == 0:
230 | self.save_snapshot(epoch, best_loss)
231 | if valid_loss < best_loss:
232 | self.save_model(epoch)
233 | best_loss = valid_loss
234 |
235 | def valid(self, cfg: DictConfig):
236 | total_samples = 0
237 | total_loss = 0.0
238 |
239 | self.model.eval()
240 | for i, obj in enumerate(self.valid_dataloader):
241 | if isinstance(obj, Tensor):
242 | img = obj.to(self.device)
243 | elif isinstance(obj, (list, tuple)):
244 | img = obj[0].to(self.device)
245 | else:
246 | raise ValueError(f"Unrecognized object type {type(obj)}")
247 |
248 | with torch.no_grad():
249 | loss = self.model(
250 | img, return_loss=True, return_recons=False, temp=self.temp
251 | )
252 |
253 | loss = loss.detach().cpu().data
254 | total_loss += loss * img.shape[0]
255 | total_samples += img.shape[0]
256 |
257 | if i % 10 == 0:
258 | self.log.info(
259 | printer(
260 | self.device,
261 | f"Valid: Step {i + 1}/{len(self.valid_dataloader)} | Loss {loss:.4f} ({total_loss / total_samples:.4f})",
262 | )
263 | )
264 |
265 | return total_loss, total_samples
266 |
267 | def save_model(self, epoch: int):
268 | filename = Path(self.exp_dir) / "model" / f"epoch{epoch}_model.pt"
269 | torch.save(self.model.module.state_dict(), filename)
270 | self.log.info(printer(self.device, f"Saving model to {filename}"))
271 | filename = Path(self.exp_dir) / "model" / f"best.pt"
272 | torch.save(self.model.module.state_dict(), filename)
273 |
274 | def load_model(self, path: Union[str, Path]):
275 | self.model.load_state_dict(torch.load(path, map_location="cpu"))
276 | self.log.info(printer(self.device, f"Loading model from {path}"))
277 |
278 | def save_snapshot(self, epoch: int, best_loss: float):
279 | state_info = {
280 | "EPOCH": epoch + 1,
281 | "STEP": self.global_step,
282 | "OPTIMIZER": self.optimizer.state_dict(),
283 | "LR_SCHEDULER": self.lr_scheduler.state_dict(),
284 | "MODEL": self.model.module.state_dict(),
285 | "LOSS": best_loss,
286 | }
287 |
288 | snapshot_path = Path(self.exp_dir) / "snapshot" / f"epoch{epoch}_snapshot.pt"
289 | torch.save(state_info, snapshot_path)
290 |
291 | self.log.info(printer(self.device, f"Saving snapshot to {snapshot_path}"))
292 |
293 | def load_snapshot(self, path: Path):
294 | self.log.info(printer(self.device, f"Loading snapshot from {path}"))
295 | snapshot = torch.load(path, map_location="cpu")
296 | assert SNAPSHOT_KEYS.issubset(snapshot.keys())
297 | return snapshot
298 |
--------------------------------------------------------------------------------
/src/trainer/utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple, Dict
2 | import torch
3 | from torch import Tensor, nn
4 | from torchtext.vocab import Vocab
5 | import tokenizers as tk
6 |
7 | from src.utils import pred_token_within_range, subsequent_mask
8 | from src.vocab import (
9 | HTML_TOKENS,
10 | TASK_TOKENS,
11 | RESERVED_TOKENS,
12 | BBOX_TOKENS,
13 | )
14 |
15 |
16 | VALID_HTML_TOKEN = [""] + HTML_TOKENS
17 | INVALID_CELL_TOKEN = (
18 | ["", "", "", ""] + TASK_TOKENS + RESERVED_TOKENS
19 | )
20 | VALID_BBOX_TOKEN = [
21 | ""
22 | ] + BBOX_TOKENS # image size will be addressed after instantiation
23 |
24 |
25 | class Batch:
26 | """Wrap up a batch of training samples with different training targets.
27 | The input is not torch tensor
28 | Shape of the image (src): B, S, E
29 | Shape of the text (tgt): B, N, S, E (M includes 1 table detection, 1 structure, 1 cell, and multiple bbox)
30 | Reshape text to (B * N, S, E) and inflate the image to match the shape of the text
31 |
32 | Args:
33 | ----
34 | device: gpu id
35 | """
36 |
37 | def __init__(
38 | self,
39 | device: torch.device,
40 | target: str,
41 | vocab: Vocab,
42 | obj: List,
43 | ) -> None:
44 | self.device = device
45 | self.image = obj[0].to(device)
46 | self.name = obj[1]["filename"]
47 | self.target = target
48 | self.vocab = vocab
49 | self.image_size = self.image.shape[-1]
50 |
51 | if "table" in target:
52 | raise NotImplementedError
53 |
54 | if "html" in target:
55 | self.valid_html_token = [vocab.token_to_id(i) for i in VALID_HTML_TOKEN]
56 | (
57 | self.html_src,
58 | self.html_tgt,
59 | self.html_casual_mask,
60 | self.html_padding_mask,
61 | ) = self._prepare_transformer_input(obj[1]["html"])
62 |
63 | if "cell" in target:
64 | self.invalid_cell_token = [vocab.token_to_id(i) for i in INVALID_CELL_TOKEN]
65 | (
66 | self.cell_src,
67 | self.cell_tgt,
68 | self.cell_casual_mask,
69 | self.cell_padding_mask,
70 | ) = self._prepare_transformer_input(obj[1]["cell"])
71 |
72 | if "bbox" in target:
73 | (
74 | self.bbox_src,
75 | self.bbox_tgt,
76 | self.bbox_casual_mask,
77 | self.bbox_padding_mask,
78 | ) = self._prepare_transformer_input(obj[1]["bbox"])
79 |
80 | def _prepare_transformer_input(
81 | self, seq: List[tk.Encoding]
82 | ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
83 | tmp = [i.ids for i in seq]
84 | tmp = torch.tensor(tmp, dtype=torch.int32)
85 | src = tmp[:, :-1].to(self.device)
86 | tgt = tmp[:, 1:].type(torch.LongTensor).to(self.device)
87 | casual_mask = subsequent_mask(src.shape[-1]).to(self.device)
88 | tmp = [i.attention_mask[:-1] for i in seq] # padding mask
89 | tmp = torch.tensor(tmp, dtype=torch.bool)
90 | padding_mask = (~tmp).to(self.device)
91 |
92 | return src, tgt, casual_mask, padding_mask
93 |
94 | def _inference_one_task(
95 | self, model, memory, src, casual_mask, padding_mask, use_ddp
96 | ):
97 | if use_ddp:
98 | out = model.module.decode(memory, src, casual_mask, padding_mask)
99 | out = model.module.generator(out)
100 | else:
101 | out = model.decode(memory, src, casual_mask, padding_mask)
102 | out = model.generator(out)
103 |
104 | return out
105 |
106 | def inference(
107 | self,
108 | model: nn.Module,
109 | criterion: nn.Module,
110 | criterion_bbox: nn.Module = None,
111 | loss_weights: dict = None,
112 | use_ddp: bool = True,
113 | ) -> Tuple[Dict, Dict]:
114 | pred = dict()
115 | loss = dict(table=0, html=0, cell=0, bbox=0)
116 |
117 | if use_ddp:
118 | memory = model.module.encode(self.image)
119 | else:
120 | memory = model.encode(self.image)
121 |
122 | # inference + suppress invalid logits + compute loss
123 | if "html" in self.target:
124 | out_html = self._inference_one_task(
125 | model,
126 | memory,
127 | self.html_src,
128 | self.html_casual_mask,
129 | self.html_padding_mask,
130 | use_ddp,
131 | )
132 |
133 | pred["html"] = pred_token_within_range(
134 | out_html, white_list=self.valid_html_token
135 | ).permute(0, 2, 1)
136 | loss["html"] = criterion(pred["html"], self.html_tgt)
137 |
138 | if "cell" in self.target:
139 | out_cell = self._inference_one_task(
140 | model,
141 | memory,
142 | self.cell_src,
143 | self.cell_casual_mask,
144 | self.cell_padding_mask,
145 | use_ddp,
146 | )
147 |
148 | pred["cell"] = pred_token_within_range(
149 | out_cell, black_list=self.invalid_cell_token
150 | ).permute(0, 2, 1)
151 | loss["cell"] = criterion(pred["cell"], self.cell_tgt)
152 |
153 | if "bbox" in self.target:
154 | assert criterion_bbox is not None
155 |
156 | out_bbox = self._inference_one_task(
157 | model,
158 | memory,
159 | self.bbox_src,
160 | self.bbox_casual_mask,
161 | self.bbox_padding_mask,
162 | use_ddp,
163 | )
164 | pred["bbox"] = out_bbox.permute(0, 2, 1)
165 | loss["bbox"] = criterion_bbox(pred["bbox"], self.bbox_tgt)
166 |
167 | total = 0.0
168 | for k, v in loss_weights.items():
169 | total += loss[k] * v
170 | loss["total"] = total
171 |
172 | return loss, pred
173 |
174 |
175 | def configure_optimizer_weight_decay(
176 | model: nn.Module, weight_decay: float
177 | ) -> List[Dict]:
178 | weight_decay_blacklist = (nn.LayerNorm, nn.BatchNorm2d, nn.Embedding)
179 |
180 | if hasattr(model, "no_weight_decay"):
181 | skip_list = model.no_weight_decay()
182 | decay = set()
183 | no_decay = set()
184 | for mn, m in model.named_modules():
185 | for pn, p in m.named_parameters():
186 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
187 | if pn.endswith("bias"):
188 | no_decay.add(fpn)
189 | elif pn.endswith("weight") and isinstance(m, weight_decay_blacklist):
190 | no_decay.add(fpn)
191 | elif pn in skip_list:
192 | no_decay.add(fpn)
193 |
194 | param_dict = {pn: p for pn, p in model.named_parameters()}
195 | decay = param_dict.keys() - no_decay
196 |
197 | optim_groups = [
198 | {
199 | "params": [param_dict[pn] for pn in sorted(list(decay))],
200 | "weight_decay": weight_decay,
201 | },
202 | {
203 | "params": [param_dict[pn] for pn in sorted(list(no_decay))],
204 | "weight_decay": 0.0,
205 | },
206 | ]
207 |
208 | return optim_groups
209 |
210 |
211 | def turn_off_beit_grad(model: nn.Module):
212 | "Freeze BEiT pretrained weights."
213 | for param in model.encoder.parameters():
214 | param.requires_grad = False
215 |
216 | for param in model.backbone.parameters():
217 | param.requires_grad = False
218 |
219 | for param in model.pos_embed.parameters():
220 | param.requires_grad = False
221 |
222 |
223 | def turn_on_beit_grad(model: nn.Module):
224 | for param in model.parameters():
225 | param.requires_grad = True
226 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .visualization import *
2 | from .data import *
3 | from .mask_generator import *
4 | from .misc import *
5 |
--------------------------------------------------------------------------------
/src/utils/coco_map.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchmetrics.detection import MeanAveragePrecision
3 | from pprint import pprint
4 |
5 |
6 | def compute_coco_map(file):
7 | coco_pred = list()
8 | coco_gt = list()
9 | for _, obj in file.items():
10 | tmp_pred = {
11 | "boxes": torch.tensor(obj["pred"], device=0),
12 | "labels": torch.tensor([0] * len(obj["pred"]), device=0),
13 | "scores": torch.tensor([0.999] * len(obj["pred"]), device=0),
14 | }
15 |
16 | tmp_gt = {
17 | "boxes": torch.tensor(obj["gt"], device=0),
18 | "labels": torch.tensor([0] * len(obj["gt"]), device=0),
19 | }
20 |
21 | coco_pred.append(tmp_pred)
22 | coco_gt.append(tmp_gt)
23 |
24 | metric = MeanAveragePrecision(
25 | iou_type="bbox",
26 | max_detection_thresholds=[1, 10, 1000],
27 | backend="faster_coco_eval",
28 | )
29 | metric.update(coco_pred, coco_gt)
30 | pprint(metric.compute())
31 |
32 |
33 | if __name__ == "__main__":
34 | import json
35 | import argparse
36 |
37 | parser = argparse.ArgumentParser(description="mAP Computation")
38 |
39 | parser.add_argument("-f", "--file", help="path to html table results in json file")
40 | args = parser.parse_args()
41 |
42 |
43 | results_file = args.file
44 | with open(results_file, "r") as f:
45 | results_json = json.load(f)
46 |
47 | compute_coco_map(results_json)
48 |
--------------------------------------------------------------------------------
/src/utils/data.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 | import random
3 | import tokenizers as tk
4 | import torch
5 | from torch import Tensor, nn
6 | import torch.nn.functional as F
7 |
8 | from src.vocab import TASK_TOKENS, CELL_SPECIAL
9 | from src.model.encoderdecoder import EncoderDecoder
10 | from .misc import html_table_template
11 |
12 | __all__ = [
13 | "subsequent_mask",
14 | "combine_cell_char_seq",
15 | "random_continuous_sequence",
16 | "prepare_html_seq",
17 | "prepare_cell_seq",
18 | "prepare_bbox_seq",
19 | "html_str_to_token_list",
20 | "cell_str_to_token_list",
21 | "bbox_str_to_token_list",
22 | "build_table_from_html_and_cell",
23 | "pred_token_within_range",
24 | "batch_autoregressive_decode",
25 | "greedy_sampling",
26 | "combine_filename_pred_gt",
27 | ]
28 |
29 |
30 | def subsequent_mask(size: int, pad: int = 0):
31 | attn_shape = (size, size)
32 | output = torch.triu(torch.ones(attn_shape), diagonal=1).to(torch.bool)
33 | if pad and pad > 0:
34 | output[:pad] = False
35 | return output
36 |
37 |
38 | def combine_cell_char_seq(seq: List[str]) -> str:
39 | """Replace empty token with in vocab. combine characters into a str"""
40 | if seq:
41 | out = "".join(seq)
42 | else:
43 | out = ""
44 | return out
45 |
46 |
47 | def prepare_html_seq(seq: List[str]) -> List[str]:
48 | """Convert html annotations to html training template."""
49 | out = ["[html]", *seq, ""]
50 | return out
51 |
52 |
53 | def prepare_cell_seq(seq: str) -> List[str]:
54 | """Convert cell sequence to training template."""
55 | for black in CELL_SPECIAL:
56 | seq = seq.replace(black, "")
57 | out = ["[cell]", seq, ""]
58 |
59 | return out
60 |
61 |
62 | def prepare_bbox_seq(seq: List[dict]):
63 | tmp = [f"bbox-{round(i)}" for i in seq]
64 | out = ["[bbox]"] + tmp + [""]
65 |
66 | return out
67 |
68 |
69 | def random_continuous_sequence(seq: List, N: int, length: int = 10) -> List:
70 | """Randomly sample a continuous sub-sequence from a sequence for N times."""
71 | start_idx = [random.randrange(len(seq)) for _ in range(N)]
72 | subseq_len = [random.randrange(1, length) for _ in range(N)]
73 | output = [(i, min(i + j, len(seq))) for i, j in zip(start_idx, subseq_len)]
74 |
75 | return output
76 |
77 |
78 | # def prepare_bbox_seq(
79 | # seq: List[dict],
80 | # N: int,
81 | # delimiter: str = "",
82 | # ) -> List[List[str]]:
83 | # """Convert the annotation to bbox input/output sequence."""
84 | # out = list()
85 | # # bbox_loss_start_idx = list()
86 |
87 | # subseq_idx = random_continuous_sequence(seq, N)
88 |
89 | # for idx in subseq_idx:
90 | # entry = seq[idx[0] : idx[1]]
91 | # tmp = list()
92 | # bbox_seq = list()
93 | # for i in entry:
94 | # if "tokens" in i.keys():
95 | # # pubtabnet and synthtabnet
96 | # tmp.append(combine_cell_char_seq(i["tokens"]))
97 | # if "bbox" in i.keys():
98 | # bbox_seq.extend([f"bbox-{round(j)}" for j in i["bbox"]])
99 | # elif "text" in i.keys():
100 | # # pubtables and icdar
101 | # tmp.append(i["text"])
102 | # if "bbox" in i.keys():
103 | # bbox_seq.extend([f"bbox-{round(j)}" for j in i["bbox"]])
104 |
105 | # cell_seq = [delimiter] * len(tmp)
106 | # cell_seq = [q for pair in zip(tmp, cell_seq) for q in pair]
107 | # cell_seq = ["[bbox]", f"{len(entry)}-cell(s)", delimiter] + cell_seq
108 |
109 | # bbox_seq.append("")
110 | # # bbox_loss_start_idx.append(len(cell_seq))
111 | # out.append(cell_seq + bbox_seq)
112 |
113 | # return out
114 |
115 |
116 | def html_str_to_token_list(
117 | seq: str, splitter: tk.pre_tokenizers.PreTokenizer = None
118 | ) -> List[str]:
119 | """Convert decode output (str) to a list of tokens for constructing html table code"""
120 |
121 | # works for no
122 | seq = seq.split("")[0]
123 |
124 | token_black_list = ["", "", *TASK_TOKENS]
125 | for i in token_black_list:
126 | seq = seq.replace(i, "")
127 |
128 | if not splitter:
129 | splitter = tk.pre_tokenizers.Split(pattern=" ", behavior="contiguous")
130 |
131 | seq = splitter.pre_tokenize_str(seq)
132 | # only preserve the space for spanning cell tokens
133 | seq = [i[0] for i in seq if len(i[0].strip()) != 0 or i[1][1] - i[1][0] != 1]
134 |
135 | return seq
136 |
137 |
138 | def cell_str_to_token_list(seq: str) -> List[str]:
139 | seq = seq.split("")[0]
140 |
141 | token_black_list = ["", "", *TASK_TOKENS]
142 | for i in token_black_list:
143 | seq = seq.replace(i, "")
144 |
145 | seq = seq.strip()
146 |
147 | return seq
148 |
149 |
150 | def build_table_from_html_and_cell(
151 | structure: List[str], content: List[str] = None
152 | ) -> List[str]:
153 | """Build table from html and cell token list"""
154 | assert structure is not None
155 | html_code = list()
156 |
157 | # deal with empty table
158 | if content is None:
159 | content = ["placeholder"] * len(structure)
160 |
161 | for tag in structure:
162 | if tag in ("[] | ", ">[]"):
163 | if len(content) == 0:
164 | continue
165 | cell = content.pop(0)
166 | html_code.append(tag.replace("[]", cell))
167 | else:
168 | html_code.append(tag)
169 |
170 | return html_code
171 |
172 |
173 | def bbox_str_to_token_list(
174 | seq: str, splitter: tk.pre_tokenizers.PreTokenizer = None
175 | ) -> List[List[int]]:
176 | """
177 | Note the out could be an empty list
178 |
179 | return
180 | [[ymin, xmin, ymax, xmax],
181 | [ymin, xmin, ymax, xmax],
182 | ...
183 | ]
184 | """
185 |
186 | seq = seq.split("")[0]
187 |
188 | token_black_list = ["", "", *TASK_TOKENS]
189 | for i in token_black_list:
190 | seq = seq.replace(i, "")
191 |
192 | if not splitter:
193 | splitter = tk.pre_tokenizers.Split(pattern=" ", behavior="removed")
194 |
195 | seq = splitter.pre_tokenize_str(seq)
196 | seq = [int(i[0].split("-")[1]) for i in seq]
197 |
198 | rounded_seq_len = len(seq) // 4 * 4
199 | out = [seq[i : i + 4] for i in range(0, rounded_seq_len, 4)]
200 | return out
201 |
202 |
203 | def pred_token_within_range(
204 | pred: Tensor,
205 | white_list: List[int] = None,
206 | black_list: List[int] = None,
207 | ) -> Tensor:
208 | assert white_list is None or black_list is None
209 | if white_list:
210 | total = set([i for i in range(pred.shape[-1])])
211 | black_list = list(total.difference(set(white_list)))
212 |
213 | pred[..., black_list] = -float("inf")
214 |
215 | return pred
216 |
217 |
218 | def greedy_sampling(logits: Tensor):
219 | """logits should have shape [B, |V|]."""
220 | probs = F.softmax(logits, dim=-1)
221 | next_probs, next_tokens = probs.topk(1)
222 |
223 | return next_probs, next_tokens
224 |
225 |
226 | def batch_autoregressive_decode(
227 | device: int,
228 | model: EncoderDecoder,
229 | batch_data,
230 | prefix: List[int],
231 | max_decode_len: int,
232 | eos_id: int,
233 | valid_token_whitelist: List[int] = None,
234 | valid_token_blacklist: List[int] = None,
235 | sampling: str = "greedy",
236 | use_ddp: bool = True,
237 | ) -> Tensor:
238 | """Auto-regressively generate the output."""
239 |
240 | model.eval()
241 | with torch.no_grad():
242 | if use_ddp:
243 | memory = model.module.encode(batch_data.image)
244 | else:
245 | memory = model.encode(batch_data.image)
246 |
247 | B = batch_data.image.shape[0]
248 |
249 | context = torch.tensor(prefix, dtype=torch.int32).repeat(B, 1).to(device)
250 |
251 | for _ in range(max_decode_len):
252 | eos_flag = [eos_id in k for k in context]
253 | if all(eos_flag):
254 | break
255 |
256 | # as long as one sample hasn't reached , continue decoding until the max seq len
257 | causal_mask = subsequent_mask(context.shape[1]).to(device)
258 |
259 | with torch.no_grad():
260 | if use_ddp:
261 | logits = model.module.decode(
262 | memory, context, tgt_mask=causal_mask, tgt_padding_mask=None
263 | )
264 | logits = model.module.generator(logits)[:, -1, :]
265 | else:
266 | logits = model.decode(
267 | memory, context, tgt_mask=causal_mask, tgt_padding_mask=None
268 | )
269 | logits = model.generator(logits)[:, -1, :]
270 |
271 | logits = pred_token_within_range(
272 | logits.detach(),
273 | white_list=valid_token_whitelist if valid_token_whitelist else None,
274 | black_list=valid_token_blacklist if valid_token_blacklist else None,
275 | )
276 |
277 | if sampling == "greedy":
278 | next_probs, next_tokens = greedy_sampling(logits)
279 | else:
280 | raise NotImplementedError
281 |
282 | context = torch.cat([context, next_tokens], dim=1)
283 |
284 | return context
285 |
286 |
287 | def combine_filename_pred_gt(
288 | filename: List[str], pred_id: Tensor, gt_id: Tensor, vocab: tk.Tokenizer, type: str
289 | ) -> dict:
290 | out = dict()
291 |
292 | assert len(filename) == len(pred_id)
293 |
294 | pred_id = pred_id.detach().cpu().numpy()
295 | gt_id = gt_id.detach().cpu().numpy()
296 |
297 | pred_token = vocab.decode_batch(pred_id, skip_special_tokens=False)
298 | gt_token = vocab.decode_batch(gt_id, skip_special_tokens=False)
299 |
300 | for idx, name in enumerate(filename):
301 | if type == "html":
302 | pred_token_list = html_str_to_token_list(pred_token[idx])
303 | gt_token_list = html_str_to_token_list(gt_token[idx])
304 | elif type == "cell":
305 | pred_token_list = cell_str_to_token_list(pred_token[idx])
306 | gt_token_list = cell_str_to_token_list(gt_token[idx])
307 | elif type == "bbox":
308 | pred_token_list = bbox_str_to_token_list(pred_token[idx])
309 | gt_token_list = bbox_str_to_token_list(gt_token[idx])
310 | else:
311 | raise ValueError(
312 | f"The supported tasks are html, cell and bbox, while {type} is provided."
313 | )
314 |
315 | out[name] = dict(pred=pred_token_list, gt=gt_token_list)
316 |
317 | return out
318 |
--------------------------------------------------------------------------------
/src/utils/engine.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | from pathlib import Path
5 | import glob
6 |
7 | from src.utils import build_table_from_html_and_cell, html_table_template
8 |
9 |
10 | def combine_all_json(file_dir: str) -> dict:
11 | total_result = dict()
12 | files = os.listdir(file_dir)
13 | try:
14 | files.remove("final.json")
15 | except ValueError:
16 | pass
17 | for file in files:
18 | with open(os.path.join(file_dir, file), "r") as f:
19 | result = json.load(f)
20 | total_result.update(result)
21 |
22 | print(f"Combined to a json with {len(total_result)} entries.")
23 |
24 | return total_result
25 |
26 |
27 | def json_to_final(file_dir: str, type: str):
28 | if type == "html" or type == "bbox":
29 | result = combine_all_json(file_dir)
30 | elif type == "html+cell":
31 | result_cell = combine_all_json(file_dir)
32 | result_html_file = os.path.join(
33 | Path(file_dir).parent,
34 | Path(file_dir).name.split("-")[0].replace("cell", "html") + "-html",
35 | )
36 | assert Path(result_html_file).is_dir(), f"{result_html_file} does not exist."
37 | result = combine_all_json(result_html_file)
38 | assert len(result) == len(result_cell)
39 | else:
40 | # assert html and cell json files have the same length
41 | raise NotImplementedError
42 |
43 | out = dict()
44 |
45 | if type == "bbox":
46 | out = result
47 | else:
48 | for filename, obj in result.items():
49 | if type == "html":
50 | pred_html = "".join(obj["pred"])
51 | gt_html = "".join(obj["gt"])
52 |
53 | out[filename] = dict(
54 | pred=html_table_template(pred_html), gt=html_table_template(gt_html)
55 | )
56 | elif type == "html+cell":
57 | pred_html_cell = build_table_from_html_and_cell(
58 | obj["pred"], result_cell[filename]["pred"]
59 | )
60 | gt_html_cell = build_table_from_html_and_cell(
61 | obj["gt"], result_cell[filename]["gt"]
62 | )
63 | out[filename] = dict(
64 | pred=html_table_template(pred_html_cell),
65 | gt=html_table_template(gt_html_cell),
66 | )
67 | else:
68 | raise NotImplementedError
69 |
70 | # write to file
71 | with open(os.path.join(file_dir, f"final.json"), "w", encoding="utf-8") as f:
72 | json.dump(out, f, indent=4)
73 |
74 |
75 | if __name__ == "__main__":
76 | parser = argparse.ArgumentParser(description="postprecess")
77 |
78 | parser.add_argument(
79 | "-f", "--file", help="path to all json files from difference devices"
80 | )
81 | parser.add_argument("-t", "--type", help="html, html+cell")
82 | args = parser.parse_args()
83 |
84 | json_to_final(args.file, args.type)
85 |
--------------------------------------------------------------------------------
/src/utils/mask_generator.py:
--------------------------------------------------------------------------------
1 | import random
2 | import math
3 | from typing import Any
4 | import numpy as np
5 |
6 | """
7 | Code adapted from beit mask generator: https://github.com/microsoft/unilm/blob/ecff36188001e9b12a90b01bbbaf9058d2b8bda6/beit/masking_generator.py .
8 | """
9 |
10 | __all__ = ["MaskGenerator"]
11 |
12 |
13 | class MaskGenerator:
14 | def __init__(
15 | self,
16 | input_size: int,
17 | num_mask_patches: int,
18 | min_num_patches: int = 4,
19 | max_num_patches: int = None,
20 | min_aspect: float = 0.3,
21 | max_aspect: float = None,
22 | ) -> None:
23 | if not isinstance(input_size, tuple):
24 | input_size = (input_size,) * 2
25 | self.height, self.width = input_size
26 |
27 | self.num_patches = self.height * self.width
28 |
29 | self.num_mask_patches = num_mask_patches
30 |
31 | self.min_num_patches = min_num_patches
32 | self.max_num_patches = (
33 | num_mask_patches if max_num_patches is None else max_num_patches
34 | )
35 |
36 | max_aspect = max_aspect or 1 / min_aspect
37 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
38 |
39 | def __repr__(self):
40 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
41 | self.height,
42 | self.width,
43 | self.min_num_patches,
44 | self.max_num_patches,
45 | self.num_mask_patches,
46 | self.log_aspect_ratio[0],
47 | self.log_aspect_ratio[1],
48 | )
49 | return repr_str
50 |
51 | def get_shape(self):
52 | return self.height, self.width
53 |
54 | def _mask(self, mask: np.array, max_mask_patches: int) -> int:
55 | delta = 0
56 | for _ in range(10):
57 | target_area = random.uniform(self.min_num_patches, max_mask_patches)
58 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
59 | h = int(round(math.sqrt(target_area * aspect_ratio)))
60 | w = int(round(math.sqrt(target_area / aspect_ratio)))
61 | if w < self.width and h < self.height:
62 | top = random.randint(0, self.height - h)
63 | left = random.randint(0, self.width - w)
64 |
65 | num_masked = mask[top : top + h, left : left + w].sum()
66 | if 0 < h * w - num_masked <= max_mask_patches:
67 | for i in range(top, top + h):
68 | for j in range(left, left + w):
69 | if mask[i, j] == 0:
70 | mask[i, j] = 1
71 | delta += 1
72 | if delta > 0:
73 | break
74 | return delta
75 |
76 | def __call__(self) -> Any:
77 | mask = np.zeros((self.height, self.width), dtype=np.int32)
78 | mask_count = 0
79 | while mask_count < self.num_mask_patches:
80 | max_mask_patches = self.num_mask_patches - mask_count
81 | max_mask_patches = min(max_mask_patches, self.max_num_patches)
82 |
83 | delta = self._mask(mask, max_mask_patches)
84 | if delta == 0:
85 | break
86 | else:
87 | mask_count += delta
88 |
89 | return mask
90 |
91 |
92 | if __name__ == "__main__":
93 | mg = MaskGenerator(input_size=14, num_mask_patches=75)
94 | mask = mg()
95 | print(mask)
96 | print(mg, mask.sum())
97 |
--------------------------------------------------------------------------------
/src/utils/misc.py:
--------------------------------------------------------------------------------
1 | import math
2 | import jsonlines
3 | from pathlib import Path
4 | from typing import Dict, Tuple, List, Union
5 | from torch import Tensor, nn
6 |
7 | __all__ = [
8 | "cosine_schedule_with_warmup",
9 | "load_json_annotations",
10 | "bbox_augmentation_resize",
11 | "count_total_parameters",
12 | "compute_grad_norm",
13 | "printer",
14 | "html_table_template",
15 | ]
16 |
17 | printer = lambda device, output: f"[GPU {device}] " + output
18 |
19 | html_table_template = (
20 | lambda table: f"""
21 |
22 |
28 |
29 | """
32 | )
33 |
34 |
35 | # adpated from https://github.com/huggingface/transformers/blob/v4.33.0/src/transformers/optimization.py
36 | def cosine_schedule_with_warmup(
37 | step: int,
38 | *,
39 | warmup: int,
40 | min_ratio: float,
41 | total_step: int,
42 | cycle: float = 0.5,
43 | ):
44 | if step < warmup:
45 | if step == 0:
46 | step = 1
47 | return float(step) / float(max(1, warmup))
48 |
49 | if step >= total_step:
50 | step = total_step
51 | progress = float(step - warmup) / float(max(1, total_step - warmup))
52 | return max(
53 | min_ratio, 0.5 * (1.0 + math.cos(math.pi * float(cycle) * 2.0 * progress))
54 | )
55 |
56 |
57 | def load_json_annotations(json_file_dir: Path, split: str):
58 | """Preprocess jsonl in dataset."""
59 | image_label_pair = list()
60 | with jsonlines.open(json_file_dir) as f:
61 | for obj in f:
62 | if obj["split"] == split:
63 | image_label_pair.append((obj["filename"], obj["html"]))
64 |
65 | return image_label_pair
66 |
67 |
68 | def bbox_augmentation_resize(
69 | bbox: List[int], image_size: List[int], target_size: int
70 | ) -> List[int]:
71 | """Modify the bbox coordinates according to the image resizing."""
72 | # Assuming the bbox is [xmin, ymin, xmax, ymax]
73 | assert len(image_size) == 2
74 | ratio = [target_size / i for i in image_size]
75 | ratio = ratio * 2
76 | bbox = [int(round(i * j)) for i, j in zip(bbox, ratio)]
77 | return bbox
78 |
79 |
80 | def count_total_parameters(model: nn.Module) -> int:
81 | """Count total parameters that need training."""
82 | total_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
83 | return total_parameters
84 |
85 |
86 | def compute_grad_norm(model: nn.Module) -> float:
87 | total_norm = 0.0
88 | for p in model.parameters():
89 | if p.grad is not None and p.requires_grad:
90 | param_norm = p.grad.detach().data.norm(2)
91 | total_norm += param_norm.item() ** 2
92 | total_norm = total_norm**0.5
93 | return total_norm
94 |
--------------------------------------------------------------------------------
/src/utils/teds.py:
--------------------------------------------------------------------------------
1 | # code adapted from https://github.com/ibm-aur-nlp/PubTabNet/blob/master/src/metric.py
2 | # tree edit distance video explanation: https://www.youtube.com/watch?v=6Ur8B35xCj8
3 | import apted
4 | import distance
5 | from collections import deque
6 | from lxml import etree, html
7 | from tqdm import tqdm
8 | from concurrent.futures import ProcessPoolExecutor, as_completed
9 | from typing import Tuple
10 |
11 |
12 | class TableTree(apted.helpers.Tree):
13 | def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
14 | self.tag = tag
15 | self.colspan = colspan
16 | self.rowspan = rowspan
17 | self.content = content
18 | self.children = list(children)
19 |
20 | def bracket(self):
21 | """Show tree using brackets notation."""
22 | if self.tag == "td":
23 | result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % (
24 | self.tag,
25 | self.colspan,
26 | self.rowspan,
27 | self.content,
28 | )
29 | else:
30 | result = '"tag": %s' % self.tag
31 | for child in self.children:
32 | result += child.bracket()
33 | return "{{{}}}".format(result)
34 |
35 |
36 | class CustomConfig(apted.Config):
37 | @staticmethod
38 | def maximum(*sequences):
39 | """Get maximum possible value."""
40 | return max(map(len, sequences))
41 |
42 | def normalized_distance(self, *sequences):
43 | """Get distance from 0 to 1."""
44 | return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
45 |
46 | def rename(self, node1, node2):
47 | """Compares attributes of trees"""
48 | if (
49 | (node1.tag != node2.tag)
50 | or (node1.colspan != node2.colspan)
51 | or (node1.rowspan != node2.rowspan)
52 | ):
53 | return 1.0
54 | if node1.tag == "td":
55 | if node1.content or node2.content:
56 | return self.normalized_distance(node1.content, node2.content)
57 | return 0.0
58 |
59 |
60 | class TEDS(object):
61 | """Tree Edit Distance basead Similarity"""
62 |
63 | def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
64 | assert isinstance(n_jobs, int) and (
65 | n_jobs >= 1
66 | ), "n_jobs must be an integer greather than 1"
67 | self.structure_only = structure_only
68 | self.n_jobs = n_jobs
69 | self.ignore_nodes = ignore_nodes
70 | self.__tokens__ = []
71 |
72 | def tokenize(self, node):
73 | """Tokenizes table cells"""
74 | self.__tokens__.append("<%s>" % node.tag)
75 | if node.text is not None:
76 | self.__tokens__ += list(node.text)
77 | for n in node.getchildren():
78 | self.tokenize(n)
79 | if node.tag != "unk":
80 | self.__tokens__.append("%s>" % node.tag)
81 | if node.tag != "td" and node.tail is not None:
82 | self.__tokens__ += list(node.tail)
83 |
84 | def load_html_tree(self, node, parent=None):
85 | """Converts HTML tree to the format required by apted"""
86 | global __tokens__
87 | if node.tag == "td":
88 | if self.structure_only:
89 | cell = []
90 | else:
91 | self.__tokens__ = []
92 | self.tokenize(node)
93 | cell = self.__tokens__[1:-1].copy()
94 | new_node = TableTree(
95 | node.tag,
96 | int(node.attrib.get("colspan", "1")),
97 | int(node.attrib.get("rowspan", "1")),
98 | cell,
99 | *deque(),
100 | )
101 | else:
102 | new_node = TableTree(node.tag, None, None, None, *deque())
103 | if parent is not None:
104 | parent.children.append(new_node)
105 | if node.tag != "td":
106 | for n in node.getchildren():
107 | self.load_html_tree(n, new_node)
108 | if parent is None:
109 | return new_node
110 |
111 | def evaluate(self, pred, true):
112 | """Computes TEDS score between the prediction and the ground truth of a
113 | given sample
114 | """
115 | if (not pred) or (not true):
116 | return 0.0
117 | parser = html.HTMLParser(remove_comments=True, encoding="utf-8")
118 | pred = html.fromstring(pred, parser=parser)
119 | true = html.fromstring(true, parser=parser)
120 | if pred.xpath("body/table") and true.xpath("body/table"):
121 | pred = pred.xpath("body/table")[0]
122 | true = true.xpath("body/table")[0]
123 | if self.ignore_nodes:
124 | etree.strip_tags(pred, *self.ignore_nodes)
125 | etree.strip_tags(true, *self.ignore_nodes)
126 | n_nodes_pred = len(pred.xpath(".//*"))
127 | n_nodes_true = len(true.xpath(".//*"))
128 | n_nodes = max(n_nodes_pred, n_nodes_true)
129 | tree_pred = self.load_html_tree(pred)
130 | tree_true = self.load_html_tree(true)
131 | distance = apted.APTED(
132 | tree_pred, tree_true, CustomConfig()
133 | ).compute_edit_distance()
134 | return 1.0 - (float(distance) / n_nodes)
135 | else:
136 | return 0.0
137 |
138 | def batch_evaluate(self, results_json):
139 | """Computes TEDS score between the prediction and the ground truth of
140 | a batch of samples
141 | @params pred_json: {'FILENAME': 'HTML CODE', ...}
142 | @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
143 | @output: {'FILENAME': 'TEDS SCORE', ...}
144 | """
145 | samples = results_json.keys()
146 | print(f"Total samples: {len(samples)}")
147 | if self.n_jobs == 1:
148 | scores = [
149 | self.evaluate(
150 | results_json[filename]["pred"],
151 | results_json[filename]["gt"],
152 | )
153 | for filename in tqdm(samples)
154 | ]
155 | else:
156 | inputs = [
157 | {
158 | "pred": results_json[filename]["pred"],
159 | "true": results_json[filename]["gt"],
160 | }
161 | for filename in samples
162 | ]
163 | scores = parallel_process(
164 | inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1
165 | )
166 | output = dict()
167 | for i, j in zip(samples, scores):
168 | if "span" in results_json[i]["gt"]:
169 | output[i] = dict(scores=j, type="complex")
170 | else:
171 | output[i] = dict(scores=j, type="simple")
172 | # scores = dict(zip(samples, scores))
173 | return output
174 |
175 |
176 | def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
177 | """
178 | A parallel version of the map function with a progress bar.
179 |
180 | Args:
181 | array (array-like): An array to iterate over.
182 | function (function): A python function to apply to the elements of array
183 | n_jobs (int, default=16): The number of cores to use
184 | use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
185 | keyword arguments to function
186 | front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
187 | Useful for catching bugs
188 | Returns:
189 | [function(array[0]), function(array[1]), ...]
190 | """
191 | # We run the first few iterations serially to catch bugs
192 | if front_num > 0:
193 | front = [
194 | function(**a) if use_kwargs else function(a) for a in array[:front_num]
195 | ]
196 | else:
197 | front = []
198 | # If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
199 | if n_jobs == 1:
200 | return front + [
201 | function(**a) if use_kwargs else function(a)
202 | for a in tqdm(array[front_num:])
203 | ]
204 | # Assemble the workers
205 | with ProcessPoolExecutor(max_workers=n_jobs) as pool:
206 | # Pass the elements of array into function
207 | if use_kwargs:
208 | futures = [pool.submit(function, **a) for a in array[front_num:]]
209 | else:
210 | futures = [pool.submit(function, a) for a in array[front_num:]]
211 | kwargs = {
212 | "total": len(futures),
213 | "unit": "it",
214 | "unit_scale": True,
215 | "leave": True,
216 | }
217 | # Print out the progress as tasks complete
218 | for f in tqdm(as_completed(futures), **kwargs):
219 | pass
220 | out = []
221 | # Get the results from the futures.
222 | for i, future in tqdm(enumerate(futures)):
223 | try:
224 | out.append(future.result())
225 | except Exception as e:
226 | out.append(e)
227 | return front + out
228 |
229 |
230 | if __name__ == "__main__":
231 | import json
232 | import pprint
233 | import numpy as np
234 | import argparse
235 |
236 | parser = argparse.ArgumentParser(description="TEDS Computation")
237 |
238 | parser.add_argument("-f", "--file", help="path to html table results in json file")
239 | parser.add_argument("-t", "--type", help="html, html+cell")
240 | parser.add_argument("-n", "--njob", default=200, help="number of jobs in parallel")
241 | args = parser.parse_args()
242 |
243 | results_file = args.file
244 | with open(results_file, "r") as f:
245 | results_json = json.load(f)
246 |
247 | if args.type == "html":
248 | s_only = True
249 | else:
250 | s_only = False
251 | teds = TEDS(structure_only=s_only, n_jobs=args.njob)
252 | scores = teds.batch_evaluate(results_json)
253 | pp = pprint.PrettyPrinter()
254 | pp.pprint(scores)
255 |
256 | # compute teds for simple and complex tables
257 | total, simple, complex = list(), list(), list()
258 | for _, obj in scores.items():
259 | if obj["type"] == "simple":
260 | simple.append(obj["scores"])
261 | elif obj["type"] == "complex":
262 | complex.append(obj["scores"])
263 | total.append(obj["scores"])
264 |
265 | total, simple, complex = np.array(total), np.array(simple), np.array(complex)
266 | print(
267 | f"Simple: {np.mean(simple)} \nComplex: {np.mean(complex)} \nTotal: {np.mean(total)}"
268 | )
269 |
--------------------------------------------------------------------------------
/src/utils/visualization.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | import numpy as np
3 |
4 |
5 | def normalize_image_for_visualization(mean: float, std: float):
6 | invNormalization = transforms.Compose(
7 | [
8 | transforms.Normalize(mean=[0.0] * 3, std=1.0 / np.array(std)),
9 | transforms.Normalize(mean=-1.0 * np.array(mean), std=[1.0] * 3),
10 | ]
11 | )
12 |
13 | return invNormalization
14 |
--------------------------------------------------------------------------------
/src/vocab/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 | !constant.py
--------------------------------------------------------------------------------
/src/vocab/__init__.py:
--------------------------------------------------------------------------------
1 | from .constant import *
2 |
--------------------------------------------------------------------------------
/src/vocab/constant.py:
--------------------------------------------------------------------------------
1 | SPECIAL_TOKENS = ["", "", "", "", "", ""]
2 | TASK_TOKENS = ["[table]", "[html]", "[cell]", "[bbox]", "[cell+bbox]"]
3 | RESERVED_TOKENS = [
4 | f"reserved {i+1}" for i in range(20 - len(SPECIAL_TOKENS) - len(TASK_TOKENS))
5 | ]
6 | CELL_NUM_TOKENS = [f"{i+1}-cell(s)" for i in range(100)]
7 | BBOX_TOKENS = [f"bbox-{i}" for i in range(880)]
8 |
9 | HTML_TOKENS = [
10 | " | ",
11 | "[] | ",
12 | " | ",
14 | ">[]",
15 | "",
16 | "
",
17 | "",
18 | "",
19 | "",
20 | "",
21 | ' rowspan="2"',
22 | ' rowspan="3"',
23 | ' rowspan="4"',
24 | ' rowspan="5"',
25 | ' rowspan="6"',
26 | ' rowspan="7"',
27 | ' rowspan="8"',
28 | ' rowspan="9"',
29 | ' rowspan="10"',
30 | ' rowspan="11"',
31 | ' rowspan="12"',
32 | ' rowspan="13"',
33 | ' rowspan="14"',
34 | ' rowspan="15"',
35 | ' rowspan="16"',
36 | ' rowspan="17"',
37 | ' rowspan="18"',
38 | ' rowspan="19"',
39 | ' colspan="2"',
40 | ' colspan="3"',
41 | ' colspan="4"',
42 | ' colspan="5"',
43 | ' colspan="6"',
44 | ' colspan="7"',
45 | ' colspan="8"',
46 | ' colspan="9"',
47 | ' colspan="10"',
48 | ' colspan="11"',
49 | ' colspan="12"',
50 | ' colspan="13"',
51 | ' colspan="14"',
52 | ' colspan="15"',
53 | ' colspan="16"',
54 | ' colspan="17"',
55 | ' colspan="18"',
56 | ' colspan="19"',
57 | ' colspan="25"',
58 | ]
59 |
60 | CELL_SPECIAL = ["", "", "", "", "", "", "", ""]
61 |
--------------------------------------------------------------------------------
/vocab/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 | !vocab_html.json
4 | !vocab_bbox.json
5 | !vocab_cell_6k.json
--------------------------------------------------------------------------------
/vocab/vocab_html.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "1.0",
3 | "truncation": null,
4 | "padding": {
5 | "strategy": "BatchLongest",
6 | "direction": "Right",
7 | "pad_to_multiple_of": null,
8 | "pad_id": 2,
9 | "pad_type_id": 0,
10 | "pad_token": ""
11 | },
12 | "added_tokens": [
13 | {
14 | "id": 0,
15 | "content": "",
16 | "single_word": false,
17 | "lstrip": false,
18 | "rstrip": false,
19 | "normalized": false,
20 | "special": true
21 | },
22 | {
23 | "id": 1,
24 | "content": "",
25 | "single_word": false,
26 | "lstrip": false,
27 | "rstrip": false,
28 | "normalized": false,
29 | "special": true
30 | },
31 | {
32 | "id": 2,
33 | "content": "",
34 | "single_word": false,
35 | "lstrip": false,
36 | "rstrip": false,
37 | "normalized": false,
38 | "special": true
39 | },
40 | {
41 | "id": 3,
42 | "content": "",
43 | "single_word": false,
44 | "lstrip": false,
45 | "rstrip": false,
46 | "normalized": false,
47 | "special": true
48 | },
49 | {
50 | "id": 4,
51 | "content": "",
52 | "single_word": false,
53 | "lstrip": false,
54 | "rstrip": false,
55 | "normalized": false,
56 | "special": true
57 | },
58 | {
59 | "id": 5,
60 | "content": "",
61 | "single_word": false,
62 | "lstrip": false,
63 | "rstrip": false,
64 | "normalized": false,
65 | "special": true
66 | },
67 | {
68 | "id": 6,
69 | "content": "[table]",
70 | "single_word": false,
71 | "lstrip": false,
72 | "rstrip": false,
73 | "normalized": false,
74 | "special": true
75 | },
76 | {
77 | "id": 7,
78 | "content": "[html]",
79 | "single_word": false,
80 | "lstrip": false,
81 | "rstrip": false,
82 | "normalized": false,
83 | "special": true
84 | },
85 | {
86 | "id": 8,
87 | "content": "[cell]",
88 | "single_word": false,
89 | "lstrip": false,
90 | "rstrip": false,
91 | "normalized": false,
92 | "special": true
93 | },
94 | {
95 | "id": 9,
96 | "content": "[bbox]",
97 | "single_word": false,
98 | "lstrip": false,
99 | "rstrip": false,
100 | "normalized": false,
101 | "special": true
102 | },
103 | {
104 | "id": 10,
105 | "content": "[cell+bbox]",
106 | "single_word": false,
107 | "lstrip": false,
108 | "rstrip": false,
109 | "normalized": false,
110 | "special": true
111 | },
112 | {
113 | "id": 11,
114 | "content": " | ",
115 | "single_word": false,
116 | "lstrip": false,
117 | "rstrip": false,
118 | "normalized": false,
119 | "special": true
120 | },
121 | {
122 | "id": 12,
123 | "content": "[] | ",
124 | "single_word": false,
125 | "lstrip": false,
126 | "rstrip": false,
127 | "normalized": false,
128 | "special": true
129 | },
130 | {
131 | "id": 13,
132 | "content": " | ",
142 | "single_word": false,
143 | "lstrip": false,
144 | "rstrip": false,
145 | "normalized": false,
146 | "special": true
147 | },
148 | {
149 | "id": 15,
150 | "content": ">[]",
151 | "single_word": false,
152 | "lstrip": false,
153 | "rstrip": false,
154 | "normalized": false,
155 | "special": true
156 | },
157 | {
158 | "id": 16,
159 | "content": "",
160 | "single_word": false,
161 | "lstrip": false,
162 | "rstrip": false,
163 | "normalized": false,
164 | "special": true
165 | },
166 | {
167 | "id": 17,
168 | "content": "
",
169 | "single_word": false,
170 | "lstrip": false,
171 | "rstrip": false,
172 | "normalized": false,
173 | "special": true
174 | },
175 | {
176 | "id": 18,
177 | "content": "",
178 | "single_word": false,
179 | "lstrip": false,
180 | "rstrip": false,
181 | "normalized": false,
182 | "special": true
183 | },
184 | {
185 | "id": 19,
186 | "content": "",
187 | "single_word": false,
188 | "lstrip": false,
189 | "rstrip": false,
190 | "normalized": false,
191 | "special": true
192 | },
193 | {
194 | "id": 20,
195 | "content": "",
196 | "single_word": false,
197 | "lstrip": false,
198 | "rstrip": false,
199 | "normalized": false,
200 | "special": true
201 | },
202 | {
203 | "id": 21,
204 | "content": "",
205 | "single_word": false,
206 | "lstrip": false,
207 | "rstrip": false,
208 | "normalized": false,
209 | "special": true
210 | },
211 | {
212 | "id": 22,
213 | "content": " rowspan=\"2\"",
214 | "single_word": false,
215 | "lstrip": false,
216 | "rstrip": false,
217 | "normalized": false,
218 | "special": true
219 | },
220 | {
221 | "id": 23,
222 | "content": " rowspan=\"3\"",
223 | "single_word": false,
224 | "lstrip": false,
225 | "rstrip": false,
226 | "normalized": false,
227 | "special": true
228 | },
229 | {
230 | "id": 24,
231 | "content": " rowspan=\"4\"",
232 | "single_word": false,
233 | "lstrip": false,
234 | "rstrip": false,
235 | "normalized": false,
236 | "special": true
237 | },
238 | {
239 | "id": 25,
240 | "content": " rowspan=\"5\"",
241 | "single_word": false,
242 | "lstrip": false,
243 | "rstrip": false,
244 | "normalized": false,
245 | "special": true
246 | },
247 | {
248 | "id": 26,
249 | "content": " rowspan=\"6\"",
250 | "single_word": false,
251 | "lstrip": false,
252 | "rstrip": false,
253 | "normalized": false,
254 | "special": true
255 | },
256 | {
257 | "id": 27,
258 | "content": " rowspan=\"7\"",
259 | "single_word": false,
260 | "lstrip": false,
261 | "rstrip": false,
262 | "normalized": false,
263 | "special": true
264 | },
265 | {
266 | "id": 28,
267 | "content": " rowspan=\"8\"",
268 | "single_word": false,
269 | "lstrip": false,
270 | "rstrip": false,
271 | "normalized": false,
272 | "special": true
273 | },
274 | {
275 | "id": 29,
276 | "content": " rowspan=\"9\"",
277 | "single_word": false,
278 | "lstrip": false,
279 | "rstrip": false,
280 | "normalized": false,
281 | "special": true
282 | },
283 | {
284 | "id": 30,
285 | "content": " rowspan=\"10\"",
286 | "single_word": false,
287 | "lstrip": false,
288 | "rstrip": false,
289 | "normalized": false,
290 | "special": true
291 | },
292 | {
293 | "id": 31,
294 | "content": " rowspan=\"11\"",
295 | "single_word": false,
296 | "lstrip": false,
297 | "rstrip": false,
298 | "normalized": false,
299 | "special": true
300 | },
301 | {
302 | "id": 32,
303 | "content": " rowspan=\"12\"",
304 | "single_word": false,
305 | "lstrip": false,
306 | "rstrip": false,
307 | "normalized": false,
308 | "special": true
309 | },
310 | {
311 | "id": 33,
312 | "content": " rowspan=\"13\"",
313 | "single_word": false,
314 | "lstrip": false,
315 | "rstrip": false,
316 | "normalized": false,
317 | "special": true
318 | },
319 | {
320 | "id": 34,
321 | "content": " rowspan=\"14\"",
322 | "single_word": false,
323 | "lstrip": false,
324 | "rstrip": false,
325 | "normalized": false,
326 | "special": true
327 | },
328 | {
329 | "id": 35,
330 | "content": " rowspan=\"15\"",
331 | "single_word": false,
332 | "lstrip": false,
333 | "rstrip": false,
334 | "normalized": false,
335 | "special": true
336 | },
337 | {
338 | "id": 36,
339 | "content": " rowspan=\"16\"",
340 | "single_word": false,
341 | "lstrip": false,
342 | "rstrip": false,
343 | "normalized": false,
344 | "special": true
345 | },
346 | {
347 | "id": 37,
348 | "content": " rowspan=\"17\"",
349 | "single_word": false,
350 | "lstrip": false,
351 | "rstrip": false,
352 | "normalized": false,
353 | "special": true
354 | },
355 | {
356 | "id": 38,
357 | "content": " rowspan=\"18\"",
358 | "single_word": false,
359 | "lstrip": false,
360 | "rstrip": false,
361 | "normalized": false,
362 | "special": true
363 | },
364 | {
365 | "id": 39,
366 | "content": " rowspan=\"19\"",
367 | "single_word": false,
368 | "lstrip": false,
369 | "rstrip": false,
370 | "normalized": false,
371 | "special": true
372 | },
373 | {
374 | "id": 40,
375 | "content": " colspan=\"2\"",
376 | "single_word": false,
377 | "lstrip": false,
378 | "rstrip": false,
379 | "normalized": false,
380 | "special": true
381 | },
382 | {
383 | "id": 41,
384 | "content": " colspan=\"3\"",
385 | "single_word": false,
386 | "lstrip": false,
387 | "rstrip": false,
388 | "normalized": false,
389 | "special": true
390 | },
391 | {
392 | "id": 42,
393 | "content": " colspan=\"4\"",
394 | "single_word": false,
395 | "lstrip": false,
396 | "rstrip": false,
397 | "normalized": false,
398 | "special": true
399 | },
400 | {
401 | "id": 43,
402 | "content": " colspan=\"5\"",
403 | "single_word": false,
404 | "lstrip": false,
405 | "rstrip": false,
406 | "normalized": false,
407 | "special": true
408 | },
409 | {
410 | "id": 44,
411 | "content": " colspan=\"6\"",
412 | "single_word": false,
413 | "lstrip": false,
414 | "rstrip": false,
415 | "normalized": false,
416 | "special": true
417 | },
418 | {
419 | "id": 45,
420 | "content": " colspan=\"7\"",
421 | "single_word": false,
422 | "lstrip": false,
423 | "rstrip": false,
424 | "normalized": false,
425 | "special": true
426 | },
427 | {
428 | "id": 46,
429 | "content": " colspan=\"8\"",
430 | "single_word": false,
431 | "lstrip": false,
432 | "rstrip": false,
433 | "normalized": false,
434 | "special": true
435 | },
436 | {
437 | "id": 47,
438 | "content": " colspan=\"9\"",
439 | "single_word": false,
440 | "lstrip": false,
441 | "rstrip": false,
442 | "normalized": false,
443 | "special": true
444 | },
445 | {
446 | "id": 48,
447 | "content": " colspan=\"10\"",
448 | "single_word": false,
449 | "lstrip": false,
450 | "rstrip": false,
451 | "normalized": false,
452 | "special": true
453 | },
454 | {
455 | "id": 49,
456 | "content": " colspan=\"11\"",
457 | "single_word": false,
458 | "lstrip": false,
459 | "rstrip": false,
460 | "normalized": false,
461 | "special": true
462 | },
463 | {
464 | "id": 50,
465 | "content": " colspan=\"12\"",
466 | "single_word": false,
467 | "lstrip": false,
468 | "rstrip": false,
469 | "normalized": false,
470 | "special": true
471 | },
472 | {
473 | "id": 51,
474 | "content": " colspan=\"13\"",
475 | "single_word": false,
476 | "lstrip": false,
477 | "rstrip": false,
478 | "normalized": false,
479 | "special": true
480 | },
481 | {
482 | "id": 52,
483 | "content": " colspan=\"14\"",
484 | "single_word": false,
485 | "lstrip": false,
486 | "rstrip": false,
487 | "normalized": false,
488 | "special": true
489 | },
490 | {
491 | "id": 53,
492 | "content": " colspan=\"15\"",
493 | "single_word": false,
494 | "lstrip": false,
495 | "rstrip": false,
496 | "normalized": false,
497 | "special": true
498 | },
499 | {
500 | "id": 54,
501 | "content": " colspan=\"16\"",
502 | "single_word": false,
503 | "lstrip": false,
504 | "rstrip": false,
505 | "normalized": false,
506 | "special": true
507 | },
508 | {
509 | "id": 55,
510 | "content": " colspan=\"17\"",
511 | "single_word": false,
512 | "lstrip": false,
513 | "rstrip": false,
514 | "normalized": false,
515 | "special": true
516 | },
517 | {
518 | "id": 56,
519 | "content": " colspan=\"18\"",
520 | "single_word": false,
521 | "lstrip": false,
522 | "rstrip": false,
523 | "normalized": false,
524 | "special": true
525 | },
526 | {
527 | "id": 57,
528 | "content": " colspan=\"19\"",
529 | "single_word": false,
530 | "lstrip": false,
531 | "rstrip": false,
532 | "normalized": false,
533 | "special": true
534 | },
535 | {
536 | "id": 58,
537 | "content": " colspan=\"25\"",
538 | "single_word": false,
539 | "lstrip": false,
540 | "rstrip": false,
541 | "normalized": false,
542 | "special": true
543 | }
544 | ],
545 | "normalizer": {
546 | "type": "Sequence",
547 | "normalizers": [
548 | {
549 | "type": "NFD"
550 | },
551 | {
552 | "type": "Lowercase"
553 | },
554 | {
555 | "type": "StripAccents"
556 | },
557 | {
558 | "type": "Strip",
559 | "strip_left": true,
560 | "strip_right": true
561 | }
562 | ]
563 | },
564 | "pre_tokenizer": {
565 | "type": "Whitespace"
566 | },
567 | "post_processor": null,
568 | "decoder": {
569 | "type": "WordPiece",
570 | "prefix": "##",
571 | "cleanup": true
572 | },
573 | "model": {
574 | "type": "WordPiece",
575 | "unk_token": "",
576 | "continuing_subword_prefix": "##",
577 | "max_input_chars_per_word": 100,
578 | "vocab": {
579 | "": 0,
580 | "": 1,
581 | "": 2,
582 | "": 3,
583 | "": 4,
584 | "": 5,
585 | "[table]": 6,
586 | "[html]": 7,
587 | "[cell]": 8,
588 | "[bbox]": 9,
589 | "[cell+bbox]": 10,
590 | " | ": 11,
591 | "[] | ": 12,
592 | " | ": 14,
594 | ">[]": 15,
595 | "": 16,
596 | "
": 17,
597 | "": 18,
598 | "": 19,
599 | "": 20,
600 | "": 21,
601 | " rowspan=\"2\"": 22,
602 | " rowspan=\"3\"": 23,
603 | " rowspan=\"4\"": 24,
604 | " rowspan=\"5\"": 25,
605 | " rowspan=\"6\"": 26,
606 | " rowspan=\"7\"": 27,
607 | " rowspan=\"8\"": 28,
608 | " rowspan=\"9\"": 29,
609 | " rowspan=\"10\"": 30,
610 | " rowspan=\"11\"": 31,
611 | " rowspan=\"12\"": 32,
612 | " rowspan=\"13\"": 33,
613 | " rowspan=\"14\"": 34,
614 | " rowspan=\"15\"": 35,
615 | " rowspan=\"16\"": 36,
616 | " rowspan=\"17\"": 37,
617 | " rowspan=\"18\"": 38,
618 | " rowspan=\"19\"": 39,
619 | " colspan=\"2\"": 40,
620 | " colspan=\"3\"": 41,
621 | " colspan=\"4\"": 42,
622 | " colspan=\"5\"": 43,
623 | " colspan=\"6\"": 44,
624 | " colspan=\"7\"": 45,
625 | " colspan=\"8\"": 46,
626 | " colspan=\"9\"": 47,
627 | " colspan=\"10\"": 48,
628 | " colspan=\"11\"": 49,
629 | " colspan=\"12\"": 50,
630 | " colspan=\"13\"": 51,
631 | " colspan=\"14\"": 52,
632 | " colspan=\"15\"": 53,
633 | " colspan=\"16\"": 54,
634 | " colspan=\"17\"": 55,
635 | " colspan=\"18\"": 56,
636 | " colspan=\"19\"": 57,
637 | " colspan=\"25\"": 58
638 | }
639 | }
640 | }
--------------------------------------------------------------------------------
/website/unitable-demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/website/unitable-demo.gif
--------------------------------------------------------------------------------
/website/unitable-demo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/website/unitable-demo.mp4
--------------------------------------------------------------------------------
/website/wandb_screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/poloclub/unitable/af1163af653e0364843fd56e1eeeb378160c2a40/website/wandb_screenshot.png
--------------------------------------------------------------------------------