├── .gitignore
├── LICENSE
├── LiLTfinetune
├── __init__.py
├── data
│ ├── __init__.py
│ ├── data_args.py
│ ├── data_collator.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── funsd.py
│ │ └── xfun.py
│ └── utils.py
├── evaluation.py
├── models
│ ├── LiLTRobertaLike
│ │ ├── __init__.py
│ │ ├── configuration_LiLTRobertaLike.py
│ │ ├── modeling_LiLTRobertaLike.py
│ │ ├── tokenization_LiLTRobertaLike.py
│ │ └── tokenization_LiLTRobertaLike_fast.py
│ ├── __init__.py
│ └── model_args.py
├── modules
│ ├── __init__.py
│ └── decoders
│ │ ├── __init__.py
│ │ └── re.py
├── trainers
│ ├── __init__.py
│ ├── funsd_trainer.py
│ └── xfun_trainer.py
└── utils.py
├── Makefile
├── README.md
├── examples
├── run_funsd.py
├── run_xfun_re.py
└── run_xfun_ser.py
├── figs
├── cl_xfund.png
├── framework.png
├── funsd.png
├── ls_xfund.png
└── mt_xfund.png
├── gen_weight_roberta_like.py
├── pyproject.toml
├── requirements.txt
├── setup.cfg
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.toptal.com/developers/gitignore/api/python,macos,pycharm+all
3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,macos,pycharm+all
4 |
5 | ### macOS ###
6 | # General
7 | .DS_Store
8 | .AppleDouble
9 | .LSOverride
10 |
11 | # Icon must end with two \r
12 | Icon
13 |
14 |
15 | # Thumbnails
16 | ._*
17 |
18 | # Files that might appear in the root of a volume
19 | .DocumentRevisions-V100
20 | .fseventsd
21 | .Spotlight-V100
22 | .TemporaryItems
23 | .Trashes
24 | .VolumeIcon.icns
25 | .com.apple.timemachine.donotpresent
26 |
27 | # Directories potentially created on remote AFP share
28 | .AppleDB
29 | .AppleDesktop
30 | Network Trash Folder
31 | Temporary Items
32 | .apdisk
33 |
34 | ### PyCharm+all ###
35 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
36 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
37 |
38 | # User-specific stuff
39 | .idea/**/workspace.xml
40 | .idea/**/tasks.xml
41 | .idea/**/usage.statistics.xml
42 | .idea/**/dictionaries
43 | .idea/**/shelf
44 |
45 | # Generated files
46 | .idea/**/contentModel.xml
47 |
48 | # Sensitive or high-churn files
49 | .idea/**/dataSources/
50 | .idea/**/dataSources.ids
51 | .idea/**/dataSources.local.xml
52 | .idea/**/sqlDataSources.xml
53 | .idea/**/dynamic.xml
54 | .idea/**/uiDesigner.xml
55 | .idea/**/dbnavigator.xml
56 |
57 | # Gradle
58 | .idea/**/gradle.xml
59 | .idea/**/libraries
60 |
61 | # Gradle and Maven with auto-import
62 | # When using Gradle or Maven with auto-import, you should exclude module files,
63 | # since they will be recreated, and may cause churn. Uncomment if using
64 | # auto-import.
65 | # .idea/artifacts
66 | # .idea/compiler.xml
67 | # .idea/jarRepositories.xml
68 | # .idea/modules.xml
69 | # .idea/*.iml
70 | # .idea/modules
71 | # *.iml
72 | # *.ipr
73 |
74 | # CMake
75 | cmake-build-*/
76 |
77 | # Mongo Explorer plugin
78 | .idea/**/mongoSettings.xml
79 |
80 | # File-based project format
81 | *.iws
82 |
83 | # IntelliJ
84 | out/
85 |
86 | # mpeltonen/sbt-idea plugin
87 | .idea_modules/
88 |
89 | # JIRA plugin
90 | atlassian-ide-plugin.xml
91 |
92 | # Cursive Clojure plugin
93 | .idea/replstate.xml
94 |
95 | # Crashlytics plugin (for Android Studio and IntelliJ)
96 | com_crashlytics_export_strings.xml
97 | crashlytics.properties
98 | crashlytics-build.properties
99 | fabric.properties
100 |
101 | # Editor-based Rest Client
102 | .idea/httpRequests
103 |
104 | # Android studio 3.1+ serialized cache file
105 | .idea/caches/build_file_checksums.ser
106 |
107 | ### PyCharm+all Patch ###
108 | # Ignores the whole .idea folder and all .iml files
109 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360
110 |
111 | .idea/
112 |
113 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023
114 |
115 | *.iml
116 | modules.xml
117 | .idea/misc.xml
118 | *.ipr
119 |
120 | # Sonarlint plugin
121 | .idea/sonarlint
122 |
123 | ### Python ###
124 | # Byte-compiled / optimized / DLL files
125 | __pycache__/
126 | *.py[cod]
127 | *$py.class
128 |
129 | # C extensions
130 | *.so
131 |
132 | # Distribution / packaging
133 | .Python
134 | build/
135 | develop-eggs/
136 | dist/
137 | downloads/
138 | eggs/
139 | .eggs/
140 | parts/
141 | sdist/
142 | var/
143 | wheels/
144 | pip-wheel-metadata/
145 | share/python-wheels/
146 | *.egg-info/
147 | .installed.cfg
148 | *.egg
149 | MANIFEST
150 |
151 | # PyInstaller
152 | # Usually these files are written by a python script from a template
153 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
154 | *.manifest
155 | *.spec
156 |
157 | # Installer logs
158 | pip-log.txt
159 | pip-delete-this-directory.txt
160 |
161 | # Unit test / coverage reports
162 | htmlcov/
163 | .tox/
164 | .nox/
165 | .coverage
166 | .coverage.*
167 | .cache
168 | nosetests.xml
169 | coverage.xml
170 | *.cover
171 | *.py,cover
172 | .hypothesis/
173 | .pytest_cache/
174 | pytestdebug.log
175 |
176 | # Translations
177 | *.mo
178 | *.pot
179 |
180 | # Django stuff:
181 | *.log
182 | local_settings.py
183 | db.sqlite3
184 | db.sqlite3-journal
185 |
186 | # Flask stuff:
187 | instance/
188 | .webassets-cache
189 |
190 | # Scrapy stuff:
191 | .scrapy
192 |
193 | # Sphinx documentation
194 | docs/_build/
195 | doc/_build/
196 |
197 | # PyBuilder
198 | target/
199 |
200 | # Jupyter Notebook
201 | .ipynb_checkpoints
202 |
203 | # IPython
204 | profile_default/
205 | ipython_config.py
206 |
207 | # pyenv
208 | .python-version
209 |
210 | # pipenv
211 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
212 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
213 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
214 | # install all needed dependencies.
215 | #Pipfile.lock
216 |
217 | # poetry
218 | #poetry.lock
219 |
220 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
221 | __pypackages__/
222 |
223 | # Celery stuff
224 | celerybeat-schedule
225 | celerybeat.pid
226 |
227 | # SageMath parsed files
228 | *.sage.py
229 |
230 | # Environments
231 | # .env
232 | .env/
233 | .venv/
234 | env/
235 | venv/
236 | ENV/
237 | env.bak/
238 | venv.bak/
239 | pythonenv*
240 |
241 | # Spyder project settings
242 | .spyderproject
243 | .spyproject
244 |
245 | # Rope project settings
246 | .ropeproject
247 |
248 | # mkdocs documentation
249 | /site
250 |
251 | # mypy
252 | .mypy_cache/
253 | .dmypy.json
254 | dmypy.json
255 |
256 | # Pyre type checker
257 | .pyre/
258 |
259 | # pytype static type analyzer
260 | .pytype/
261 |
262 | # operating system-related files
263 | # file properties cache/storage on macOS
264 | *.DS_Store
265 | # thumbnail cache on Windows
266 | Thumbs.db
267 |
268 | # profiling data
269 | .prof
270 |
271 | LiLTfinetune/data/datasets/*.lock
272 | # End of https://www.toptal.com/developers/gitignore/api/python,macos,pycharm+all
273 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Jiapeng Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/LiLTfinetune/__init__.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | from transformers import CONFIG_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_NAMES_MAPPING, TOKENIZER_MAPPING
4 | from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, BertConverter, RobertaConverter, XLMRobertaConverter
5 | from transformers.models.auto.modeling_auto import auto_class_factory
6 |
7 | from .models.LiLTRobertaLike import (
8 | LiLTRobertaLikeConfig,
9 | LiLTRobertaLikeForRelationExtraction,
10 | LiLTRobertaLikeForTokenClassification,
11 | LiLTRobertaLikeTokenizer,
12 | LiLTRobertaLikeTokenizerFast,
13 | )
14 |
15 | CONFIG_MAPPING.update([("liltrobertalike", LiLTRobertaLikeConfig),])
16 | MODEL_NAMES_MAPPING.update([("liltrobertalike", "LiLTRobertaLike"),])
17 | TOKENIZER_MAPPING.update(
18 | [
19 | (LiLTRobertaLikeConfig, (LiLTRobertaLikeTokenizer, LiLTRobertaLikeTokenizerFast)),
20 | ]
21 | )
22 |
23 | with open('tag.txt', 'r') as tagf:
24 | TAG = tagf.read().lower()
25 | assert TAG == 'monolingual' or TAG == 'multilingual', 'TAG is wrong. It should be monolingual or multilingual.'
26 | if TAG == 'monolingual':
27 | SLOW_TO_FAST_CONVERTERS.update({"LiLTRobertaLikeTokenizer": RobertaConverter,})
28 | elif TAG == 'multilingual':
29 | SLOW_TO_FAST_CONVERTERS.update({"LiLTRobertaLikeTokenizer": XLMRobertaConverter,})
30 |
31 | MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.update(
32 | [(LiLTRobertaLikeConfig, LiLTRobertaLikeForTokenClassification),]
33 | )
34 |
35 | MODEL_FOR_RELATION_EXTRACTION_MAPPING = OrderedDict(
36 | [(LiLTRobertaLikeConfig, LiLTRobertaLikeForRelationExtraction),]
37 | )
38 |
39 | AutoModelForTokenClassification = auto_class_factory(
40 | "AutoModelForTokenClassification", MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification"
41 | )
42 |
43 | AutoModelForRelationExtraction = auto_class_factory(
44 | "AutoModelForRelationExtraction", MODEL_FOR_RELATION_EXTRACTION_MAPPING, head_doc="relation extraction"
45 | )
46 |
--------------------------------------------------------------------------------
/LiLTfinetune/data/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | from .data_collator import DataCollatorForKeyValueExtraction
3 | from .datasets import *
4 |
--------------------------------------------------------------------------------
/LiLTfinetune/data/data_args.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 |
5 | @dataclass
6 | class DataTrainingArguments:
7 | """
8 | Arguments pertaining to what data we are going to input our model for training and eval.
9 | """
10 |
11 | task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."})
12 | dataset_name: Optional[str] = field(
13 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
14 | )
15 | dataset_config_name: Optional[str] = field(
16 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
17 | )
18 | train_file: Optional[str] = field(
19 | default=None, metadata={"help": "The input training data file (a csv or JSON file)."}
20 | )
21 | validation_file: Optional[str] = field(
22 | default=None,
23 | metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."},
24 | )
25 | test_file: Optional[str] = field(
26 | default=None,
27 | metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
28 | )
29 | overwrite_cache: bool = field(
30 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
31 | )
32 | preprocessing_num_workers: Optional[int] = field(
33 | default=None,
34 | metadata={"help": "The number of processes to use for the preprocessing."},
35 | )
36 | pad_to_max_length: bool = field(
37 | default=True,
38 | metadata={
39 | "help": "Whether to pad all samples to model maximum sentence length. "
40 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
41 | "efficient on GPU but very bad for TPU."
42 | },
43 | )
44 | max_train_samples: Optional[int] = field(
45 | default=None,
46 | metadata={
47 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
48 | "value if set."
49 | },
50 | )
51 | max_val_samples: Optional[int] = field(
52 | default=None,
53 | metadata={
54 | "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
55 | "value if set."
56 | },
57 | )
58 | max_test_samples: Optional[int] = field(
59 | default=None,
60 | metadata={
61 | "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
62 | "value if set."
63 | },
64 | )
65 | label_all_tokens: bool = field(
66 | default=False,
67 | metadata={
68 | "help": "Whether to put the label for one word on all tokens of generated by that word or just on the "
69 | "one (in which case the other tokens will have a padding index)."
70 | },
71 | )
72 | return_entity_level_metrics: bool = field(
73 | default=False,
74 | metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."},
75 | )
76 |
77 |
78 | @dataclass
79 | class XFUNDataTrainingArguments(DataTrainingArguments):
80 | lang: Optional[str] = field(default="en")
81 | additional_langs: Optional[str] = field(default=None)
82 |
--------------------------------------------------------------------------------
/LiLTfinetune/data/data_collator.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional, Union
3 |
4 | import torch
5 |
6 | from detectron2.structures import ImageList
7 | from transformers import PreTrainedTokenizerBase
8 | from transformers.file_utils import PaddingStrategy
9 |
10 |
11 | @dataclass
12 | class DataCollatorForKeyValueExtraction:
13 | """
14 | Data collator that will dynamically pad the inputs received, as well as the labels.
15 |
16 | Args:
17 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
18 | The tokenizer used for encoding the data.
19 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
20 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
21 | among:
22 |
23 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
24 | sequence if provided).
25 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
26 | maximum acceptable input length for the model if that argument is not provided.
27 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
28 | different lengths).
29 | max_length (:obj:`int`, `optional`):
30 | Maximum length of the returned list and optionally padding length (see above).
31 | pad_to_multiple_of (:obj:`int`, `optional`):
32 | If set will pad the sequence to a multiple of the provided value.
33 |
34 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
35 | 7.5 (Volta).
36 | label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
37 | The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
38 | """
39 |
40 | tokenizer: PreTrainedTokenizerBase
41 | padding: Union[bool, str, PaddingStrategy] = True
42 | max_length: Optional[int] = None
43 | pad_to_multiple_of: Optional[int] = None
44 | label_pad_token_id: int = -100
45 |
46 | def __call__(self, features):
47 | label_name = "label" if "label" in features[0].keys() else "labels"
48 | labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
49 |
50 | has_image_input = "image" in features[0]
51 | has_bbox_input = "bbox" in features[0]
52 | if has_image_input:
53 | image = ImageList.from_tensors([torch.tensor(feature["image"]) for feature in features], 32)
54 | for feature in features:
55 | del feature["image"]
56 | batch = self.tokenizer.pad(
57 | features,
58 | padding=self.padding,
59 | max_length=self.max_length,
60 | pad_to_multiple_of=self.pad_to_multiple_of,
61 | # Conversion to tensors will fail if we have labels as they are not of the same length yet.
62 | return_tensors="pt" if labels is None else None,
63 | )
64 |
65 | if labels is None:
66 | return batch
67 |
68 | sequence_length = torch.tensor(batch["input_ids"]).shape[1]
69 | padding_side = self.tokenizer.padding_side
70 | if padding_side == "right":
71 | batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
72 | if has_bbox_input:
73 | batch["bbox"] = [bbox + [[0, 0, 0, 0]] * (sequence_length - len(bbox)) for bbox in batch["bbox"]]
74 | else:
75 | batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
76 | if has_bbox_input:
77 | batch["bbox"] = [[[0, 0, 0, 0]] * (sequence_length - len(bbox)) + bbox for bbox in batch["bbox"]]
78 |
79 | batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) else v for k, v in batch.items()}
80 | if has_image_input:
81 | batch["image"] = image
82 | return batch
83 |
--------------------------------------------------------------------------------
/LiLTfinetune/data/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/LiLTfinetune/data/datasets/__init__.py
--------------------------------------------------------------------------------
/LiLTfinetune/data/datasets/funsd.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | import json
4 | import os
5 |
6 | import datasets
7 |
8 | from LiLTfinetune.data.utils import load_image, normalize_bbox
9 |
10 |
11 | logger = datasets.logging.get_logger(__name__)
12 |
13 |
14 | _CITATION = """\
15 | @article{Jaume2019FUNSDAD,
16 | title={FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents},
17 | author={Guillaume Jaume and H. K. Ekenel and J. Thiran},
18 | journal={2019 International Conference on Document Analysis and Recognition Workshops (ICDARW)},
19 | year={2019},
20 | volume={2},
21 | pages={1-6}
22 | }
23 | """
24 |
25 | _DESCRIPTION = """\
26 | https://guillaumejaume.github.io/FUNSD/
27 | """
28 |
29 |
30 | class FunsdConfig(datasets.BuilderConfig):
31 | """BuilderConfig for FUNSD"""
32 |
33 | def __init__(self, **kwargs):
34 | """BuilderConfig for FUNSD.
35 |
36 | Args:
37 | **kwargs: keyword arguments forwarded to super.
38 | """
39 | super(FunsdConfig, self).__init__(**kwargs)
40 |
41 |
42 | class Funsd(datasets.GeneratorBasedBuilder):
43 | """Conll2003 dataset."""
44 |
45 | BUILDER_CONFIGS = [
46 | FunsdConfig(name="funsd", version=datasets.Version("1.0.0"), description="FUNSD dataset"),
47 | ]
48 |
49 | def _info(self):
50 | return datasets.DatasetInfo(
51 | description=_DESCRIPTION,
52 | features=datasets.Features(
53 | {
54 | "id": datasets.Value("string"),
55 | "tokens": datasets.Sequence(datasets.Value("string")),
56 | "bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
57 | "ner_tags": datasets.Sequence(
58 | datasets.features.ClassLabel(
59 | names=["O", "B-HEADER", "I-HEADER", "B-QUESTION", "I-QUESTION", "B-ANSWER", "I-ANSWER"]
60 | )
61 | ),
62 | "image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
63 | }
64 | ),
65 | supervised_keys=None,
66 | homepage="https://guillaumejaume.github.io/FUNSD/",
67 | citation=_CITATION,
68 | )
69 |
70 | def _split_generators(self, dl_manager):
71 | """Returns SplitGenerators."""
72 | downloaded_file = dl_manager.download_and_extract("https://guillaumejaume.github.io/FUNSD/dataset.zip")
73 | return [
74 | datasets.SplitGenerator(
75 | name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"}
76 | ),
77 | datasets.SplitGenerator(
78 | name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"}
79 | ),
80 | ]
81 |
82 | def _generate_examples(self, filepath):
83 | logger.info("⏳ Generating examples from = %s", filepath)
84 | ann_dir = os.path.join(filepath, "annotations")
85 | img_dir = os.path.join(filepath, "images")
86 | for guid, file in enumerate(sorted(os.listdir(ann_dir))):
87 | tokens = []
88 | bboxes = []
89 | ner_tags = []
90 |
91 | file_path = os.path.join(ann_dir, file)
92 | with open(file_path, "r", encoding="utf8") as f:
93 | data = json.load(f)
94 | image_path = os.path.join(img_dir, file)
95 | image_path = image_path.replace("json", "png")
96 | image, size = load_image(image_path)
97 | for item in data["form"]:
98 | words, label = item["words"], item["label"]
99 | words = [w for w in words if w["text"].strip() != ""]
100 | if len(words) == 0:
101 | continue
102 | if label == "other":
103 | for w in words:
104 | tokens.append(w["text"])
105 | ner_tags.append("O")
106 | bboxes.append(normalize_bbox(item["box"], size))
107 | else:
108 | tokens.append(words[0]["text"])
109 | ner_tags.append("B-" + label.upper())
110 | bboxes.append(normalize_bbox(item["box"], size))
111 | for w in words[1:]:
112 | tokens.append(w["text"])
113 | ner_tags.append("I-" + label.upper())
114 | bboxes.append(normalize_bbox(item["box"], size))
115 |
116 | yield guid, {"id": str(guid), "tokens": tokens, "bboxes": bboxes, "ner_tags": ner_tags, "image": image}
117 |
--------------------------------------------------------------------------------
/LiLTfinetune/data/datasets/xfun.py:
--------------------------------------------------------------------------------
1 | # Lint as: python3
2 | import json
3 | import logging
4 | import os
5 |
6 | import datasets
7 |
8 | from LiLTfinetune.data.utils import load_image, merge_bbox, normalize_bbox, simplify_bbox
9 | from transformers import AutoTokenizer
10 |
11 |
12 | _URL = "https://github.com/doc-analysis/XFUN/releases/download/v1.0/"
13 |
14 | _LANG = ["zh", "de", "es", "fr", "en", "it", "ja", "pt"]
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | class XFUNConfig(datasets.BuilderConfig):
19 | """BuilderConfig for XFUN."""
20 |
21 | def __init__(self, lang, additional_langs=None, **kwargs):
22 | """
23 | Args:
24 | lang: string, language for the input text
25 | **kwargs: keyword arguments forwarded to super.
26 | """
27 | super(XFUNConfig, self).__init__(**kwargs)
28 | self.lang = lang
29 | self.additional_langs = additional_langs
30 |
31 |
32 | class XFUN(datasets.GeneratorBasedBuilder):
33 | """XFUN dataset."""
34 |
35 | BUILDER_CONFIGS = [XFUNConfig(name=f"xfun.{lang}", lang=lang) for lang in _LANG]
36 |
37 | tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
38 |
39 | def _info(self):
40 | return datasets.DatasetInfo(
41 | features=datasets.Features(
42 | {
43 | "id": datasets.Value("string"),
44 | "input_ids": datasets.Sequence(datasets.Value("int64")),
45 | "bbox": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
46 | "labels": datasets.Sequence(
47 | datasets.ClassLabel(
48 | names=["O", "B-QUESTION", "B-ANSWER", "B-HEADER", "I-ANSWER", "I-QUESTION", "I-HEADER"]
49 | )
50 | ),
51 | "image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
52 | "entities": datasets.Sequence(
53 | {
54 | "start": datasets.Value("int64"),
55 | "end": datasets.Value("int64"),
56 | "label": datasets.ClassLabel(names=["HEADER", "QUESTION", "ANSWER"]),
57 | }
58 | ),
59 | "relations": datasets.Sequence(
60 | {
61 | "head": datasets.Value("int64"),
62 | "tail": datasets.Value("int64"),
63 | "start_index": datasets.Value("int64"),
64 | "end_index": datasets.Value("int64"),
65 | }
66 | ),
67 | }
68 | ),
69 | supervised_keys=None,
70 | )
71 |
72 | def _split_generators(self, dl_manager):
73 | """Returns SplitGenerators."""
74 | # urls_to_download = {
75 | # "train": [f"{_URL}{self.config.lang}.train.json", f"{_URL}{self.config.lang}.train.zip"],
76 | # "val": [f"{_URL}{self.config.lang}.val.json", f"{_URL}{self.config.lang}.val.zip"],
77 | # # "test": [f"{_URL}{self.config.lang}.test.json", f"{_URL}{self.config.lang}.test.zip"],
78 | # }
79 | # downloaded_files = dl_manager.download_and_extract(urls_to_download)
80 | # train_files_for_many_langs = [downloaded_files["train"]]
81 | # val_files_for_many_langs = [downloaded_files["val"]]
82 | # # test_files_for_many_langs = [downloaded_files["test"]]
83 | file_dir = 'xfund&funsd/'
84 | train_files_for_many_langs = [[file_dir+f"{self.config.lang}.train.json", file_dir+f"{self.config.lang}"]]
85 | val_files_for_many_langs = [[file_dir+f"{self.config.lang}.val.json", file_dir+f"{self.config.lang}"]]
86 |
87 | if self.config.additional_langs:
88 | additional_langs = self.config.additional_langs.split("+")
89 | if "all" in additional_langs:
90 | additional_langs = [lang for lang in _LANG if lang != self.config.lang]
91 | for lang in additional_langs:
92 | # urls_to_download = {"train": [f"{_URL}{lang}.train.json", f"{_URL}{lang}.train.zip"]}
93 | # additional_downloaded_files = dl_manager.download_and_extract(urls_to_download)
94 | # train_files_for_many_langs.append(additional_downloaded_files["train"])
95 | train_files_for_many_langs.append([file_dir+f"{lang}.train.json", file_dir+f"{lang}"])
96 |
97 |
98 | logger.info(f"Training on {self.config.lang} with additional langs({self.config.additional_langs})")
99 | logger.info(f"Evaluating on {self.config.lang}")
100 | logger.info(f"Testing on {self.config.lang}")
101 | return [
102 | datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": train_files_for_many_langs}),
103 | datasets.SplitGenerator(
104 | name=datasets.Split.VALIDATION, gen_kwargs={"filepaths": val_files_for_many_langs}
105 | ),
106 | # datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": test_files_for_many_langs}),
107 | ]
108 |
109 | def _generate_examples(self, filepaths):
110 | for filepath in filepaths:
111 | logger.info("Generating examples from = %s", filepath)
112 | with open(filepath[0], "r") as f:
113 | data = json.load(f)
114 |
115 | for doc in data["documents"]:
116 | doc["img"]["fpath"] = os.path.join(filepath[1], doc["img"]["fname"])
117 | image, size = load_image(doc["img"]["fpath"])
118 | document = doc["document"]
119 | tokenized_doc = {"input_ids": [], "bbox": [], "labels": []}
120 | entities = []
121 | relations = []
122 | id2label = {}
123 | entity_id_to_index_map = {}
124 | empty_entity = set()
125 | for line in document:
126 | if len(line["text"]) == 0:
127 | empty_entity.add(line["id"])
128 | continue
129 | id2label[line["id"]] = line["label"]
130 | relations.extend([tuple(sorted(l)) for l in line["linking"]])
131 | if '/en' in filepath[0]:
132 | tokenized_inputs = self.tokenizer(
133 | ' '.join([q['text'].replace(u'\uf703','') for q in line['words']]),
134 | add_special_tokens=False,
135 | return_offsets_mapping=True,
136 | return_attention_mask=False,
137 | )
138 | else:
139 | tokenized_inputs = self.tokenizer(
140 | line["text"],
141 | add_special_tokens=False,
142 | return_offsets_mapping=True,
143 | return_attention_mask=False,
144 | )
145 | text_length = 0
146 | ocr_length = 0
147 | bbox = []
148 | last_box = None
149 | for token_id, offset in zip(tokenized_inputs["input_ids"], tokenized_inputs["offset_mapping"]):
150 | if token_id == 6:
151 | bbox.append(None)
152 | continue
153 | text_length += offset[1] - offset[0]
154 | tmp_box = []
155 | while ocr_length < text_length:
156 | ocr_word = line["words"].pop(0)
157 | ocr_length += len(
158 | self.tokenizer._tokenizer.normalizer.normalize_str(ocr_word["text"].strip())
159 | )
160 | tmp_box.append(simplify_bbox(line["box"]))
161 | if len(tmp_box) == 0:
162 | tmp_box = last_box
163 | bbox.append(normalize_bbox(merge_bbox(tmp_box), size))
164 | last_box = tmp_box
165 | bbox = [
166 | [bbox[i + 1][0], bbox[i + 1][1], bbox[i + 1][0], bbox[i + 1][1]] if b is None else b
167 | for i, b in enumerate(bbox)
168 | ]
169 | if line["label"] == "other":
170 | label = ["O"] * len(bbox)
171 | else:
172 | label = [f"I-{line['label'].upper()}"] * len(bbox)
173 | label[0] = f"B-{line['label'].upper()}"
174 | tokenized_inputs.update({"bbox": bbox, "labels": label})
175 | if label[0] != "O":
176 | entity_id_to_index_map[line["id"]] = len(entities)
177 | entities.append(
178 | {
179 | "start": len(tokenized_doc["input_ids"]),
180 | "end": len(tokenized_doc["input_ids"]) + len(tokenized_inputs["input_ids"]),
181 | "label": line["label"].upper(),
182 | }
183 | )
184 | for i in tokenized_doc:
185 | tokenized_doc[i] = tokenized_doc[i] + tokenized_inputs[i]
186 | relations = list(set(relations))
187 | relations = [rel for rel in relations if rel[0] not in empty_entity and rel[1] not in empty_entity]
188 | kvrelations = []
189 | for rel in relations:
190 | pair = [id2label[rel[0]], id2label[rel[1]]]
191 | if pair == ["question", "answer"]:
192 | kvrelations.append(
193 | {"head": entity_id_to_index_map[rel[0]], "tail": entity_id_to_index_map[rel[1]]}
194 | )
195 | elif pair == ["answer", "question"]:
196 | kvrelations.append(
197 | {"head": entity_id_to_index_map[rel[1]], "tail": entity_id_to_index_map[rel[0]]}
198 | )
199 | else:
200 | continue
201 |
202 | def get_relation_span(rel):
203 | bound = []
204 | for entity_index in [rel["head"], rel["tail"]]:
205 | bound.append(entities[entity_index]["start"])
206 | bound.append(entities[entity_index]["end"])
207 | return min(bound), max(bound)
208 |
209 | relations = sorted(
210 | [
211 | {
212 | "head": rel["head"],
213 | "tail": rel["tail"],
214 | "start_index": get_relation_span(rel)[0],
215 | "end_index": get_relation_span(rel)[1],
216 | }
217 | for rel in kvrelations
218 | ],
219 | key=lambda x: x["head"],
220 | )
221 | chunk_size = 512
222 | for chunk_id, index in enumerate(range(0, len(tokenized_doc["input_ids"]), chunk_size)):
223 | item = {}
224 | for k in tokenized_doc:
225 | item[k] = tokenized_doc[k][index : index + chunk_size]
226 | entities_in_this_span = []
227 | global_to_local_map = {}
228 | for entity_id, entity in enumerate(entities):
229 | if (
230 | index <= entity["start"] < index + chunk_size
231 | and index <= entity["end"] < index + chunk_size
232 | ):
233 | entity["start"] = entity["start"] - index
234 | entity["end"] = entity["end"] - index
235 | global_to_local_map[entity_id] = len(entities_in_this_span)
236 | entities_in_this_span.append(entity)
237 | relations_in_this_span = []
238 | for relation in relations:
239 | if (
240 | index <= relation["start_index"] < index + chunk_size
241 | and index <= relation["end_index"] < index + chunk_size
242 | ):
243 | relations_in_this_span.append(
244 | {
245 | "head": global_to_local_map[relation["head"]],
246 | "tail": global_to_local_map[relation["tail"]],
247 | "start_index": relation["start_index"] - index,
248 | "end_index": relation["end_index"] - index,
249 | }
250 | )
251 | item.update(
252 | {
253 | "id": f"{doc['id']}_{chunk_id}",
254 | "image": image,
255 | "entities": entities_in_this_span,
256 | "relations": relations_in_this_span,
257 | }
258 | )
259 | yield f"{doc['id']}_{chunk_id}", item
260 |
--------------------------------------------------------------------------------
/LiLTfinetune/data/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from detectron2.data.detection_utils import read_image
4 | from detectron2.data.transforms import ResizeTransform, TransformList
5 |
6 |
7 | def normalize_bbox(bbox, size):
8 | return [
9 | int(1000 * bbox[0] / size[0]),
10 | int(1000 * bbox[1] / size[1]),
11 | int(1000 * bbox[2] / size[0]),
12 | int(1000 * bbox[3] / size[1]),
13 | ]
14 |
15 |
16 | def simplify_bbox(bbox):
17 | return [
18 | min(bbox[0::2]),
19 | min(bbox[1::2]),
20 | max(bbox[2::2]),
21 | max(bbox[3::2]),
22 | ]
23 |
24 |
25 | def merge_bbox(bbox_list):
26 | x0, y0, x1, y1 = list(zip(*bbox_list))
27 | return [min(x0), min(y0), max(x1), max(y1)]
28 |
29 |
30 | def load_image(image_path):
31 | image = read_image(image_path, format="BGR")
32 | h = image.shape[0]
33 | w = image.shape[1]
34 | img_trans = TransformList([ResizeTransform(h=h, w=w, new_h=224, new_w=224)])
35 | image = torch.tensor(img_trans.apply_image(image).copy()).permute(2, 0, 1) # copy to make it writeable
36 | return image, (w, h)
37 |
--------------------------------------------------------------------------------
/LiLTfinetune/evaluation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 |
4 | import numpy as np
5 |
6 | from transformers.utils import logging
7 |
8 |
9 | logger = logging.get_logger(__name__)
10 |
11 |
12 | PREFIX_CHECKPOINT_DIR = "checkpoint"
13 | _re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
14 |
15 |
16 | def get_last_checkpoint(folder):
17 | content = os.listdir(folder)
18 | checkpoints = [
19 | path
20 | for path in content
21 | if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))
22 | ]
23 | if len(checkpoints) == 0:
24 | return
25 | return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
26 |
27 |
28 | def re_score(pred_relations, gt_relations, mode="strict"):
29 | """Evaluate RE predictions
30 |
31 | Args:
32 | pred_relations (list) : list of list of predicted relations (several relations in each sentence)
33 | gt_relations (list) : list of list of ground truth relations
34 |
35 | rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
36 | "tail": (start_idx (inclusive), end_idx (exclusive)),
37 | "head_type": ent_type,
38 | "tail_type": ent_type,
39 | "type": rel_type}
40 |
41 | vocab (Vocab) : dataset vocabulary
42 | mode (str) : in 'strict' or 'boundaries'"""
43 |
44 | assert mode in ["strict", "boundaries"]
45 |
46 | relation_types = [v for v in [0, 1] if not v == 0]
47 | scores = {rel: {"tp": 0, "fp": 0, "fn": 0} for rel in relation_types + ["ALL"]}
48 |
49 | # Count GT relations and Predicted relations
50 | n_sents = len(gt_relations)
51 | n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
52 | n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
53 |
54 | # Count TP, FP and FN per type
55 | for pred_sent, gt_sent in zip(pred_relations, gt_relations):
56 | for rel_type in relation_types:
57 | # strict mode takes argument types into account
58 | if mode == "strict":
59 | pred_rels = {
60 | (rel["head"], rel["head_type"], rel["tail"], rel["tail_type"])
61 | for rel in pred_sent
62 | if rel["type"] == rel_type
63 | }
64 | gt_rels = {
65 | (rel["head"], rel["head_type"], rel["tail"], rel["tail_type"])
66 | for rel in gt_sent
67 | if rel["type"] == rel_type
68 | }
69 |
70 | # boundaries mode only takes argument spans into account
71 | elif mode == "boundaries":
72 | pred_rels = {(rel["head"], rel["tail"]) for rel in pred_sent if rel["type"] == rel_type}
73 | gt_rels = {(rel["head"], rel["tail"]) for rel in gt_sent if rel["type"] == rel_type}
74 |
75 | scores[rel_type]["tp"] += len(pred_rels & gt_rels)
76 | scores[rel_type]["fp"] += len(pred_rels - gt_rels)
77 | scores[rel_type]["fn"] += len(gt_rels - pred_rels)
78 |
79 | # Compute per entity Precision / Recall / F1
80 | for rel_type in scores.keys():
81 | if scores[rel_type]["tp"]:
82 | scores[rel_type]["p"] = scores[rel_type]["tp"] / (scores[rel_type]["fp"] + scores[rel_type]["tp"])
83 | scores[rel_type]["r"] = scores[rel_type]["tp"] / (scores[rel_type]["fn"] + scores[rel_type]["tp"])
84 | else:
85 | scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
86 |
87 | if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
88 | scores[rel_type]["f1"] = (
89 | 2 * scores[rel_type]["p"] * scores[rel_type]["r"] / (scores[rel_type]["p"] + scores[rel_type]["r"])
90 | )
91 | else:
92 | scores[rel_type]["f1"] = 0
93 |
94 | # Compute micro F1 Scores
95 | tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
96 | fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
97 | fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
98 |
99 | if tp:
100 | precision = tp / (tp + fp)
101 | recall = tp / (tp + fn)
102 | f1 = 2 * precision * recall / (precision + recall)
103 |
104 | else:
105 | precision, recall, f1 = 0, 0, 0
106 |
107 | scores["ALL"]["p"] = precision
108 | scores["ALL"]["r"] = recall
109 | scores["ALL"]["f1"] = f1
110 | scores["ALL"]["tp"] = tp
111 | scores["ALL"]["fp"] = fp
112 | scores["ALL"]["fn"] = fn
113 |
114 | # Compute Macro F1 Scores
115 | scores["ALL"]["Macro_f1"] = np.mean([scores[ent_type]["f1"] for ent_type in relation_types])
116 | scores["ALL"]["Macro_p"] = np.mean([scores[ent_type]["p"] for ent_type in relation_types])
117 | scores["ALL"]["Macro_r"] = np.mean([scores[ent_type]["r"] for ent_type in relation_types])
118 |
119 | logger.info(f"RE Evaluation in *** {mode.upper()} *** mode")
120 |
121 | logger.info(
122 | "processed {} sentences with {} relations; found: {} relations; correct: {}.".format(
123 | n_sents, n_rels, n_found, tp
124 | )
125 | )
126 | logger.info(
127 | "\tALL\t TP: {};\tFP: {};\tFN: {}".format(scores["ALL"]["tp"], scores["ALL"]["fp"], scores["ALL"]["fn"])
128 | )
129 | logger.info("\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(precision, recall, f1))
130 | logger.info(
131 | "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format(
132 | scores["ALL"]["Macro_p"], scores["ALL"]["Macro_r"], scores["ALL"]["Macro_f1"]
133 | )
134 | )
135 |
136 | for rel_type in relation_types:
137 | logger.info(
138 | "\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format(
139 | rel_type,
140 | scores[rel_type]["tp"],
141 | scores[rel_type]["fp"],
142 | scores[rel_type]["fn"],
143 | scores[rel_type]["p"],
144 | scores[rel_type]["r"],
145 | scores[rel_type]["f1"],
146 | scores[rel_type]["tp"] + scores[rel_type]["fp"],
147 | )
148 | )
149 |
150 | return scores
151 |
--------------------------------------------------------------------------------
/LiLTfinetune/models/LiLTRobertaLike/__init__.py:
--------------------------------------------------------------------------------
1 | from .configuration_LiLTRobertaLike import LiLTRobertaLikeConfig
2 | from .modeling_LiLTRobertaLike import LiLTRobertaLikeForRelationExtraction, LiLTRobertaLikeForTokenClassification, LiLTRobertaLikeModel
3 | from .tokenization_LiLTRobertaLike import LiLTRobertaLikeTokenizer
4 | from .tokenization_LiLTRobertaLike_fast import LiLTRobertaLikeTokenizerFast
5 |
--------------------------------------------------------------------------------
/LiLTfinetune/models/LiLTRobertaLike/configuration_LiLTRobertaLike.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from collections import OrderedDict
3 | from typing import Any, List, Mapping, Optional
4 |
5 | from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
6 |
7 | from transformers.utils import logging
8 | from transformers import RobertaConfig, XLMRobertaConfig
9 |
10 | logger = logging.get_logger(__name__)
11 |
12 | with open('tag.txt', 'r') as tagf:
13 | TAG = tagf.read().lower()
14 | assert TAG == 'monolingual' or TAG == 'multilingual', 'TAG is wrong. It should be monolingual or multilingual.'
15 |
16 | if TAG == 'monolingual':
17 | class LiLTRobertaLikeConfig(RobertaConfig):
18 | model_type = "liltrobertalike"
19 |
20 | def __init__(
21 | self,
22 | channel_shrink_ratio=4,
23 | max_2d_position_embeddings=1024,
24 | **kwargs
25 | ):
26 | super().__init__(
27 | **kwargs,
28 | )
29 | self.channel_shrink_ratio = channel_shrink_ratio
30 | self.max_2d_position_embeddings = max_2d_position_embeddings
31 |
32 | elif TAG == 'multilingual':
33 | class LiLTRobertaLikeConfig(XLMRobertaConfig):
34 | model_type = "liltrobertalike"
35 |
36 | def __init__(
37 | self,
38 | channel_shrink_ratio=4,
39 | max_2d_position_embeddings=1024,
40 | **kwargs
41 | ):
42 | super().__init__(
43 | **kwargs,
44 | )
45 | self.channel_shrink_ratio = channel_shrink_ratio
46 | self.max_2d_position_embeddings = max_2d_position_embeddings
47 |
--------------------------------------------------------------------------------
/LiLTfinetune/models/LiLTRobertaLike/modeling_LiLTRobertaLike.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.utils.checkpoint
6 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7 | from transformers.activations import ACT2FN, gelu
8 | from transformers.file_utils import (
9 | add_code_sample_docstrings,
10 | add_start_docstrings,
11 | add_start_docstrings_to_model_forward,
12 | replace_return_docstrings,
13 | )
14 | from transformers.modeling_outputs import (
15 | BaseModelOutputWithPastAndCrossAttentions,
16 | BaseModelOutputWithPoolingAndCrossAttentions,
17 | CausalLMOutputWithCrossAttentions,
18 | MaskedLMOutput,
19 | MultipleChoiceModelOutput,
20 | QuestionAnsweringModelOutput,
21 | SequenceClassifierOutput,
22 | TokenClassifierOutput,
23 | )
24 | from transformers.modeling_utils import (
25 | PreTrainedModel,
26 | apply_chunking_to_forward,
27 | find_pruneable_heads_and_indices,
28 | prune_linear_layer,
29 | )
30 | from transformers.utils import logging
31 | from .configuration_LiLTRobertaLike import LiLTRobertaLikeConfig
32 |
33 | logger = logging.get_logger(__name__)
34 |
35 | class LiLTRobertaLikeTextEmbeddings(nn.Module):
36 | def __init__(self, config):
37 | super().__init__()
38 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
39 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
40 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
41 |
42 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
43 | # any TensorFlow checkpoint file
44 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
45 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
46 |
47 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized
48 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
49 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
50 |
51 | # End copy
52 | self.padding_idx = config.pad_token_id
53 | self.position_embeddings = nn.Embedding(
54 | config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
55 | )
56 |
57 | def forward(
58 | self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
59 | ):
60 | if position_ids is None:
61 | if input_ids is not None:
62 | # Create the position ids from the input token ids. Any padded tokens remain padded.
63 | position_ids = create_position_ids_from_input_ids(
64 | input_ids, self.padding_idx, past_key_values_length
65 | ).to(input_ids.device)
66 | else:
67 | position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
68 |
69 | if input_ids is not None:
70 | input_shape = input_ids.size()
71 | else:
72 | input_shape = inputs_embeds.size()[:-1]
73 |
74 | if token_type_ids is None:
75 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
76 |
77 | if inputs_embeds is None:
78 | inputs_embeds = self.word_embeddings(input_ids)
79 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
80 |
81 | embeddings = inputs_embeds + token_type_embeddings
82 | if self.position_embedding_type == "absolute":
83 | position_embeddings = self.position_embeddings(position_ids)
84 | embeddings += position_embeddings
85 | embeddings = self.LayerNorm(embeddings)
86 | embeddings = self.dropout(embeddings)
87 | return embeddings, position_ids
88 |
89 | def create_position_ids_from_inputs_embeds(self, inputs_embeds):
90 | """
91 | We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
92 | Args:
93 | inputs_embeds: torch.Tensor
94 | Returns: torch.Tensor
95 | """
96 | input_shape = inputs_embeds.size()[:-1]
97 | sequence_length = input_shape[1]
98 |
99 | position_ids = torch.arange(
100 | self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
101 | )
102 | return position_ids.unsqueeze(0).expand(input_shape)
103 |
104 |
105 | class LiLTRobertaLikeLayoutEmbeddings(nn.Module):
106 | def __init__(self, config):
107 | super(LiLTRobertaLikeLayoutEmbeddings, self).__init__()
108 | self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
109 | self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
110 | self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
111 | self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
112 |
113 | self.padding_idx = config.pad_token_id
114 | self.box_position_embeddings = nn.Embedding(
115 | config.max_position_embeddings, config.hidden_size//config.channel_shrink_ratio, padding_idx=self.padding_idx
116 | )
117 | self.box_linear_embeddings = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size//config.channel_shrink_ratio)
118 | self.LayerNorm = nn.LayerNorm(config.hidden_size//config.channel_shrink_ratio, eps=config.layer_norm_eps)
119 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
120 |
121 | def forward(
122 | self,
123 | bbox=None,
124 | position_ids=None,
125 | ):
126 |
127 | try:
128 | left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
129 | upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
130 | right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
131 | lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
132 | except IndexError as e:
133 | raise IndexError("The :obj:`bbox`coordinate values should be within 0-1000 range.") from e
134 |
135 | h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
136 | w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])
137 |
138 | spatial_position_embeddings = torch.cat(
139 | [
140 | left_position_embeddings,
141 | upper_position_embeddings,
142 | right_position_embeddings,
143 | lower_position_embeddings,
144 | h_position_embeddings,
145 | w_position_embeddings,
146 | ],
147 | dim=-1,
148 | )
149 | spatial_position_embeddings = self.box_linear_embeddings(spatial_position_embeddings)
150 | box_position_embeddings = self.box_position_embeddings(position_ids)
151 |
152 | spatial_position_embeddings = spatial_position_embeddings + box_position_embeddings
153 |
154 | spatial_position_embeddings = self.LayerNorm(spatial_position_embeddings)
155 | spatial_position_embeddings = self.dropout(spatial_position_embeddings)
156 |
157 | return spatial_position_embeddings
158 |
159 |
160 | class LiLTRobertaLikeSelfAttention(nn.Module):
161 | def __init__(self, config):
162 | super().__init__()
163 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
164 | raise ValueError(
165 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
166 | f"heads ({config.num_attention_heads})"
167 | )
168 |
169 | self.num_attention_heads = config.num_attention_heads
170 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
171 | self.all_head_size = self.num_attention_heads * self.attention_head_size
172 |
173 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
174 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
175 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
176 |
177 | self.layout_query = nn.Linear(config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio)
178 | self.layout_key = nn.Linear(config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio)
179 | self.layout_value = nn.Linear(config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio)
180 |
181 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
182 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
183 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
184 | self.max_position_embeddings = config.max_position_embeddings
185 | self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
186 |
187 | self.is_decoder = config.is_decoder
188 | self.channel_shrink_ratio = config.channel_shrink_ratio
189 |
190 | def transpose_for_scores(self, x, r=1):
191 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size//r)
192 | x = x.view(*new_x_shape)
193 | return x.permute(0, 2, 1, 3)
194 |
195 | def forward(
196 | self,
197 | hidden_states,
198 | layout_inputs,
199 | attention_mask=None,
200 | head_mask=None,
201 | encoder_hidden_states=None,
202 | encoder_attention_mask=None,
203 | past_key_value=None,
204 | output_attentions=False,
205 | ):
206 |
207 | layout_value_layer = self.transpose_for_scores(self.layout_value(layout_inputs), r=self.channel_shrink_ratio)
208 | layout_key_layer = self.transpose_for_scores(self.layout_key(layout_inputs), r=self.channel_shrink_ratio)
209 | layout_query_layer = self.transpose_for_scores(self.layout_query(layout_inputs), r=self.channel_shrink_ratio)
210 |
211 | mixed_query_layer = self.query(hidden_states)
212 |
213 | # If this is instantiated as a cross-attention module, the keys
214 | # and values come from an encoder; the attention mask needs to be
215 | # such that the encoder's padding tokens are not attended to.
216 | is_cross_attention = encoder_hidden_states is not None
217 |
218 | if is_cross_attention and past_key_value is not None:
219 | # reuse k,v, cross_attentions
220 | key_layer = past_key_value[0]
221 | value_layer = past_key_value[1]
222 | attention_mask = encoder_attention_mask
223 | elif is_cross_attention:
224 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
225 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
226 | attention_mask = encoder_attention_mask
227 | elif past_key_value is not None:
228 | key_layer = self.transpose_for_scores(self.key(hidden_states))
229 | value_layer = self.transpose_for_scores(self.value(hidden_states))
230 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
231 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
232 | else:
233 | key_layer = self.transpose_for_scores(self.key(hidden_states))
234 | value_layer = self.transpose_for_scores(self.value(hidden_states))
235 |
236 | query_layer = self.transpose_for_scores(mixed_query_layer)
237 |
238 | if self.is_decoder:
239 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
240 | # Further calls to cross_attention layer can then reuse all cross-attention
241 | # key/value_states (first "if" case)
242 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
243 | # all previous decoder key/value_states. Further calls to uni-directional self-attention
244 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
245 | # if encoder bi-directional self-attention `past_key_value` is always `None`
246 | past_key_value = (key_layer, value_layer)
247 |
248 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
249 | layout_attention_scores = torch.matmul(layout_query_layer, layout_key_layer.transpose(-1, -2))
250 |
251 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
252 | seq_length = hidden_states.size()[1]
253 | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
254 | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
255 | distance = position_ids_l - position_ids_r
256 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
257 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
258 |
259 | if self.position_embedding_type == "relative_key":
260 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
261 | attention_scores = attention_scores + relative_position_scores
262 | elif self.position_embedding_type == "relative_key_query":
263 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
264 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
265 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
266 |
267 | tmp_attention_scores = attention_scores / math.sqrt(self.attention_head_size)
268 | tmp_layout_attention_scores = layout_attention_scores / math.sqrt(self.attention_head_size//self.channel_shrink_ratio)
269 | attention_scores = tmp_attention_scores + tmp_layout_attention_scores
270 | layout_attention_scores = tmp_layout_attention_scores + tmp_attention_scores
271 |
272 | if attention_mask is not None:
273 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
274 | layout_attention_scores = layout_attention_scores + attention_mask
275 |
276 | # Normalize the attention scores to probabilities.
277 | layout_attention_probs = nn.Softmax(dim=-1)(layout_attention_scores)
278 |
279 | # This is actually dropping out entire tokens to attend to, which might
280 | # seem a bit unusual, but is taken from the original Transformer paper.
281 | layout_attention_probs = self.dropout(layout_attention_probs)
282 |
283 | # Mask heads if we want to
284 | if head_mask is not None:
285 | layout_attention_probs = layout_attention_probs * head_mask
286 |
287 | layout_context_layer = torch.matmul(layout_attention_probs, layout_value_layer)
288 |
289 | layout_context_layer = layout_context_layer.permute(0, 2, 1, 3).contiguous()
290 | new_context_layer_shape = layout_context_layer.size()[:-2] + (self.all_head_size//self.channel_shrink_ratio,)
291 | layout_context_layer = layout_context_layer.view(*new_context_layer_shape)
292 |
293 | if attention_mask is not None:
294 | # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
295 | attention_scores = attention_scores + attention_mask
296 |
297 | # Normalize the attention scores to probabilities.
298 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
299 |
300 | # This is actually dropping out entire tokens to attend to, which might
301 | # seem a bit unusual, but is taken from the original Transformer paper.
302 | attention_probs = self.dropout(attention_probs)
303 |
304 | # Mask heads if we want to
305 | if head_mask is not None:
306 | attention_probs = attention_probs * head_mask
307 |
308 | context_layer = torch.matmul(attention_probs, value_layer)
309 |
310 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
311 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
312 | context_layer = context_layer.view(*new_context_layer_shape)
313 |
314 | outputs = ((context_layer, layout_context_layer), attention_probs) if output_attentions else ((context_layer, layout_context_layer),)
315 |
316 | if self.is_decoder:
317 | outputs = outputs + (past_key_value,)
318 | return outputs
319 |
320 |
321 | class LiLTRobertaLikeSelfOutput(nn.Module):
322 | def __init__(self, config):
323 | super().__init__()
324 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
325 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
326 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
327 |
328 | def forward(self, hidden_states, input_tensor):
329 | hidden_states = self.dense(hidden_states)
330 | hidden_states = self.dropout(hidden_states)
331 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
332 | return hidden_states
333 |
334 |
335 | class LiLTRobertaLikeAttention(nn.Module):
336 | def __init__(self, config):
337 | super().__init__()
338 | self.self = LiLTRobertaLikeSelfAttention(config)
339 | self.output = LiLTRobertaLikeSelfOutput(config)
340 | self.pruned_heads = set()
341 |
342 | ori_hidden_size = config.hidden_size
343 | config.hidden_size = config.hidden_size // config.channel_shrink_ratio
344 | self.layout_output = LiLTRobertaLikeSelfOutput(config)
345 | config.hidden_size = ori_hidden_size
346 |
347 | def prune_heads(self, heads):
348 | if len(heads) == 0:
349 | return
350 | heads, index = find_pruneable_heads_and_indices(
351 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
352 | )
353 |
354 | # Prune linear layers
355 | self.self.query = prune_linear_layer(self.self.query, index)
356 | self.self.key = prune_linear_layer(self.self.key, index)
357 | self.self.value = prune_linear_layer(self.self.value, index)
358 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
359 |
360 | # Update hyper params and store pruned heads
361 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
362 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
363 | self.pruned_heads = self.pruned_heads.union(heads)
364 |
365 | def forward(
366 | self,
367 | hidden_states,
368 | layout_inputs,
369 | attention_mask=None,
370 | head_mask=None,
371 | encoder_hidden_states=None,
372 | encoder_attention_mask=None,
373 | past_key_value=None,
374 | output_attentions=False,
375 | ):
376 | self_outputs = self.self(
377 | hidden_states,
378 | layout_inputs,
379 | attention_mask,
380 | head_mask,
381 | encoder_hidden_states,
382 | encoder_attention_mask,
383 | past_key_value,
384 | output_attentions,
385 | )
386 | attention_output = self.output(self_outputs[0][0], hidden_states)
387 | layout_attention_output = self.layout_output(self_outputs[0][1], layout_inputs)
388 |
389 | outputs = ((attention_output, layout_attention_output),) + self_outputs[1:] # add attentions if we output them
390 | return outputs
391 |
392 |
393 | class LiLTRobertaLikeIntermediate(nn.Module):
394 | def __init__(self, config):
395 | super().__init__()
396 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
397 | if isinstance(config.hidden_act, str):
398 | self.intermediate_act_fn = ACT2FN[config.hidden_act]
399 | else:
400 | self.intermediate_act_fn = config.hidden_act
401 |
402 | def forward(self, hidden_states):
403 | hidden_states = self.dense(hidden_states)
404 | hidden_states = self.intermediate_act_fn(hidden_states)
405 | return hidden_states
406 |
407 |
408 | class LiLTRobertaLikeOutput(nn.Module):
409 | def __init__(self, config):
410 | super().__init__()
411 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
412 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
413 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
414 |
415 | def forward(self, hidden_states, input_tensor):
416 | hidden_states = self.dense(hidden_states)
417 | hidden_states = self.dropout(hidden_states)
418 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
419 | return hidden_states
420 |
421 |
422 | class LiLTRobertaLikeLayer(nn.Module):
423 | def __init__(self, config):
424 | super().__init__()
425 | self.chunk_size_feed_forward = config.chunk_size_feed_forward
426 | self.seq_len_dim = 1
427 | self.attention = LiLTRobertaLikeAttention(config)
428 | self.is_decoder = config.is_decoder
429 | self.add_cross_attention = config.add_cross_attention
430 | if self.add_cross_attention:
431 | assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
432 | self.crossattention = LiLTRobertaLikeAttention(config)
433 | self.intermediate = LiLTRobertaLikeIntermediate(config)
434 | self.output = LiLTRobertaLikeOutput(config)
435 |
436 | ori_hidden_size = config.hidden_size
437 | ori_intermediate_size = config.intermediate_size
438 | config.hidden_size = config.hidden_size // config.channel_shrink_ratio
439 | config.intermediate_size = config.intermediate_size // config.channel_shrink_ratio
440 | self.layout_intermediate = LiLTRobertaLikeIntermediate(config)
441 | self.layout_output = LiLTRobertaLikeOutput(config)
442 | config.hidden_size = ori_hidden_size
443 | config.intermediate_size = ori_intermediate_size
444 |
445 | def forward(
446 | self,
447 | hidden_states,
448 | layout_inputs,
449 | attention_mask=None,
450 | head_mask=None,
451 | encoder_hidden_states=None,
452 | encoder_attention_mask=None,
453 | past_key_value=None,
454 | output_attentions=False,
455 | ):
456 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
457 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
458 | self_attention_outputs = self.attention(
459 | hidden_states,
460 | layout_inputs,
461 | attention_mask,
462 | head_mask,
463 | output_attentions=output_attentions,
464 | past_key_value=self_attn_past_key_value,
465 | )
466 | attention_output = self_attention_outputs[0][0]
467 | layout_attention_output = self_attention_outputs[0][1]
468 |
469 | # if decoder, the last output is tuple of self-attn cache
470 | if self.is_decoder:
471 | outputs = self_attention_outputs[1:-1]
472 | present_key_value = self_attention_outputs[-1]
473 | else:
474 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
475 |
476 | cross_attn_present_key_value = None
477 | if self.is_decoder and encoder_hidden_states is not None:
478 | assert hasattr(
479 | self, "crossattention"
480 | ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
481 |
482 | # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
483 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
484 | cross_attention_outputs = self.crossattention(
485 | attention_output,
486 | attention_mask,
487 | head_mask,
488 | encoder_hidden_states,
489 | encoder_attention_mask,
490 | cross_attn_past_key_value,
491 | output_attentions,
492 | )
493 | attention_output = cross_attention_outputs[0]
494 | outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
495 |
496 | # add cross-attn cache to positions 3,4 of present_key_value tuple
497 | cross_attn_present_key_value = cross_attention_outputs[-1]
498 | present_key_value = present_key_value + cross_attn_present_key_value
499 |
500 | layer_output = apply_chunking_to_forward(
501 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
502 | )
503 |
504 | layout_layer_output = apply_chunking_to_forward(
505 | self.layout_feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, layout_attention_output
506 | )
507 |
508 | outputs = ((layer_output, layout_layer_output),) + outputs
509 |
510 | # if decoder, return the attn key/values as the last output
511 | if self.is_decoder:
512 | outputs = outputs + (present_key_value,)
513 |
514 | return outputs
515 |
516 | def feed_forward_chunk(self, attention_output):
517 | intermediate_output = self.intermediate(attention_output)
518 | layer_output = self.output(intermediate_output, attention_output)
519 | return layer_output
520 |
521 | def layout_feed_forward_chunk(self, attention_output):
522 | intermediate_output = self.layout_intermediate(attention_output)
523 | layer_output = self.layout_output(intermediate_output, attention_output)
524 | return layer_output
525 |
526 |
527 | class LiLTRobertaLikeEncoder(nn.Module):
528 | def __init__(self, config):
529 | super().__init__()
530 | self.config = config
531 | self.layer = nn.ModuleList([LiLTRobertaLikeLayer(config) for _ in range(config.num_hidden_layers)])
532 |
533 | def forward(
534 | self,
535 | hidden_states,
536 | layout_inputs,
537 | attention_mask=None,
538 | head_mask=None,
539 | encoder_hidden_states=None,
540 | encoder_attention_mask=None,
541 | past_key_values=None,
542 | use_cache=None,
543 | output_attentions=False,
544 | output_hidden_states=False,
545 | return_dict=True,
546 | ):
547 | all_hidden_states = () if output_hidden_states else None
548 | all_self_attentions = () if output_attentions else None
549 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
550 |
551 | next_decoder_cache = () if use_cache else None
552 | for i, layer_module in enumerate(self.layer):
553 | if output_hidden_states:
554 | all_hidden_states = all_hidden_states + (hidden_states,)
555 |
556 | layer_head_mask = head_mask[i] if head_mask is not None else None
557 | past_key_value = past_key_values[i] if past_key_values is not None else None
558 |
559 | if getattr(self.config, "gradient_checkpointing", False) and self.training:
560 |
561 | if use_cache:
562 | logger.warning(
563 | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
564 | "`use_cache=False`..."
565 | )
566 | use_cache = False
567 |
568 | def create_custom_forward(module):
569 | def custom_forward(*inputs):
570 | return module(*inputs, past_key_value, output_attentions)
571 |
572 | return custom_forward
573 |
574 | layer_outputs = torch.utils.checkpoint.checkpoint(
575 | create_custom_forward(layer_module),
576 | hidden_states,
577 | layout_inputs,
578 | attention_mask,
579 | layer_head_mask,
580 | encoder_hidden_states,
581 | encoder_attention_mask,
582 | )
583 |
584 | else:
585 | layer_outputs = layer_module(
586 | hidden_states,
587 | layout_inputs,
588 | attention_mask,
589 | layer_head_mask,
590 | encoder_hidden_states,
591 | encoder_attention_mask,
592 | past_key_value,
593 | output_attentions,
594 | )
595 |
596 | hidden_states = layer_outputs[0][0]
597 | layout_inputs = layer_outputs[0][1]
598 |
599 | if use_cache:
600 | next_decoder_cache += (layer_outputs[-1],)
601 | if output_attentions:
602 | all_self_attentions = all_self_attentions + (layer_outputs[1],)
603 | if self.config.add_cross_attention:
604 | all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
605 |
606 | if output_hidden_states:
607 | all_hidden_states = all_hidden_states + (hidden_states,)
608 |
609 | if not return_dict:
610 | return tuple(
611 | v
612 | for v in [
613 | hidden_states,
614 | next_decoder_cache,
615 | all_hidden_states,
616 | all_self_attentions,
617 | all_cross_attentions,
618 | ]
619 | if v is not None
620 | ), layout_inputs
621 |
622 | return BaseModelOutputWithPastAndCrossAttentions(
623 | last_hidden_state=hidden_states,
624 | past_key_values=next_decoder_cache,
625 | hidden_states=all_hidden_states,
626 | attentions=all_self_attentions,
627 | cross_attentions=all_cross_attentions,
628 | ), layout_inputs
629 |
630 |
631 | class LiLTRobertaLikePooler(nn.Module):
632 | def __init__(self, config):
633 | super().__init__()
634 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
635 | self.activation = nn.Tanh()
636 |
637 | def forward(self, hidden_states):
638 | # We "pool" the model by simply taking the hidden state corresponding
639 | # to the first token.
640 | first_token_tensor = hidden_states[:, 0]
641 | pooled_output = self.dense(first_token_tensor)
642 | pooled_output = self.activation(pooled_output)
643 | return pooled_output
644 |
645 |
646 | class LiLTRobertaLikePreTrainedModel(PreTrainedModel):
647 | """
648 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
649 | models.
650 | """
651 | config_class = LiLTRobertaLikeConfig
652 | base_model_prefix = "liltrobertalike"
653 | def _init_weights(self, module):
654 | """Initialize the weights"""
655 | if isinstance(module, nn.Linear):
656 | # Slightly different from the TF version which uses truncated_normal for initialization
657 | # cf https://github.com/pytorch/pytorch/pull/5617
658 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
659 | if module.bias is not None:
660 | module.bias.data.zero_()
661 | elif isinstance(module, nn.Embedding):
662 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
663 | if module.padding_idx is not None:
664 | module.weight.data[module.padding_idx].zero_()
665 | elif isinstance(module, nn.LayerNorm):
666 | module.bias.data.zero_()
667 | module.weight.data.fill_(1.0)
668 |
669 |
670 |
671 | class LiLTRobertaLikeModel(LiLTRobertaLikePreTrainedModel):
672 |
673 | _keys_to_ignore_on_load_missing = [r"position_ids"]
674 |
675 | def __init__(self, config, add_pooling_layer=True):
676 | super().__init__(config)
677 | self.config = config
678 |
679 | self.embeddings = LiLTRobertaLikeTextEmbeddings(config)
680 | self.layout_embeddings = LiLTRobertaLikeLayoutEmbeddings(config)
681 |
682 | self.encoder = LiLTRobertaLikeEncoder(config)
683 |
684 | self.pooler = LiLTRobertaLikePooler(config) if add_pooling_layer else None
685 |
686 | self.init_weights()
687 |
688 | def get_input_embeddings(self):
689 | return self.embeddings.word_embeddings
690 |
691 | def set_input_embeddings(self, value):
692 | self.embeddings.word_embeddings = value
693 |
694 | def _prune_heads(self, heads_to_prune):
695 | """
696 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
697 | class PreTrainedModel
698 | """
699 | for layer, heads in heads_to_prune.items():
700 | self.encoder.layer[layer].attention.prune_heads(heads)
701 |
702 | def forward(
703 | self,
704 | input_ids=None,
705 | bbox=None,
706 | attention_mask=None,
707 | token_type_ids=None,
708 | position_ids=None,
709 | head_mask=None,
710 | inputs_embeds=None,
711 | encoder_hidden_states=None,
712 | encoder_attention_mask=None,
713 | past_key_values=None,
714 | use_cache=None,
715 | output_attentions=None,
716 | output_hidden_states=None,
717 | return_dict=None,
718 | ):
719 |
720 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
721 | output_hidden_states = (
722 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
723 | )
724 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
725 |
726 | if self.config.is_decoder:
727 | use_cache = use_cache if use_cache is not None else self.config.use_cache
728 | else:
729 | use_cache = False
730 |
731 | if input_ids is not None and inputs_embeds is not None:
732 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
733 | elif input_ids is not None:
734 | input_shape = input_ids.size()
735 | batch_size, seq_length = input_shape
736 | elif inputs_embeds is not None:
737 | input_shape = inputs_embeds.size()[:-1]
738 | batch_size, seq_length = input_shape
739 | else:
740 | raise ValueError("You have to specify either input_ids or inputs_embeds")
741 |
742 | device = input_ids.device if input_ids is not None else inputs_embeds.device
743 |
744 | # past_key_values_length
745 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
746 |
747 | if attention_mask is None:
748 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
749 | if token_type_ids is None:
750 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
751 |
752 | if bbox is None:
753 | bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
754 |
755 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
756 | # ourselves in which case we just need to make it broadcastable to all heads.
757 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
758 |
759 | # If a 2D or 3D attention mask is provided for the cross-attention
760 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
761 | if self.config.is_decoder and encoder_hidden_states is not None:
762 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
763 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
764 | if encoder_attention_mask is None:
765 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
766 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
767 | else:
768 | encoder_extended_attention_mask = None
769 |
770 | # Prepare head mask if needed
771 | # 1.0 in head_mask indicate we keep the head
772 | # attention_probs has shape bsz x n_heads x N x N
773 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
774 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
775 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
776 |
777 | embedding_output, position_ids = self.embeddings(
778 | input_ids=input_ids,
779 | position_ids=position_ids,
780 | token_type_ids=token_type_ids,
781 | inputs_embeds=inputs_embeds,
782 | past_key_values_length=past_key_values_length,
783 | )
784 |
785 | layout_embedding_output = self.layout_embeddings(
786 | bbox=bbox,
787 | position_ids=position_ids,
788 | )
789 |
790 | encoder_outputs, layout_encoder_outputs = self.encoder(
791 | embedding_output,
792 | layout_embedding_output,
793 | attention_mask=extended_attention_mask,
794 | head_mask=head_mask,
795 | encoder_hidden_states=encoder_hidden_states,
796 | encoder_attention_mask=encoder_extended_attention_mask,
797 | past_key_values=past_key_values,
798 | use_cache=use_cache,
799 | output_attentions=output_attentions,
800 | output_hidden_states=output_hidden_states,
801 | return_dict=return_dict,
802 | )
803 |
804 | sequence_output = encoder_outputs[0]
805 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
806 |
807 | if not return_dict:
808 | return (sequence_output, pooled_output) + encoder_outputs[1:]
809 |
810 | return BaseModelOutputWithPoolingAndCrossAttentions(
811 | last_hidden_state=sequence_output,
812 | pooler_output=pooled_output,
813 | past_key_values=encoder_outputs.past_key_values,
814 | hidden_states=encoder_outputs.hidden_states,
815 | attentions=encoder_outputs.attentions,
816 | cross_attentions=encoder_outputs.cross_attentions,
817 | ), layout_encoder_outputs
818 |
819 |
820 |
821 | class LiLTRobertaLikeForTokenClassification(LiLTRobertaLikePreTrainedModel):
822 | _keys_to_ignore_on_load_unexpected = [r"pooler"]
823 | _keys_to_ignore_on_load_missing = [r"position_ids"]
824 |
825 | def __init__(self, config):
826 | super().__init__(config)
827 | self.num_labels = config.num_labels
828 |
829 | self.lilt = LiLTRobertaLikeModel(config, add_pooling_layer=False)
830 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
831 |
832 | self.classifier = nn.Linear(config.hidden_size + config.hidden_size//config.channel_shrink_ratio, config.num_labels)
833 |
834 | self.init_weights()
835 |
836 | def forward(
837 | self,
838 | input_ids=None,
839 | bbox=None,
840 | attention_mask=None,
841 | token_type_ids=None,
842 | position_ids=None,
843 | head_mask=None,
844 | inputs_embeds=None,
845 | labels=None,
846 | output_attentions=None,
847 | output_hidden_states=None,
848 | return_dict=None,
849 | ):
850 | r"""
851 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
852 | Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
853 | 1]``.
854 | """
855 |
856 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
857 |
858 | outputs, layout_outputs = self.lilt(
859 | input_ids,
860 | bbox=bbox,
861 | attention_mask=attention_mask,
862 | token_type_ids=token_type_ids,
863 | position_ids=position_ids,
864 | head_mask=head_mask,
865 | inputs_embeds=inputs_embeds,
866 | output_attentions=output_attentions,
867 | output_hidden_states=output_hidden_states,
868 | return_dict=return_dict,
869 | )
870 |
871 | sequence_output = outputs[0]
872 |
873 | sequence_output = torch.cat([sequence_output, layout_outputs], -1)
874 | sequence_output = self.dropout(sequence_output)
875 | logits = self.classifier(sequence_output)
876 |
877 | loss = None
878 | if labels is not None:
879 | loss_fct = CrossEntropyLoss()
880 | # Only keep active parts of the loss
881 | if attention_mask is not None:
882 | active_loss = attention_mask.view(-1) == 1
883 | active_logits = logits.view(-1, self.num_labels)
884 | active_labels = torch.where(
885 | active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
886 | )
887 | loss = loss_fct(active_logits, active_labels)
888 | else:
889 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
890 |
891 | if not return_dict:
892 | output = (logits,) + outputs[2:]
893 | return ((loss,) + output) if loss is not None else output
894 |
895 | return TokenClassifierOutput(
896 | loss=loss,
897 | logits=logits,
898 | hidden_states=outputs.hidden_states,
899 | attentions=outputs.attentions,
900 | )
901 |
902 |
903 | from dataclasses import dataclass
904 | from typing import Dict, Optional, Tuple
905 | from transformers.file_utils import ModelOutput
906 | from ...modules.decoders.re import REDecoder
907 | from ...utils import ReOutput
908 |
909 | class LiLTRobertaLikeForRelationExtraction(LiLTRobertaLikePreTrainedModel):
910 | _keys_to_ignore_on_load_unexpected = [r"pooler"]
911 | _keys_to_ignore_on_load_missing = [r"position_ids"]
912 | def __init__(self, config):
913 | super().__init__(config)
914 |
915 | self.lilt = LiLTRobertaLikeModel(config, add_pooling_layer=False)
916 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
917 | self.extractor = REDecoder(config, config.hidden_size + config.hidden_size // config.channel_shrink_ratio)
918 | self.init_weights()
919 |
920 | def forward(
921 | self,
922 | input_ids=None,
923 | bbox=None,
924 | attention_mask=None,
925 | token_type_ids=None,
926 | position_ids=None,
927 | head_mask=None,
928 | inputs_embeds=None,
929 | labels=None,
930 | output_attentions=None,
931 | output_hidden_states=None,
932 | return_dict=None,
933 | entities=None,
934 | relations=None,
935 | ):
936 |
937 | outputs, layout_outputs = self.lilt(
938 | input_ids,
939 | bbox=bbox,
940 | attention_mask=attention_mask,
941 | token_type_ids=token_type_ids,
942 | position_ids=position_ids,
943 | head_mask=head_mask,
944 | inputs_embeds=inputs_embeds,
945 | output_attentions=output_attentions,
946 | output_hidden_states=output_hidden_states,
947 | return_dict=return_dict,
948 | )
949 |
950 | seq_length = input_ids.size(1)
951 | sequence_output = outputs[0]
952 | sequence_output = torch.cat([sequence_output, layout_outputs], -1)
953 |
954 | sequence_output = self.dropout(sequence_output)
955 | loss, pred_relations = self.extractor(sequence_output, entities, relations)
956 |
957 | return ReOutput(
958 | loss=loss,
959 | entities=entities,
960 | relations=relations,
961 | pred_relations=pred_relations,
962 | )
963 |
964 |
965 | def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
966 | """
967 | Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
968 | are ignored. This is modified from fairseq's `utils.make_positions`.
969 | Args:
970 | x: torch.Tensor x:
971 | Returns: torch.Tensor
972 | """
973 | # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
974 | mask = input_ids.ne(padding_idx).int()
975 | incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
976 | return incremental_indices.long() + padding_idx
--------------------------------------------------------------------------------
/LiLTfinetune/models/LiLTRobertaLike/tokenization_LiLTRobertaLike.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | from transformers import RobertaTokenizer, XLMRobertaTokenizer
4 | from transformers.utils import logging
5 |
6 | logger = logging.get_logger(__name__)
7 |
8 | SPIECE_UNDERLINE = "▁"
9 |
10 | VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
11 |
12 | with open('tag.txt', 'r') as tagf:
13 | TAG = tagf.read().lower()
14 | assert TAG == 'monolingual' or TAG == 'multilingual', 'TAG is wrong. It should be monolingual or multilingual.'
15 |
16 | if TAG == 'monolingual':
17 | class LiLTRobertaLikeTokenizer(RobertaTokenizer):
18 | vocab_files_names = VOCAB_FILES_NAMES
19 | max_model_input_sizes = {"lilt-roberta-base": 512,}
20 | model_input_names = ["input_ids", "attention_mask"]
21 |
22 | def __init__(self, model_max_length=512, **kwargs):
23 | super().__init__(model_max_length=model_max_length, **kwargs)
24 |
25 | elif TAG == 'multilingual':
26 | class LiLTRobertaLikeTokenizer(XLMRobertaTokenizer):
27 | vocab_files_names = VOCAB_FILES_NAMES
28 | max_model_input_sizes = {"lilt-infoxlm-base": 512,}
29 | model_input_names = ["input_ids", "attention_mask"]
30 |
31 | def __init__(self, model_max_length=512, **kwargs):
32 | super().__init__(model_max_length=model_max_length, **kwargs)
33 |
--------------------------------------------------------------------------------
/LiLTfinetune/models/LiLTRobertaLike/tokenization_LiLTRobertaLike_fast.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from transformers import RobertaTokenizerFast, XLMRobertaTokenizerFast
3 | from transformers.file_utils import is_sentencepiece_available
4 | from transformers.utils import logging
5 |
6 |
7 | if is_sentencepiece_available():
8 | from .tokenization_LiLTRobertaLike import LiLTRobertaLikeTokenizer
9 | else:
10 | LiLTRobertaLikeTokenizer = None
11 |
12 |
13 | logger = logging.get_logger(__name__)
14 |
15 | VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
16 |
17 | with open('tag.txt', 'r') as tagf:
18 | TAG = tagf.read().lower()
19 | assert TAG == 'monolingual' or TAG == 'multilingual', 'TAG is wrong. It should be monolingual or multilingual.'
20 |
21 | if TAG == 'monolingual':
22 | class LiLTRobertaLikeTokenizerFast(RobertaTokenizerFast):
23 |
24 | vocab_files_names = VOCAB_FILES_NAMES
25 | max_model_input_sizes = {"lilt-roberta-base": 512,}
26 | model_input_names = ["input_ids", "attention_mask"]
27 | slow_tokenizer_class = LiLTRobertaLikeTokenizer
28 |
29 | def __init__(self, model_max_length=512, **kwargs):
30 | super().__init__(model_max_length=model_max_length, **kwargs)
31 |
32 | elif TAG == 'multilingual':
33 | class LiLTRobertaLikeTokenizerFast(XLMRobertaTokenizerFast):
34 |
35 | vocab_files_names = VOCAB_FILES_NAMES
36 | max_model_input_sizes = {"lilt-infoxlm-base": 512,}
37 | model_input_names = ["input_ids", "attention_mask"]
38 | slow_tokenizer_class = LiLTRobertaLikeTokenizer
39 |
40 | def __init__(self, model_max_length=512, **kwargs):
41 | super().__init__(model_max_length=model_max_length, **kwargs)
42 |
--------------------------------------------------------------------------------
/LiLTfinetune/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/LiLTfinetune/models/__init__.py
--------------------------------------------------------------------------------
/LiLTfinetune/models/model_args.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 |
5 | @dataclass
6 | class ModelArguments:
7 | """
8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
9 | """
10 |
11 | model_name_or_path: str = field(
12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
13 | )
14 | config_name: Optional[str] = field(
15 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
16 | )
17 | tokenizer_name: Optional[str] = field(
18 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
19 | )
20 | cache_dir: Optional[str] = field(
21 | default=None,
22 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
23 | )
24 | model_revision: str = field(
25 | default="main",
26 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
27 | )
28 | use_auth_token: bool = field(
29 | default=False,
30 | metadata={
31 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
32 | "with private models)."
33 | },
34 | )
35 |
--------------------------------------------------------------------------------
/LiLTfinetune/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/LiLTfinetune/modules/__init__.py
--------------------------------------------------------------------------------
/LiLTfinetune/modules/decoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/LiLTfinetune/modules/decoders/__init__.py
--------------------------------------------------------------------------------
/LiLTfinetune/modules/decoders/re.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import CrossEntropyLoss
6 |
7 |
8 | class BiaffineAttention(torch.nn.Module):
9 | """Implements a biaffine attention operator for binary relation classification.
10 |
11 | PyTorch implementation of the biaffine attention operator from "End-to-end neural relation
12 | extraction using deep biaffine attention" (https://arxiv.org/abs/1812.11275) which can be used
13 | as a classifier for binary relation classification.
14 |
15 | Args:
16 | in_features (int): The size of the feature dimension of the inputs.
17 | out_features (int): The size of the feature dimension of the output.
18 |
19 | Shape:
20 | - x_1: `(N, *, in_features)` where `N` is the batch dimension and `*` means any number of
21 | additional dimensisons.
22 | - x_2: `(N, *, in_features)`, where `N` is the batch dimension and `*` means any number of
23 | additional dimensions.
24 | - Output: `(N, *, out_features)`, where `N` is the batch dimension and `*` means any number
25 | of additional dimensions.
26 |
27 | Examples:
28 | >>> batch_size, in_features, out_features = 32, 100, 4
29 | >>> biaffine_attention = BiaffineAttention(in_features, out_features)
30 | >>> x_1 = torch.randn(batch_size, in_features)
31 | >>> x_2 = torch.randn(batch_size, in_features)
32 | >>> output = biaffine_attention(x_1, x_2)
33 | >>> print(output.size())
34 | torch.Size([32, 4])
35 | """
36 |
37 | def __init__(self, in_features, out_features):
38 | super(BiaffineAttention, self).__init__()
39 |
40 | self.in_features = in_features
41 | self.out_features = out_features
42 |
43 | self.bilinear = torch.nn.Bilinear(in_features, in_features, out_features, bias=False)
44 | self.linear = torch.nn.Linear(2 * in_features, out_features, bias=True)
45 |
46 | self.reset_parameters()
47 |
48 | def forward(self, x_1, x_2):
49 | return self.bilinear(x_1, x_2) + self.linear(torch.cat((x_1, x_2), dim=-1))
50 |
51 | def reset_parameters(self):
52 | self.bilinear.reset_parameters()
53 | self.linear.reset_parameters()
54 |
55 |
56 | class REDecoder(nn.Module):
57 | def __init__(self, config, input_size):
58 | super().__init__()
59 | self.entity_emb = nn.Embedding(3, input_size, scale_grad_by_freq=True)
60 | projection = nn.Sequential(
61 | nn.Linear(input_size * 2, config.hidden_size),
62 | nn.ReLU(),
63 | nn.Dropout(config.hidden_dropout_prob),
64 | nn.Linear(config.hidden_size, config.hidden_size // 2),
65 | nn.ReLU(),
66 | nn.Dropout(config.hidden_dropout_prob),
67 | )
68 | self.ffnn_head = copy.deepcopy(projection)
69 | self.ffnn_tail = copy.deepcopy(projection)
70 | self.rel_classifier = BiaffineAttention(config.hidden_size // 2, 2)
71 | self.loss_fct = CrossEntropyLoss()
72 |
73 | def build_relation(self, relations, entities):
74 | batch_size = len(relations)
75 | new_relations = []
76 | for b in range(batch_size):
77 | if len(entities[b]["start"]) <= 2:
78 | entities[b] = {"end": [1, 1], "label": [0, 0], "start": [0, 0]}
79 | all_possible_relations = set(
80 | [
81 | (i, j)
82 | for i in range(len(entities[b]["label"]))
83 | for j in range(len(entities[b]["label"]))
84 | if entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2
85 | ]
86 | )
87 | if len(all_possible_relations) == 0:
88 | all_possible_relations = set([(0, 1)])
89 | positive_relations = set(list(zip(relations[b]["head"], relations[b]["tail"])))
90 | negative_relations = all_possible_relations - positive_relations
91 | positive_relations = set([i for i in positive_relations if i in all_possible_relations])
92 | reordered_relations = list(positive_relations) + list(negative_relations)
93 | relation_per_doc = {"head": [], "tail": [], "label": []}
94 | relation_per_doc["head"] = [i[0] for i in reordered_relations]
95 | relation_per_doc["tail"] = [i[1] for i in reordered_relations]
96 | relation_per_doc["label"] = [1] * len(positive_relations) + [0] * (
97 | len(reordered_relations) - len(positive_relations)
98 | )
99 | assert len(relation_per_doc["head"]) != 0
100 | new_relations.append(relation_per_doc)
101 | return new_relations, entities
102 |
103 | def get_predicted_relations(self, logits, relations, entities):
104 | pred_relations = []
105 | for i, pred_label in enumerate(logits.argmax(-1)):
106 | if pred_label != 1:
107 | continue
108 | rel = {}
109 | rel["head_id"] = relations["head"][i]
110 | rel["head"] = (entities["start"][rel["head_id"]], entities["end"][rel["head_id"]])
111 | rel["head_type"] = entities["label"][rel["head_id"]]
112 |
113 | rel["tail_id"] = relations["tail"][i]
114 | rel["tail"] = (entities["start"][rel["tail_id"]], entities["end"][rel["tail_id"]])
115 | rel["tail_type"] = entities["label"][rel["tail_id"]]
116 | rel["type"] = 1
117 | pred_relations.append(rel)
118 | return pred_relations
119 |
120 | def forward(self, hidden_states, entities, relations):
121 | batch_size, max_n_words, context_dim = hidden_states.size()
122 | device = hidden_states.device
123 | relations, entities = self.build_relation(relations, entities)
124 | loss = 0
125 | all_pred_relations = []
126 | all_logits = []
127 | all_labels = []
128 |
129 | for b in range(batch_size):
130 | head_entities = torch.tensor(relations[b]["head"], device=device)
131 | tail_entities = torch.tensor(relations[b]["tail"], device=device)
132 | relation_labels = torch.tensor(relations[b]["label"], device=device)
133 | entities_start_index = torch.tensor(entities[b]["start"], device=device)
134 | entities_labels = torch.tensor(entities[b]["label"], device=device)
135 | head_index = entities_start_index[head_entities]
136 | head_label = entities_labels[head_entities]
137 | head_label_repr = self.entity_emb(head_label)
138 |
139 | tail_index = entities_start_index[tail_entities]
140 | tail_label = entities_labels[tail_entities]
141 | tail_label_repr = self.entity_emb(tail_label)
142 |
143 | head_repr = torch.cat(
144 | (hidden_states[b][head_index], head_label_repr),
145 | dim=-1,
146 | )
147 | tail_repr = torch.cat(
148 | (hidden_states[b][tail_index], tail_label_repr),
149 | dim=-1,
150 | )
151 | heads = self.ffnn_head(head_repr)
152 | tails = self.ffnn_tail(tail_repr)
153 | logits = self.rel_classifier(heads, tails)
154 | pred_relations = self.get_predicted_relations(logits, relations[b], entities[b])
155 | all_pred_relations.append(pred_relations)
156 | all_logits.append(logits)
157 | all_labels.append(relation_labels)
158 | all_logits = torch.cat(all_logits, 0)
159 | all_labels = torch.cat(all_labels, 0)
160 | loss = self.loss_fct(all_logits, all_labels)
161 | return loss, all_pred_relations
--------------------------------------------------------------------------------
/LiLTfinetune/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from .funsd_trainer import FunsdTrainer
2 | from .xfun_trainer import XfunReTrainer, XfunSerTrainer
3 |
--------------------------------------------------------------------------------
/LiLTfinetune/trainers/funsd_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Union
2 |
3 | import torch
4 |
5 | from transformers import Trainer
6 |
7 |
8 | class FunsdTrainer(Trainer):
9 | def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
10 | """
11 | Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
12 | handling potential state.
13 | """
14 | for k, v in inputs.items():
15 | if hasattr(v, "to") and hasattr(v, "device"):
16 | inputs[k] = v.to(self.args.device)
17 |
18 | if self.args.past_index >= 0 and self._past is not None:
19 | inputs["mems"] = self._past
20 |
21 | return inputs
22 |
--------------------------------------------------------------------------------
/LiLTfinetune/trainers/xfun_trainer.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import time
3 | from typing import Any, Dict, List, Optional, Tuple, Union
4 |
5 | import torch
6 | from packaging import version
7 | from torch import nn
8 | from torch.utils.data import DataLoader, Dataset
9 |
10 | from transformers.utils import logging
11 | from transformers.file_utils import is_sagemaker_mp_enabled
12 | from transformers.trainer_utils import EvalPrediction, PredictionOutput, speed_metrics, ShardedDDPOption
13 | from transformers.trainer_pt_utils import get_parameter_names
14 | from transformers.optimization import Adafactor, AdamW, get_scheduler
15 |
16 | from .funsd_trainer import FunsdTrainer
17 |
18 |
19 | if version.parse(torch.__version__) >= version.parse("1.6"):
20 | _is_native_amp_available = True
21 | from torch.cuda.amp import autocast
22 |
23 | logger = logging.get_logger(__name__)
24 |
25 |
26 | class XfunSerTrainer(FunsdTrainer):
27 | pass
28 |
29 |
30 | class XfunReTrainer(FunsdTrainer):
31 | def __init__(self, **kwargs):
32 | super().__init__(**kwargs)
33 | self.label_names.append("relations")
34 |
35 | def prediction_step(
36 | self,
37 | model: nn.Module,
38 | inputs: Dict[str, Union[torch.Tensor, Any]],
39 | prediction_loss_only: bool,
40 | ignore_keys: Optional[List[str]] = None,
41 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
42 | inputs = self._prepare_inputs(inputs)
43 |
44 | with torch.no_grad():
45 | if self.use_amp:
46 | with autocast():
47 | outputs = model(**inputs)
48 | else:
49 | outputs = model(**inputs)
50 | labels = tuple(inputs.get(name) for name in self.label_names)
51 | return outputs, labels
52 |
53 | def prediction_loop(
54 | self,
55 | dataloader: DataLoader,
56 | description: str,
57 | prediction_loss_only: Optional[bool] = None,
58 | ignore_keys: Optional[List[str]] = None,
59 | metric_key_prefix: str = "eval",
60 | ) -> PredictionOutput:
61 | """
62 | Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
63 |
64 | Works both with or without labels.
65 | """
66 | if not isinstance(dataloader.dataset, collections.abc.Sized):
67 | raise ValueError("dataset must implement __len__")
68 | prediction_loss_only = (
69 | prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
70 | )
71 |
72 | if self.args.deepspeed and not self.args.do_train:
73 | # no harm, but flagging to the user that deepspeed config is ignored for eval
74 | # flagging only for when --do_train wasn't passed as only then it's redundant
75 | logger.info("Detected the deepspeed argument but it will not be used for evaluation")
76 |
77 | model = self._wrap_model(self.model, training=False)
78 |
79 | # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
80 | # ``train`` is running, half it first and then put on device
81 | if not self.is_in_train and self.args.fp16_full_eval:
82 | model = model.half().to(self.args.device)
83 |
84 | batch_size = dataloader.batch_size
85 | num_examples = self.num_examples(dataloader)
86 | logger.info("***** Running %s *****", description)
87 | logger.info(" Num examples = %d", num_examples)
88 | logger.info(" Batch size = %d", batch_size)
89 |
90 | model.eval()
91 |
92 | self.callback_handler.eval_dataloader = dataloader
93 |
94 | re_labels = None
95 | pred_relations = None
96 | entities = None
97 | for step, inputs in enumerate(dataloader):
98 | outputs, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
99 | re_labels = labels[1] if re_labels is None else re_labels + labels[1]
100 | pred_relations = (
101 | outputs.pred_relations if pred_relations is None else pred_relations + outputs.pred_relations
102 | )
103 | entities = outputs.entities if entities is None else entities + outputs.entities
104 |
105 | self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
106 |
107 | gt_relations = []
108 | for b in range(len(re_labels)):
109 | rel_sent = []
110 | for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]):
111 | rel = {}
112 | rel["head_id"] = head
113 | rel["head"] = (entities[b]["start"][rel["head_id"]], entities[b]["end"][rel["head_id"]])
114 | rel["head_type"] = entities[b]["label"][rel["head_id"]]
115 |
116 | rel["tail_id"] = tail
117 | rel["tail"] = (entities[b]["start"][rel["tail_id"]], entities[b]["end"][rel["tail_id"]])
118 | rel["tail_type"] = entities[b]["label"][rel["tail_id"]]
119 |
120 | rel["type"] = 1
121 |
122 | rel_sent.append(rel)
123 |
124 | gt_relations.append(rel_sent)
125 |
126 | re_metrics = self.compute_metrics(EvalPrediction(predictions=pred_relations, label_ids=gt_relations))
127 |
128 | re_metrics = {
129 | "precision": re_metrics["ALL"]["p"],
130 | "recall": re_metrics["ALL"]["r"],
131 | "f1": re_metrics["ALL"]["f1"],
132 | }
133 | re_metrics[f"{metric_key_prefix}_loss"] = outputs.loss.mean().item()
134 |
135 | metrics = {}
136 |
137 | # # Prefix all keys with metric_key_prefix + '_'
138 | for key in list(re_metrics.keys()):
139 | if not key.startswith(f"{metric_key_prefix}_"):
140 | metrics[f"{metric_key_prefix}_{key}"] = re_metrics.pop(key)
141 | else:
142 | metrics[f"{key}"] = re_metrics.pop(key)
143 |
144 | return metrics
145 |
146 | def evaluate(
147 | self,
148 | eval_dataset: Optional[Dataset] = None,
149 | ignore_keys: Optional[List[str]] = None,
150 | metric_key_prefix: str = "eval",
151 | ) -> Dict[str, float]:
152 | """
153 | Run evaluation and returns metrics.
154 |
155 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
156 | (pass it to the init :obj:`compute_metrics` argument).
157 |
158 | You can also subclass and override this method to inject custom behavior.
159 |
160 | Args:
161 | eval_dataset (:obj:`Dataset`, `optional`):
162 | Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
163 | columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
164 | :obj:`__len__` method.
165 | ignore_keys (:obj:`Lst[str]`, `optional`):
166 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when
167 | gathering predictions.
168 | metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
169 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
170 | "eval_bleu" if the prefix is "eval" (default)
171 |
172 | Returns:
173 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
174 | dictionary also contains the epoch number which comes from the training state.
175 | """
176 | if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
177 | raise ValueError("eval_dataset must implement __len__")
178 |
179 | self.args.local_rank = -1
180 | eval_dataloader = self.get_eval_dataloader(eval_dataset)
181 | self.args.local_rank = torch.distributed.get_rank()
182 |
183 | start_time = time.time()
184 |
185 | metrics = self.prediction_loop(
186 | eval_dataloader,
187 | description="Evaluation",
188 | # No point gathering the predictions if there are no metrics, otherwise we defer to
189 | # self.args.prediction_loss_only
190 | prediction_loss_only=True if self.compute_metrics is None else None,
191 | ignore_keys=ignore_keys,
192 | metric_key_prefix=metric_key_prefix,
193 | )
194 |
195 | n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset)
196 | metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples))
197 | self.log(metrics)
198 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
199 |
200 | return metrics
201 |
202 | def create_optimizer(self, speedup_r=4.):
203 | if self.optimizer is None:
204 | decay_parameters = get_parameter_names(self.model, [torch.nn.LayerNorm])
205 | decay_parameters = [name for name in decay_parameters if "bias" not in name]
206 | speedup_parameters = [name for name in get_parameter_names(self.model, []) if 'extractor' in name and 'rel_classifier' not in name]
207 | optimizer_grouped_parameters = [
208 | {
209 | "params": [p for n, p in self.model.named_parameters() if n in decay_parameters and n in speedup_parameters],
210 | "weight_decay": self.args.weight_decay,
211 | "lr": self.args.learning_rate *speedup_r,
212 | },
213 | {
214 | "params": [p for n, p in self.model.named_parameters() if n not in decay_parameters and n in speedup_parameters],
215 | "weight_decay": 0.0,
216 | "lr": self.args.learning_rate *speedup_r,
217 | },
218 | {
219 | "params": [p for n, p in self.model.named_parameters() if n in decay_parameters and n not in speedup_parameters],
220 | "weight_decay": self.args.weight_decay,
221 | "lr": self.args.learning_rate,
222 | },
223 | {
224 | "params": [p for n, p in self.model.named_parameters() if n not in decay_parameters and n not in speedup_parameters],
225 | "weight_decay": 0.0,
226 | "lr": self.args.learning_rate,
227 | },
228 | ]
229 | optimizer_cls = Adafactor if self.args.adafactor else AdamW
230 | if self.args.adafactor:
231 | optimizer_cls = Adafactor
232 | optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
233 | else:
234 | optimizer_cls = AdamW
235 | optimizer_kwargs = {
236 | "betas": (self.args.adam_beta1, self.args.adam_beta2),
237 | "eps": self.args.adam_epsilon,
238 | }
239 |
240 | if self.sharded_ddp == ShardedDDPOption.SIMPLE:
241 | self.optimizer = OSS(
242 | params=optimizer_grouped_parameters,
243 | optim=optimizer_cls,
244 | **optimizer_kwargs,
245 | )
246 | else:
247 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
248 |
249 | if is_sagemaker_mp_enabled():
250 | import smdistributed.modelparallel.torch as smp
251 | self.optimizer = smp.DistributedOptimizer(self.optimizer)
252 |
--------------------------------------------------------------------------------
/LiLTfinetune/utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Dict, Optional, Tuple
3 |
4 | import torch
5 |
6 | from transformers.file_utils import ModelOutput
7 |
8 |
9 | @dataclass
10 | class ReOutput(ModelOutput):
11 | loss: Optional[torch.FloatTensor] = None
12 | logits: torch.FloatTensor = None
13 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
14 | attentions: Optional[Tuple[torch.FloatTensor]] = None
15 | entities: Optional[Dict] = None
16 | relations: Optional[Dict] = None
17 | pred_relations: Optional[Dict] = None
18 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: quality style
2 | check_dirs := LiLTfinetune examples
3 | # Check that source code meets quality standards
4 |
5 | quality:
6 | black --check $(check_dirs)
7 | isort --check-only $(check_dirs)
8 | flake8 $(check_dirs)
9 |
10 | # Format source code automatically
11 |
12 | style:
13 | black $(check_dirs)
14 | isort $(check_dirs)
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # What's New
2 |
3 | [2022/10] LiLT has been added to 🤗[huggingface/transformers](https://github.com/huggingface/transformers) in [HERE](https://huggingface.co/docs/transformers/main/model_doc/lilt).
4 |
5 | [2022/03] Initial model and code release.
6 |
7 | # LiLT (ACL 2022)
8 |
9 | This is the official PyTorch implementation of the ACL 2022 paper: "LiLT: A Simple yet Effective Language-Independent Layout Transformer for Structured Document Understanding". [[official](https://aclanthology.org/2022.acl-long.534/)] [[arXiv](https://arxiv.org/abs/2202.13669)]
10 |
11 |
12 |
13 | LiLT is pre-trained on the visually-rich documents of a single language (English) and can be directly fine-tuned on other languages with the corresponding off-the-shelf monolingual/multilingual pre-trained textual models. We hope the public availability of this work can help document intelligence researches.
14 |
15 | ## Installation
16 |
17 | For CUDA 11.X:
18 |
19 | ~~~bash
20 | conda create -n liltfinetune python=3.7
21 | conda activate liltfinetune
22 | conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=11.0 -c pytorch
23 | python -m pip install detectron2==0.5 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu110/torch1.7/index.html
24 | git clone https://github.com/jpWang/LiLT
25 | cd LiLT
26 | pip install -r requirements.txt
27 | pip install -e .
28 | ~~~
29 |
30 | Or check [Detectron2](https://github.com/facebookresearch/detectron2/releases)/[PyTorch](https://pytorch.org/get-started/previous-versions/) versions and modify the command lines accordingly.
31 |
32 | ## Datasets
33 |
34 | In this repository, we provide the fine-tuning codes for [FUNSD](https://guillaumejaume.github.io/FUNSD/) and [XFUND](https://github.com/doc-analysis/XFUND).
35 |
36 | You can download our **pre-processed data (~1.2GB)** from [**HERE**](https://1drv.ms/u/s!Ahd-h7H5akVZeZQvKieg8g5THV8?e=mBRnxw), and put the unzipped `xfund&funsd/` under `LiLT/`.
37 |
38 | ## Available Checkpoints
39 |
40 | | Model | Language | Size | Download |
41 | | ----------------------------- | --------- | ----- | ------------ |
42 | | `lilt-roberta-en-base` | EN | 293MB | [OneDrive](https://1drv.ms/u/s!Ahd-h7H5akVZfhPVHQQ1tOypA48?e=nraHn3) |
43 | | `lilt-infoxlm-base` | MUL | 846MB | [OneDrive](https://1drv.ms/u/s!Ahd-h7H5akVZfeIhAQ8KHELRvcc?e=WS1P82) |
44 | | `lilt-only-base` | None | 21MB | [OneDrive](https://1drv.ms/u/s!Ahd-h7H5akVZfEIRbCmcWKjhoSM?e=6tMGbe) |
45 |
46 | ## Or Generate Your Own Checkpoint (Optional)
47 |
48 | If you want to combine the pre-trained LiLT with **other language's *RoBERTa***, please download `lilt-only-base` and use `gen_weight_roberta_like.py` to generate your own pre-trained checkpoint.
49 |
50 | **For example,** combine `lilt-only-base` with English `roberta-base`:
51 |
52 | ~~~bash
53 | mkdir roberta-en-base
54 | wget https://huggingface.co/roberta-base/resolve/main/config.json -O roberta-en-base/config.json
55 | wget https://huggingface.co/roberta-base/resolve/main/pytorch_model.bin -O roberta-en-base/pytorch_model.bin
56 | python gen_weight_roberta_like.py \
57 | --lilt lilt-only-base/pytorch_model.bin \
58 | --text roberta-en-base/pytorch_model.bin \
59 | --config roberta-en-base/config.json \
60 | --out lilt-roberta-en-base
61 | ~~~
62 |
63 | Or combine `lilt-only-base` with `microsoft/infoxlm-base`:
64 |
65 | ~~~bash
66 | mkdir infoxlm-base
67 | wget https://huggingface.co/microsoft/infoxlm-base/resolve/main/config.json -O infoxlm-base/config.json
68 | wget https://huggingface.co/microsoft/infoxlm-base/resolve/main/pytorch_model.bin -O infoxlm-base/pytorch_model.bin
69 | python gen_weight_roberta_like.py \
70 | --lilt lilt-only-base/pytorch_model.bin \
71 | --text infoxlm-base/pytorch_model.bin \
72 | --config infoxlm-base/config.json \
73 | --out lilt-infoxlm-base
74 | ~~~
75 |
76 |
77 | ## Fine-tuning
78 |
79 |
80 | ### Semantic Entity Recognition on FUNSD
81 |
82 | ```
83 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 examples/run_funsd.py \
84 | --model_name_or_path lilt-roberta-en-base \
85 | --tokenizer_name roberta-base \
86 | --output_dir ser_funsd_lilt-roberta-en-base \
87 | --do_train \
88 | --do_predict \
89 | --max_steps 2000 \
90 | --per_device_train_batch_size 8 \
91 | --warmup_ratio 0.1 \
92 | --fp16
93 | ```
94 |
95 | ### Language-specific (For example, ZH) Semantic Entity Recognition on XFUND
96 |
97 | ```
98 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 examples/run_xfun_ser.py \
99 | --model_name_or_path lilt-infoxlm-base \
100 | --tokenizer_name xlm-roberta-base \
101 | --output_dir ls_ser_xfund_zh_lilt-infoxlm-base \
102 | --do_train \
103 | --do_eval \
104 | --lang zh \
105 | --max_steps 2000 \
106 | --per_device_train_batch_size 16 \
107 | --warmup_ratio 0.1 \
108 | --fp16
109 | ```
110 |
111 | ### Language-specific (For example, ZH) Relation Extraction on XFUND
112 |
113 | ```
114 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 examples/run_xfun_re.py \
115 | --model_name_or_path lilt-infoxlm-base \
116 | --tokenizer_name xlm-roberta-base \
117 | --output_dir ls_re_xfund_zh_lilt-infoxlm-base \
118 | --do_train \
119 | --do_eval \
120 | --lang zh \
121 | --max_steps 5000 \
122 | --per_device_train_batch_size 8 \
123 | --learning_rate 6.25e-6 \
124 | --warmup_ratio 0.1 \
125 | --fp16
126 | ```
127 |
128 | ### Multi-task Semantic Entity Recognition on XFUND
129 |
130 | ```
131 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 examples/run_xfun_ser.py \
132 | --model_name_or_path lilt-infoxlm-base \
133 | --tokenizer_name xlm-roberta-base \
134 | --output_dir mt_ser_xfund_all_lilt-infoxlm-base \
135 | --do_train \
136 | --additional_langs all \
137 | --max_steps 16000 \
138 | --per_device_train_batch_size 16 \
139 | --warmup_ratio 0.1 \
140 | --fp16
141 | ```
142 |
143 | ### Multi-task Relation Extraction on XFUND
144 |
145 | ```
146 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 examples/run_xfun_re.py \
147 | --model_name_or_path lilt-infoxlm-base \
148 | --tokenizer_name xlm-roberta-base \
149 | --output_dir mt_re_xfund_all_lilt-infoxlm-base \
150 | --do_train \
151 | --additional_langs all \
152 | --max_steps 40000 \
153 | --per_device_train_batch_size 8 \
154 | --learning_rate 6.25e-6 \
155 | --warmup_ratio 0.1 \
156 | --fp16
157 | ```
158 |
159 |
160 | ## Results
161 |
162 | ### Semantic Entity Recognition on FUNSD
163 |
164 |
165 | ### Language-specific Fine-tuning on XFUND
166 |
167 |
168 | ### Cross-lingual Zero-shot Transfer on XFUND
169 |
170 |
171 | ### Multitask Fine-tuning on XFUND
172 |
173 |
174 |
175 |
176 | ## Acknowledge
177 |
178 | The repository benefits greatly from [unilm/layoutlmft](https://github.com/microsoft/unilm/tree/master/layoutlmft). Thanks a lot for their excellent work.
179 |
180 | ## Citation
181 | If our paper helps your research, please cite it in your publication(s):
182 | ```
183 | @inproceedings{wang-etal-2022-lilt,
184 | title = "{L}i{LT}: A Simple yet Effective Language-Independent Layout Transformer for Structured Document Understanding",
185 | author={Wang, Jiapeng and Jin, Lianwen and Ding, Kai},
186 | booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
187 | month = may,
188 | year = "2022",
189 | publisher = "Association for Computational Linguistics",
190 | url = "https://aclanthology.org/2022.acl-long.534",
191 | doi = "10.18653/v1/2022.acl-long.534",
192 | pages = "7747--7757",
193 | }
194 | ```
195 |
196 | ## Feedback
197 | Suggestions and discussions are greatly welcome. Please contact the authors by sending email to `eejpwang@mail.scut.edu.cn`.
198 |
--------------------------------------------------------------------------------
/examples/run_funsd.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | with open('tag.txt', 'w') as tagf:
4 | tagf.write('monolingual')
5 | import logging
6 | import os
7 | import sys
8 | from dataclasses import dataclass, field
9 | from typing import Optional
10 |
11 | import numpy as np
12 | from datasets import ClassLabel, load_dataset, load_metric
13 |
14 | import LiLTfinetune.data.datasets.funsd
15 | import transformers
16 | from LiLTfinetune.data import DataCollatorForKeyValueExtraction
17 | from LiLTfinetune.data.data_args import DataTrainingArguments
18 | from LiLTfinetune.models.model_args import ModelArguments
19 | from LiLTfinetune.trainers import FunsdTrainer as Trainer
20 | from transformers import (
21 | AutoConfig,
22 | AutoModelForTokenClassification,
23 | AutoTokenizer,
24 | HfArgumentParser,
25 | PreTrainedTokenizerFast,
26 | TrainingArguments,
27 | set_seed,
28 | )
29 | from transformers.trainer_utils import get_last_checkpoint, is_main_process
30 | from transformers.utils import check_min_version
31 |
32 |
33 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
34 | check_min_version("4.5.0")
35 |
36 | logger = logging.getLogger(__name__)
37 |
38 |
39 | def main():
40 | # See all possible arguments in layoutlmft/transformers/training_args.py
41 | # or by passing the --help flag to this script.
42 | # We now keep distinct sets of args, for a cleaner separation of concerns.
43 |
44 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
45 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
46 | # If we pass only one argument to the script and it's the path to a json file,
47 | # let's parse it to get our arguments.
48 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
49 | else:
50 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
51 |
52 | # Detecting last checkpoint.
53 | last_checkpoint = None
54 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
55 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
56 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
57 | raise ValueError(
58 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
59 | "Use --overwrite_output_dir to overcome."
60 | )
61 | elif last_checkpoint is not None:
62 | logger.info(
63 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
64 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
65 | )
66 |
67 | # Setup logging
68 | logging.basicConfig(
69 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
70 | datefmt="%m/%d/%Y %H:%M:%S",
71 | handlers=[logging.StreamHandler(sys.stdout)],
72 | )
73 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
74 |
75 | # Log on each process the small summary:
76 | logger.warning(
77 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
78 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
79 | )
80 | # Set the verbosity to info of the Transformers logger (on main process only):
81 | if is_main_process(training_args.local_rank):
82 | transformers.utils.logging.set_verbosity_info()
83 | transformers.utils.logging.enable_default_handler()
84 | transformers.utils.logging.enable_explicit_format()
85 | logger.info(f"Training/evaluation parameters {training_args}")
86 |
87 | # Set seed before initializing model.
88 | set_seed(training_args.seed)
89 |
90 | datasets = load_dataset(os.path.abspath(LiLTfinetune.data.datasets.funsd.__file__))
91 |
92 | if training_args.do_train:
93 | column_names = datasets["train"].column_names
94 | features = datasets["train"].features
95 | else:
96 | column_names = datasets["validation"].column_names
97 | features = datasets["validation"].features
98 | text_column_name = "tokens" if "tokens" in column_names else column_names[0]
99 | label_column_name = (
100 | f"{data_args.task_name}_tags" if f"{data_args.task_name}_tags" in column_names else column_names[1]
101 | )
102 |
103 | remove_columns = column_names
104 |
105 | # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
106 | # unique labels.
107 | def get_label_list(labels):
108 | unique_labels = set()
109 | for label in labels:
110 | unique_labels = unique_labels | set(label)
111 | label_list = list(unique_labels)
112 | label_list.sort()
113 | return label_list
114 |
115 | if isinstance(features[label_column_name].feature, ClassLabel):
116 | label_list = features[label_column_name].feature.names
117 | # No need to convert the labels since they are already ints.
118 | label_to_id = {i: i for i in range(len(label_list))}
119 | else:
120 | label_list = get_label_list(datasets["train"][label_column_name])
121 | label_to_id = {l: i for i, l in enumerate(label_list)}
122 | num_labels = len(label_list)
123 |
124 | # Load pretrained model and tokenizer
125 | #
126 | # Distributed training:
127 | # The .from_pretrained methods guarantee that only one local process can concurrently
128 | # download model & vocab.
129 | config = AutoConfig.from_pretrained(
130 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
131 | num_labels=num_labels,
132 | finetuning_task=data_args.task_name,
133 | cache_dir=model_args.cache_dir,
134 | revision=model_args.model_revision,
135 | use_auth_token=True if model_args.use_auth_token else None,
136 | )
137 | tokenizer = AutoTokenizer.from_pretrained(
138 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
139 | cache_dir=model_args.cache_dir,
140 | use_fast=True,
141 | revision=model_args.model_revision,
142 | use_auth_token=True if model_args.use_auth_token else None,
143 | add_prefix_space=True,
144 | )
145 | model = AutoModelForTokenClassification.from_pretrained(
146 | model_args.model_name_or_path,
147 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
148 | config=config,
149 | cache_dir=model_args.cache_dir,
150 | revision=model_args.model_revision,
151 | use_auth_token=True if model_args.use_auth_token else None,
152 | )
153 |
154 | # Tokenizer check: this script requires a fast tokenizer.
155 | if not isinstance(tokenizer, PreTrainedTokenizerFast):
156 | raise ValueError(
157 | "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
158 | "at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this "
159 | "requirement"
160 | )
161 |
162 | # Preprocessing the dataset
163 | # Padding strategy
164 | padding = "max_length" if data_args.pad_to_max_length else False
165 |
166 | # Tokenize all texts and align the labels with them.
167 | def tokenize_and_align_labels(examples):
168 | tokenized_inputs = tokenizer(
169 | examples[text_column_name],
170 | padding=padding,
171 | truncation=True,
172 | return_overflowing_tokens=True,
173 | # We use this argument because the texts in our dataset are lists of words (with a label for each word).
174 | is_split_into_words=True,
175 | )
176 |
177 | labels = []
178 | bboxes = []
179 | images = []
180 | for batch_index in range(len(tokenized_inputs["input_ids"])):
181 | word_ids = tokenized_inputs.word_ids(batch_index=batch_index)
182 | org_batch_index = tokenized_inputs["overflow_to_sample_mapping"][batch_index]
183 |
184 | label = examples[label_column_name][org_batch_index]
185 | bbox = examples["bboxes"][org_batch_index]
186 | image = examples["image"][org_batch_index]
187 | previous_word_idx = None
188 | label_ids = []
189 | bbox_inputs = []
190 | for word_idx in word_ids:
191 | # Special tokens have a word id that is None. We set the label to -100 so they are automatically
192 | # ignored in the loss function.
193 | if word_idx is None:
194 | label_ids.append(-100)
195 | bbox_inputs.append([0, 0, 0, 0])
196 | # We set the label for the first token of each word.
197 | elif word_idx != previous_word_idx:
198 | label_ids.append(label_to_id[label[word_idx]])
199 | bbox_inputs.append(bbox[word_idx])
200 | # For the other tokens in a word, we set the label to either the current label or -100, depending on
201 | # the label_all_tokens flag.
202 | else:
203 | label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100)
204 | bbox_inputs.append(bbox[word_idx])
205 | previous_word_idx = word_idx
206 | labels.append(label_ids)
207 | bboxes.append(bbox_inputs)
208 | images.append(image)
209 | tokenized_inputs["labels"] = labels
210 | tokenized_inputs["bbox"] = bboxes
211 | tokenized_inputs["image"] = images
212 | return tokenized_inputs
213 |
214 | if training_args.do_train:
215 | if "train" not in datasets:
216 | raise ValueError("--do_train requires a train dataset")
217 | train_dataset = datasets["train"]
218 | if data_args.max_train_samples is not None:
219 | train_dataset = train_dataset.select(range(data_args.max_train_samples))
220 | train_dataset = train_dataset.map(
221 | tokenize_and_align_labels,
222 | batched=True,
223 | remove_columns=remove_columns,
224 | num_proc=data_args.preprocessing_num_workers,
225 | load_from_cache_file=not data_args.overwrite_cache,
226 | )
227 |
228 | if training_args.do_eval:
229 | if "validation" not in datasets:
230 | raise ValueError("--do_eval requires a validation dataset")
231 | eval_dataset = datasets["validation"]
232 | if data_args.max_val_samples is not None:
233 | eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
234 | eval_dataset = eval_dataset.map(
235 | tokenize_and_align_labels,
236 | batched=True,
237 | remove_columns=remove_columns,
238 | num_proc=data_args.preprocessing_num_workers,
239 | load_from_cache_file=not data_args.overwrite_cache,
240 | )
241 |
242 | if training_args.do_predict:
243 | if "test" not in datasets:
244 | raise ValueError("--do_predict requires a test dataset")
245 | test_dataset = datasets["test"]
246 | if data_args.max_test_samples is not None:
247 | test_dataset = test_dataset.select(range(data_args.max_test_samples))
248 | test_dataset = test_dataset.map(
249 | tokenize_and_align_labels,
250 | batched=True,
251 | remove_columns=remove_columns,
252 | num_proc=data_args.preprocessing_num_workers,
253 | load_from_cache_file=not data_args.overwrite_cache,
254 | )
255 |
256 | # Data collator
257 | data_collator = DataCollatorForKeyValueExtraction(
258 | tokenizer,
259 | pad_to_multiple_of=8 if training_args.fp16 else None,
260 | padding=padding,
261 | max_length=512,
262 | )
263 |
264 | # Metrics
265 | metric = load_metric("seqeval")
266 |
267 | def compute_metrics(p):
268 | predictions, labels = p
269 | predictions = np.argmax(predictions, axis=2)
270 |
271 | # Remove ignored index (special tokens)
272 | true_predictions = [
273 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
274 | for prediction, label in zip(predictions, labels)
275 | ]
276 | true_labels = [
277 | [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
278 | for prediction, label in zip(predictions, labels)
279 | ]
280 |
281 | results = metric.compute(predictions=true_predictions, references=true_labels)
282 | if data_args.return_entity_level_metrics:
283 | # Unpack nested dictionaries
284 | final_results = {}
285 | for key, value in results.items():
286 | if isinstance(value, dict):
287 | for n, v in value.items():
288 | final_results[f"{key}_{n}"] = v
289 | else:
290 | final_results[key] = value
291 | return final_results
292 | else:
293 | return {
294 | "precision": results["overall_precision"],
295 | "recall": results["overall_recall"],
296 | "f1": results["overall_f1"],
297 | "accuracy": results["overall_accuracy"],
298 | }
299 |
300 | # Initialize our Trainer
301 | trainer = Trainer(
302 | model=model,
303 | args=training_args,
304 | train_dataset=train_dataset if training_args.do_train else None,
305 | eval_dataset=eval_dataset if training_args.do_eval else None,
306 | tokenizer=tokenizer,
307 | data_collator=data_collator,
308 | compute_metrics=compute_metrics,
309 | )
310 |
311 | # Training
312 | if training_args.do_train:
313 | checkpoint = last_checkpoint if last_checkpoint else None
314 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
315 | metrics = train_result.metrics
316 | trainer.save_model() # Saves the tokenizer too for easy upload
317 |
318 | max_train_samples = (
319 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
320 | )
321 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
322 |
323 | trainer.log_metrics("train", metrics)
324 | trainer.save_metrics("train", metrics)
325 | trainer.save_state()
326 |
327 | # Evaluation
328 | if training_args.do_eval:
329 | logger.info("*** Evaluate ***")
330 |
331 | metrics = trainer.evaluate()
332 |
333 | max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
334 | metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
335 |
336 | trainer.log_metrics("eval", metrics)
337 | trainer.save_metrics("eval", metrics)
338 |
339 | # Predict
340 | if training_args.do_predict:
341 | logger.info("*** Predict ***")
342 |
343 | predictions, labels, metrics = trainer.predict(test_dataset)
344 | predictions = np.argmax(predictions, axis=2)
345 |
346 | # Remove ignored index (special tokens)
347 | true_predictions = [
348 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
349 | for prediction, label in zip(predictions, labels)
350 | ]
351 |
352 | trainer.log_metrics("test", metrics)
353 | trainer.save_metrics("test", metrics)
354 |
355 | # Save predictions
356 | output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
357 | if trainer.is_world_process_zero():
358 | with open(output_test_predictions_file, "w") as writer:
359 | for prediction in true_predictions:
360 | writer.write(" ".join(prediction) + "\n")
361 |
362 |
363 | def _mp_fn(index):
364 | # For xla_spawn (TPUs)
365 | main()
366 |
367 |
368 | if __name__ == "__main__":
369 | main()
370 |
--------------------------------------------------------------------------------
/examples/run_xfun_re.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | with open('tag.txt', 'w') as tagf:
4 | tagf.write('multilingual')
5 | import logging
6 | import os
7 | import sys
8 |
9 | import numpy as np
10 | from datasets import ClassLabel, load_dataset
11 |
12 | import LiLTfinetune.data.datasets.xfun
13 | import transformers
14 | from LiLTfinetune import AutoModelForRelationExtraction
15 | from LiLTfinetune.data.data_args import XFUNDataTrainingArguments
16 | from LiLTfinetune.data.data_collator import DataCollatorForKeyValueExtraction
17 | from LiLTfinetune.evaluation import re_score
18 | from LiLTfinetune.models.model_args import ModelArguments
19 | from LiLTfinetune.trainers import XfunReTrainer
20 | from transformers import (
21 | AutoConfig,
22 | AutoTokenizer,
23 | HfArgumentParser,
24 | PreTrainedTokenizerFast,
25 | TrainingArguments,
26 | set_seed,
27 | )
28 | from transformers.trainer_utils import get_last_checkpoint, is_main_process
29 |
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 |
34 | def main():
35 | # See all possible arguments in src/transformers/training_args.py
36 | # or by passing the --help flag to this script.
37 | # We now keep distinct sets of args, for a cleaner separation of concerns.
38 |
39 | parser = HfArgumentParser((ModelArguments, XFUNDataTrainingArguments, TrainingArguments))
40 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
41 | # If we pass only one argument to the script and it's the path to a json file,
42 | # let's parse it to get our arguments.
43 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
44 | else:
45 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
46 |
47 | # Detecting last checkpoint.
48 | last_checkpoint = None
49 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
50 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
51 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
52 | raise ValueError(
53 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
54 | "Use --overwrite_output_dir to overcome."
55 | )
56 | elif last_checkpoint is not None:
57 | logger.info(
58 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
59 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
60 | )
61 |
62 | # Setup logging
63 | logging.basicConfig(
64 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
65 | datefmt="%m/%d/%Y %H:%M:%S",
66 | handlers=[logging.StreamHandler(sys.stdout)],
67 | )
68 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
69 |
70 | # Log on each process the small summary:
71 | logger.warning(
72 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
73 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
74 | )
75 | # Set the verbosity to info of the Transformers logger (on main process only):
76 | if is_main_process(training_args.local_rank):
77 | transformers.utils.logging.set_verbosity_info()
78 | transformers.utils.logging.enable_default_handler()
79 | transformers.utils.logging.enable_explicit_format()
80 | logger.info(f"Training/evaluation parameters {training_args}")
81 |
82 | # Set seed before initializing model.
83 | set_seed(training_args.seed)
84 | datasets = load_dataset(
85 | os.path.abspath(LiLTfinetune.data.datasets.xfun.__file__),
86 | f"xfun.{data_args.lang}",
87 | additional_langs=data_args.additional_langs,
88 | keep_in_memory=True,
89 | )
90 | if training_args.do_train:
91 | column_names = datasets["train"].column_names
92 | features = datasets["train"].features
93 | else:
94 | column_names = datasets["validation"].column_names
95 | features = datasets["validation"].features
96 | text_column_name = "input_ids"
97 | label_column_name = "labels"
98 |
99 | remove_columns = column_names
100 |
101 | # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
102 | # unique labels.
103 | def get_label_list(labels):
104 | unique_labels = set()
105 | for label in labels:
106 | unique_labels = unique_labels | set(label)
107 | label_list = list(unique_labels)
108 | label_list.sort()
109 | return label_list
110 |
111 | if isinstance(features[label_column_name].feature, ClassLabel):
112 | label_list = features[label_column_name].feature.names
113 | # No need to convert the labels since they are already ints.
114 | label_to_id = {i: i for i in range(len(label_list))}
115 | else:
116 | label_list = get_label_list(datasets["train"][label_column_name])
117 | label_to_id = {l: i for i, l in enumerate(label_list)}
118 | num_labels = len(label_list)
119 |
120 | # Load pretrained model and tokenizer
121 | #
122 | # Distributed training:
123 | # The .from_pretrained methods guarantee that only one local process can concurrently
124 | # download model & vocab.
125 | config = AutoConfig.from_pretrained(
126 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
127 | num_labels=num_labels,
128 | finetuning_task=data_args.task_name,
129 | cache_dir=model_args.cache_dir,
130 | revision=model_args.model_revision,
131 | use_auth_token=True if model_args.use_auth_token else None,
132 | )
133 | tokenizer = AutoTokenizer.from_pretrained(
134 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
135 | cache_dir=model_args.cache_dir,
136 | use_fast=True,
137 | revision=model_args.model_revision,
138 | use_auth_token=True if model_args.use_auth_token else None,
139 | )
140 | model = AutoModelForRelationExtraction.from_pretrained(
141 | model_args.model_name_or_path,
142 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
143 | config=config,
144 | cache_dir=model_args.cache_dir,
145 | revision=model_args.model_revision,
146 | use_auth_token=True if model_args.use_auth_token else None,
147 | )
148 |
149 | # Tokenizer check: this script requires a fast tokenizer.
150 | if not isinstance(tokenizer, PreTrainedTokenizerFast):
151 | raise ValueError(
152 | "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
153 | "at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this "
154 | "requirement"
155 | )
156 |
157 | # Preprocessing the dataset
158 | # Padding strategy
159 | padding = "max_length" if data_args.pad_to_max_length else False
160 |
161 | if training_args.do_train:
162 | if "train" not in datasets:
163 | raise ValueError("--do_train requires a train dataset")
164 | train_dataset = datasets["train"]
165 | if data_args.max_train_samples is not None:
166 | train_dataset = train_dataset.select(range(data_args.max_train_samples))
167 |
168 | if training_args.do_eval:
169 | if "validation" not in datasets:
170 | raise ValueError("--do_eval requires a validation dataset")
171 | eval_dataset = datasets["validation"]
172 | if data_args.max_val_samples is not None:
173 | eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
174 |
175 | if training_args.do_predict:
176 | if "test" not in datasets:
177 | raise ValueError("--do_predict requires a test dataset")
178 | test_dataset = datasets["test"]
179 | if data_args.max_test_samples is not None:
180 | test_dataset = test_dataset.select(range(data_args.max_test_samples))
181 |
182 | # Data collator
183 | data_collator = DataCollatorForKeyValueExtraction(
184 | tokenizer,
185 | pad_to_multiple_of=8 if training_args.fp16 else None,
186 | padding=padding,
187 | max_length=512,
188 | )
189 |
190 | def compute_metrics(p):
191 | pred_relations, gt_relations = p
192 | score = re_score(pred_relations, gt_relations, mode="boundaries")
193 | return score
194 |
195 | # Initialize our Trainer
196 | trainer = XfunReTrainer(
197 | model=model,
198 | args=training_args,
199 | train_dataset=train_dataset if training_args.do_train else None,
200 | eval_dataset=eval_dataset if training_args.do_eval else None,
201 | tokenizer=tokenizer,
202 | data_collator=data_collator,
203 | compute_metrics=compute_metrics,
204 | )
205 |
206 | # Training
207 | if training_args.do_train:
208 | checkpoint = last_checkpoint if last_checkpoint else None
209 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
210 | metrics = train_result.metrics
211 | trainer.save_model() # Saves the tokenizer too for easy upload
212 |
213 | max_train_samples = (
214 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
215 | )
216 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
217 |
218 | trainer.log_metrics("train", metrics)
219 | trainer.save_metrics("train", metrics)
220 | trainer.save_state()
221 |
222 | # Evaluation
223 | if training_args.do_eval:
224 | logger.info("*** Evaluate ***")
225 |
226 | metrics = trainer.evaluate()
227 |
228 | max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
229 | metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
230 |
231 | trainer.log_metrics("eval", metrics)
232 | trainer.save_metrics("eval", metrics)
233 |
234 |
235 | def _mp_fn(index):
236 | # For xla_spawn (TPUs)
237 | main()
238 |
239 |
240 | if __name__ == "__main__":
241 | main()
242 |
--------------------------------------------------------------------------------
/examples/run_xfun_ser.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | with open('tag.txt', 'w') as tagf:
4 | tagf.write('multilingual')
5 | import logging
6 | import os
7 | import sys
8 | from dataclasses import dataclass, field
9 | from typing import Optional
10 |
11 | import numpy as np
12 | from datasets import ClassLabel, load_dataset, load_metric
13 |
14 | import LiLTfinetune.data.datasets.xfun
15 | import transformers
16 | from LiLTfinetune.data import DataCollatorForKeyValueExtraction
17 | from LiLTfinetune.data.data_args import XFUNDataTrainingArguments
18 | from LiLTfinetune.models.model_args import ModelArguments
19 | from LiLTfinetune.trainers import XfunSerTrainer
20 | from transformers import (
21 | AutoConfig,
22 | AutoModelForTokenClassification,
23 | AutoTokenizer,
24 | HfArgumentParser,
25 | PreTrainedTokenizerFast,
26 | TrainingArguments,
27 | set_seed,
28 | )
29 | from transformers.trainer_utils import get_last_checkpoint, is_main_process
30 | from transformers.utils import check_min_version
31 |
32 |
33 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
34 | check_min_version("4.5.0")
35 |
36 | logger = logging.getLogger(__name__)
37 |
38 |
39 | def main():
40 |
41 | parser = HfArgumentParser((ModelArguments, XFUNDataTrainingArguments, TrainingArguments))
42 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
43 | # If we pass only one argument to the script and it's the path to a json file,
44 | # let's parse it to get our arguments.
45 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
46 | else:
47 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
48 |
49 | # Detecting last checkpoint.
50 | last_checkpoint = None
51 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
52 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
53 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
54 | raise ValueError(
55 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
56 | "Use --overwrite_output_dir to overcome."
57 | )
58 | elif last_checkpoint is not None:
59 | logger.info(
60 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
61 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
62 | )
63 |
64 | # Setup logging
65 | logging.basicConfig(
66 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
67 | datefmt="%m/%d/%Y %H:%M:%S",
68 | handlers=[logging.StreamHandler(sys.stdout)],
69 | )
70 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
71 |
72 | # Log on each process the small summary:
73 | logger.warning(
74 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
75 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
76 | )
77 | # Set the verbosity to info of the Transformers logger (on main process only):
78 | if is_main_process(training_args.local_rank):
79 | transformers.utils.logging.set_verbosity_info()
80 | transformers.utils.logging.enable_default_handler()
81 | transformers.utils.logging.enable_explicit_format()
82 | logger.info(f"Training/evaluation parameters {training_args}")
83 |
84 | # Set seed before initializing model.
85 | set_seed(training_args.seed)
86 | datasets = load_dataset(
87 | os.path.abspath(LiLTfinetune.data.datasets.xfun.__file__),
88 | f"xfun.{data_args.lang}",
89 | additional_langs=data_args.additional_langs,
90 | keep_in_memory=True,
91 | )
92 | if training_args.do_train:
93 | column_names = datasets["train"].column_names
94 | features = datasets["train"].features
95 | else:
96 | column_names = datasets["validation"].column_names
97 | features = datasets["validation"].features
98 | text_column_name = "input_ids"
99 | label_column_name = "labels"
100 |
101 | remove_columns = column_names
102 |
103 | # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
104 | # unique labels.
105 | def get_label_list(labels):
106 | unique_labels = set()
107 | for label in labels:
108 | unique_labels = unique_labels | set(label)
109 | label_list = list(unique_labels)
110 | label_list.sort()
111 | return label_list
112 |
113 | if isinstance(features[label_column_name].feature, ClassLabel):
114 | label_list = features[label_column_name].feature.names
115 | # No need to convert the labels since they are already ints.
116 | label_to_id = {i: i for i in range(len(label_list))}
117 | else:
118 | label_list = get_label_list(datasets["train"][label_column_name])
119 | label_to_id = {l: i for i, l in enumerate(label_list)}
120 | num_labels = len(label_list)
121 |
122 | # Load pretrained model and tokenizer
123 | #
124 | # Distributed training:
125 | # The .from_pretrained methods guarantee that only one local process can concurrently
126 | # download model & vocab.
127 | config = AutoConfig.from_pretrained(
128 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
129 | num_labels=num_labels,
130 | finetuning_task=data_args.task_name,
131 | cache_dir=model_args.cache_dir,
132 | revision=model_args.model_revision,
133 | use_auth_token=True if model_args.use_auth_token else None,
134 | )
135 | tokenizer = AutoTokenizer.from_pretrained(
136 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
137 | cache_dir=model_args.cache_dir,
138 | use_fast=True,
139 | revision=model_args.model_revision,
140 | use_auth_token=True if model_args.use_auth_token else None,
141 | )
142 | model = AutoModelForTokenClassification.from_pretrained(
143 | model_args.model_name_or_path,
144 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
145 | config=config,
146 | cache_dir=model_args.cache_dir,
147 | revision=model_args.model_revision,
148 | use_auth_token=True if model_args.use_auth_token else None,
149 | )
150 |
151 | # Tokenizer check: this script requires a fast tokenizer.
152 | if not isinstance(tokenizer, PreTrainedTokenizerFast):
153 | raise ValueError(
154 | "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
155 | "at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this "
156 | "requirement"
157 | )
158 |
159 | # Preprocessing the dataset
160 | # Padding strategy
161 | padding = "max_length" if data_args.pad_to_max_length else False
162 |
163 | if training_args.do_train:
164 | if "train" not in datasets:
165 | raise ValueError("--do_train requires a train dataset")
166 | train_dataset = datasets["train"]
167 | if data_args.max_train_samples is not None:
168 | train_dataset = train_dataset.select(range(data_args.max_train_samples))
169 |
170 | if training_args.do_eval:
171 | if "validation" not in datasets:
172 | raise ValueError("--do_eval requires a validation dataset")
173 | eval_dataset = datasets["validation"]
174 | if data_args.max_val_samples is not None:
175 | eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
176 |
177 | if training_args.do_predict:
178 | if "test" not in datasets:
179 | raise ValueError("--do_predict requires a test dataset")
180 | test_dataset = datasets["test"]
181 | if data_args.max_test_samples is not None:
182 | test_dataset = test_dataset.select(range(data_args.max_test_samples))
183 |
184 | # Data collator
185 | data_collator = DataCollatorForKeyValueExtraction(
186 | tokenizer,
187 | pad_to_multiple_of=8 if training_args.fp16 else None,
188 | padding=padding,
189 | max_length=512,
190 | )
191 |
192 | # Metrics
193 | metric = load_metric("seqeval")
194 |
195 | def compute_metrics(p):
196 | predictions, labels = p
197 | predictions = np.argmax(predictions, axis=2)
198 |
199 | # Remove ignored index (special tokens)
200 | true_predictions = [
201 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
202 | for prediction, label in zip(predictions, labels)
203 | ]
204 | true_labels = [
205 | [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
206 | for prediction, label in zip(predictions, labels)
207 | ]
208 |
209 | results = metric.compute(predictions=true_predictions, references=true_labels)
210 | if data_args.return_entity_level_metrics:
211 | # Unpack nested dictionaries
212 | final_results = {}
213 | for key, value in results.items():
214 | if isinstance(value, dict):
215 | for n, v in value.items():
216 | final_results[f"{key}_{n}"] = v
217 | else:
218 | final_results[key] = value
219 | return final_results
220 | else:
221 | return {
222 | "precision": results["overall_precision"],
223 | "recall": results["overall_recall"],
224 | "f1": results["overall_f1"],
225 | "accuracy": results["overall_accuracy"],
226 | }
227 |
228 | # Initialize our Trainer
229 | trainer = XfunSerTrainer(
230 | model=model,
231 | args=training_args,
232 | train_dataset=train_dataset if training_args.do_train else None,
233 | eval_dataset=eval_dataset if training_args.do_eval else None,
234 | tokenizer=tokenizer,
235 | data_collator=data_collator,
236 | compute_metrics=compute_metrics,
237 | )
238 |
239 | # Training
240 | if training_args.do_train:
241 | checkpoint = last_checkpoint if last_checkpoint else None
242 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
243 | metrics = train_result.metrics
244 | trainer.save_model() # Saves the tokenizer too for easy upload
245 |
246 | max_train_samples = (
247 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
248 | )
249 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
250 |
251 | trainer.log_metrics("train", metrics)
252 | trainer.save_metrics("train", metrics)
253 | trainer.save_state()
254 |
255 | # Evaluation
256 | if training_args.do_eval:
257 | logger.info("*** Evaluate ***")
258 |
259 | metrics = trainer.evaluate()
260 |
261 | max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
262 | metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
263 |
264 | trainer.log_metrics("eval", metrics)
265 | trainer.save_metrics("eval", metrics)
266 |
267 | # Predict
268 | if training_args.do_predict:
269 | logger.info("*** Predict ***")
270 |
271 | predictions, labels, metrics = trainer.predict(test_dataset)
272 | predictions = np.argmax(predictions, axis=2)
273 |
274 | # Remove ignored index (special tokens)
275 | true_predictions = [
276 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
277 | for prediction, label in zip(predictions, labels)
278 | ]
279 |
280 | trainer.log_metrics("test", metrics)
281 | trainer.save_metrics("test", metrics)
282 |
283 | # Save predictions
284 | output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
285 | if trainer.is_world_process_zero():
286 | with open(output_test_predictions_file, "w") as writer:
287 | for prediction in true_predictions:
288 | writer.write(" ".join(prediction) + "\n")
289 |
290 |
291 | def _mp_fn(index):
292 | # For xla_spawn (TPUs)
293 | main()
294 |
295 |
296 | if __name__ == "__main__":
297 | main()
298 |
--------------------------------------------------------------------------------
/figs/cl_xfund.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/figs/cl_xfund.png
--------------------------------------------------------------------------------
/figs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/figs/framework.png
--------------------------------------------------------------------------------
/figs/funsd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/figs/funsd.png
--------------------------------------------------------------------------------
/figs/ls_xfund.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/figs/ls_xfund.png
--------------------------------------------------------------------------------
/figs/mt_xfund.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/figs/mt_xfund.png
--------------------------------------------------------------------------------
/gen_weight_roberta_like.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os, json
3 | import argparse
4 | from transformers import AutoConfig, AutoModel
5 |
6 | if __name__ == '__main__':
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--lilt', type=str, required=True, help='Path to LiLT model.')
10 | parser.add_argument('--text', type=str, required=True, help='Path to text model.')
11 | parser.add_argument('--config', type=str, required=True, help='Path to text config.')
12 | parser.add_argument('--out', type=str, required=True, help='Path to output.')
13 | opt = parser.parse_args()
14 |
15 | with open(opt.config, 'r') as jf:
16 | config = json.load(jf)
17 | config['channel_shrink_ratio'] = 4
18 | config['max_2d_position_embeddings'] = 1024
19 | config['model_type'] = 'liltrobertalike'
20 |
21 | if not os.path.isdir(opt.out):
22 | os.makedirs(opt.out)
23 | with open(os.path.join(opt.out, 'config.json'), 'w') as jf:
24 | json.dump(config, jf, sort_keys=True, indent=2, separators=(',', ': '),)
25 |
26 | text_model = torch.load(opt.text)
27 | text_model = {k.replace('roberta.', 'lilt.'): v for (k, v) in text_model.items()}
28 | lilt_model = torch.load(opt.lilt)
29 | total_model = {**text_model, **lilt_model}
30 | torch.save(total_model, os.path.join(opt.out, 'pytorch_model.bin'))
31 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 119
3 | target-version = ['py35']
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets==1.6.2
2 | transformers==4.5.1
3 | seqeval==1.2.2
4 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | default_section = FIRSTPARTY
3 | ensure_newline_before_comments = True
4 | force_grid_wrap = 0
5 | include_trailing_comma = True
6 | known_first_party = LiLT
7 | known_third_party =
8 | datasets
9 | git
10 | h5py
11 | numpy
12 | packaging
13 | PIL
14 | seqeval
15 | torch
16 | torchvision
17 | tqdm
18 |
19 | line_length = 119
20 | lines_after_imports = 2
21 | multi_line_output = 3
22 | use_parentheses = True
23 |
24 | [flake8]
25 | ignore = E203, E501, E741, W503, W605
26 | max-line-length = 119
27 | per-file-ignores = __init__.py:F401
28 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from setuptools import find_packages, setup
3 | setup(
4 | name="LiLTfinetune",
5 | version="1.0",
6 | author="Deep Learning and Vision Computing Lab, SCUT",
7 | url="https://github.com/jpWang/LiLT",
8 | packages=find_packages(),
9 | python_requires=">=3.7",
10 | extras_require={"dev": ["flake8", "isort", "black"]},
11 | )
--------------------------------------------------------------------------------