├── .gitignore ├── LICENSE ├── LiLTfinetune ├── __init__.py ├── data │ ├── __init__.py │ ├── data_args.py │ ├── data_collator.py │ ├── datasets │ │ ├── __init__.py │ │ ├── funsd.py │ │ └── xfun.py │ └── utils.py ├── evaluation.py ├── models │ ├── LiLTRobertaLike │ │ ├── __init__.py │ │ ├── configuration_LiLTRobertaLike.py │ │ ├── modeling_LiLTRobertaLike.py │ │ ├── tokenization_LiLTRobertaLike.py │ │ └── tokenization_LiLTRobertaLike_fast.py │ ├── __init__.py │ └── model_args.py ├── modules │ ├── __init__.py │ └── decoders │ │ ├── __init__.py │ │ └── re.py ├── trainers │ ├── __init__.py │ ├── funsd_trainer.py │ └── xfun_trainer.py └── utils.py ├── Makefile ├── README.md ├── examples ├── run_funsd.py ├── run_xfun_re.py └── run_xfun_ser.py ├── figs ├── cl_xfund.png ├── framework.png ├── funsd.png ├── ls_xfund.png └── mt_xfund.png ├── gen_weight_roberta_like.py ├── pyproject.toml ├── requirements.txt ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python,macos,pycharm+all 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,macos,pycharm+all 4 | 5 | ### macOS ### 6 | # General 7 | .DS_Store 8 | .AppleDouble 9 | .LSOverride 10 | 11 | # Icon must end with two \r 12 | Icon 13 | 14 | 15 | # Thumbnails 16 | ._* 17 | 18 | # Files that might appear in the root of a volume 19 | .DocumentRevisions-V100 20 | .fseventsd 21 | .Spotlight-V100 22 | .TemporaryItems 23 | .Trashes 24 | .VolumeIcon.icns 25 | .com.apple.timemachine.donotpresent 26 | 27 | # Directories potentially created on remote AFP share 28 | .AppleDB 29 | .AppleDesktop 30 | Network Trash Folder 31 | Temporary Items 32 | .apdisk 33 | 34 | ### PyCharm+all ### 35 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 36 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 37 | 38 | # User-specific stuff 39 | .idea/**/workspace.xml 40 | .idea/**/tasks.xml 41 | .idea/**/usage.statistics.xml 42 | .idea/**/dictionaries 43 | .idea/**/shelf 44 | 45 | # Generated files 46 | .idea/**/contentModel.xml 47 | 48 | # Sensitive or high-churn files 49 | .idea/**/dataSources/ 50 | .idea/**/dataSources.ids 51 | .idea/**/dataSources.local.xml 52 | .idea/**/sqlDataSources.xml 53 | .idea/**/dynamic.xml 54 | .idea/**/uiDesigner.xml 55 | .idea/**/dbnavigator.xml 56 | 57 | # Gradle 58 | .idea/**/gradle.xml 59 | .idea/**/libraries 60 | 61 | # Gradle and Maven with auto-import 62 | # When using Gradle or Maven with auto-import, you should exclude module files, 63 | # since they will be recreated, and may cause churn. Uncomment if using 64 | # auto-import. 65 | # .idea/artifacts 66 | # .idea/compiler.xml 67 | # .idea/jarRepositories.xml 68 | # .idea/modules.xml 69 | # .idea/*.iml 70 | # .idea/modules 71 | # *.iml 72 | # *.ipr 73 | 74 | # CMake 75 | cmake-build-*/ 76 | 77 | # Mongo Explorer plugin 78 | .idea/**/mongoSettings.xml 79 | 80 | # File-based project format 81 | *.iws 82 | 83 | # IntelliJ 84 | out/ 85 | 86 | # mpeltonen/sbt-idea plugin 87 | .idea_modules/ 88 | 89 | # JIRA plugin 90 | atlassian-ide-plugin.xml 91 | 92 | # Cursive Clojure plugin 93 | .idea/replstate.xml 94 | 95 | # Crashlytics plugin (for Android Studio and IntelliJ) 96 | com_crashlytics_export_strings.xml 97 | crashlytics.properties 98 | crashlytics-build.properties 99 | fabric.properties 100 | 101 | # Editor-based Rest Client 102 | .idea/httpRequests 103 | 104 | # Android studio 3.1+ serialized cache file 105 | .idea/caches/build_file_checksums.ser 106 | 107 | ### PyCharm+all Patch ### 108 | # Ignores the whole .idea folder and all .iml files 109 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 110 | 111 | .idea/ 112 | 113 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 114 | 115 | *.iml 116 | modules.xml 117 | .idea/misc.xml 118 | *.ipr 119 | 120 | # Sonarlint plugin 121 | .idea/sonarlint 122 | 123 | ### Python ### 124 | # Byte-compiled / optimized / DLL files 125 | __pycache__/ 126 | *.py[cod] 127 | *$py.class 128 | 129 | # C extensions 130 | *.so 131 | 132 | # Distribution / packaging 133 | .Python 134 | build/ 135 | develop-eggs/ 136 | dist/ 137 | downloads/ 138 | eggs/ 139 | .eggs/ 140 | parts/ 141 | sdist/ 142 | var/ 143 | wheels/ 144 | pip-wheel-metadata/ 145 | share/python-wheels/ 146 | *.egg-info/ 147 | .installed.cfg 148 | *.egg 149 | MANIFEST 150 | 151 | # PyInstaller 152 | # Usually these files are written by a python script from a template 153 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 154 | *.manifest 155 | *.spec 156 | 157 | # Installer logs 158 | pip-log.txt 159 | pip-delete-this-directory.txt 160 | 161 | # Unit test / coverage reports 162 | htmlcov/ 163 | .tox/ 164 | .nox/ 165 | .coverage 166 | .coverage.* 167 | .cache 168 | nosetests.xml 169 | coverage.xml 170 | *.cover 171 | *.py,cover 172 | .hypothesis/ 173 | .pytest_cache/ 174 | pytestdebug.log 175 | 176 | # Translations 177 | *.mo 178 | *.pot 179 | 180 | # Django stuff: 181 | *.log 182 | local_settings.py 183 | db.sqlite3 184 | db.sqlite3-journal 185 | 186 | # Flask stuff: 187 | instance/ 188 | .webassets-cache 189 | 190 | # Scrapy stuff: 191 | .scrapy 192 | 193 | # Sphinx documentation 194 | docs/_build/ 195 | doc/_build/ 196 | 197 | # PyBuilder 198 | target/ 199 | 200 | # Jupyter Notebook 201 | .ipynb_checkpoints 202 | 203 | # IPython 204 | profile_default/ 205 | ipython_config.py 206 | 207 | # pyenv 208 | .python-version 209 | 210 | # pipenv 211 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 212 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 213 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 214 | # install all needed dependencies. 215 | #Pipfile.lock 216 | 217 | # poetry 218 | #poetry.lock 219 | 220 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 221 | __pypackages__/ 222 | 223 | # Celery stuff 224 | celerybeat-schedule 225 | celerybeat.pid 226 | 227 | # SageMath parsed files 228 | *.sage.py 229 | 230 | # Environments 231 | # .env 232 | .env/ 233 | .venv/ 234 | env/ 235 | venv/ 236 | ENV/ 237 | env.bak/ 238 | venv.bak/ 239 | pythonenv* 240 | 241 | # Spyder project settings 242 | .spyderproject 243 | .spyproject 244 | 245 | # Rope project settings 246 | .ropeproject 247 | 248 | # mkdocs documentation 249 | /site 250 | 251 | # mypy 252 | .mypy_cache/ 253 | .dmypy.json 254 | dmypy.json 255 | 256 | # Pyre type checker 257 | .pyre/ 258 | 259 | # pytype static type analyzer 260 | .pytype/ 261 | 262 | # operating system-related files 263 | # file properties cache/storage on macOS 264 | *.DS_Store 265 | # thumbnail cache on Windows 266 | Thumbs.db 267 | 268 | # profiling data 269 | .prof 270 | 271 | LiLTfinetune/data/datasets/*.lock 272 | # End of https://www.toptal.com/developers/gitignore/api/python,macos,pycharm+all 273 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Jiapeng Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LiLTfinetune/__init__.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from transformers import CONFIG_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_NAMES_MAPPING, TOKENIZER_MAPPING 4 | from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, BertConverter, RobertaConverter, XLMRobertaConverter 5 | from transformers.models.auto.modeling_auto import auto_class_factory 6 | 7 | from .models.LiLTRobertaLike import ( 8 | LiLTRobertaLikeConfig, 9 | LiLTRobertaLikeForRelationExtraction, 10 | LiLTRobertaLikeForTokenClassification, 11 | LiLTRobertaLikeTokenizer, 12 | LiLTRobertaLikeTokenizerFast, 13 | ) 14 | 15 | CONFIG_MAPPING.update([("liltrobertalike", LiLTRobertaLikeConfig),]) 16 | MODEL_NAMES_MAPPING.update([("liltrobertalike", "LiLTRobertaLike"),]) 17 | TOKENIZER_MAPPING.update( 18 | [ 19 | (LiLTRobertaLikeConfig, (LiLTRobertaLikeTokenizer, LiLTRobertaLikeTokenizerFast)), 20 | ] 21 | ) 22 | 23 | with open('tag.txt', 'r') as tagf: 24 | TAG = tagf.read().lower() 25 | assert TAG == 'monolingual' or TAG == 'multilingual', 'TAG is wrong. It should be monolingual or multilingual.' 26 | if TAG == 'monolingual': 27 | SLOW_TO_FAST_CONVERTERS.update({"LiLTRobertaLikeTokenizer": RobertaConverter,}) 28 | elif TAG == 'multilingual': 29 | SLOW_TO_FAST_CONVERTERS.update({"LiLTRobertaLikeTokenizer": XLMRobertaConverter,}) 30 | 31 | MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.update( 32 | [(LiLTRobertaLikeConfig, LiLTRobertaLikeForTokenClassification),] 33 | ) 34 | 35 | MODEL_FOR_RELATION_EXTRACTION_MAPPING = OrderedDict( 36 | [(LiLTRobertaLikeConfig, LiLTRobertaLikeForRelationExtraction),] 37 | ) 38 | 39 | AutoModelForTokenClassification = auto_class_factory( 40 | "AutoModelForTokenClassification", MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification" 41 | ) 42 | 43 | AutoModelForRelationExtraction = auto_class_factory( 44 | "AutoModelForRelationExtraction", MODEL_FOR_RELATION_EXTRACTION_MAPPING, head_doc="relation extraction" 45 | ) 46 | -------------------------------------------------------------------------------- /LiLTfinetune/data/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .data_collator import DataCollatorForKeyValueExtraction 3 | from .datasets import * 4 | -------------------------------------------------------------------------------- /LiLTfinetune/data/data_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class DataTrainingArguments: 7 | """ 8 | Arguments pertaining to what data we are going to input our model for training and eval. 9 | """ 10 | 11 | task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."}) 12 | dataset_name: Optional[str] = field( 13 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 14 | ) 15 | dataset_config_name: Optional[str] = field( 16 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 17 | ) 18 | train_file: Optional[str] = field( 19 | default=None, metadata={"help": "The input training data file (a csv or JSON file)."} 20 | ) 21 | validation_file: Optional[str] = field( 22 | default=None, 23 | metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."}, 24 | ) 25 | test_file: Optional[str] = field( 26 | default=None, 27 | metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."}, 28 | ) 29 | overwrite_cache: bool = field( 30 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 31 | ) 32 | preprocessing_num_workers: Optional[int] = field( 33 | default=None, 34 | metadata={"help": "The number of processes to use for the preprocessing."}, 35 | ) 36 | pad_to_max_length: bool = field( 37 | default=True, 38 | metadata={ 39 | "help": "Whether to pad all samples to model maximum sentence length. " 40 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 41 | "efficient on GPU but very bad for TPU." 42 | }, 43 | ) 44 | max_train_samples: Optional[int] = field( 45 | default=None, 46 | metadata={ 47 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 48 | "value if set." 49 | }, 50 | ) 51 | max_val_samples: Optional[int] = field( 52 | default=None, 53 | metadata={ 54 | "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " 55 | "value if set." 56 | }, 57 | ) 58 | max_test_samples: Optional[int] = field( 59 | default=None, 60 | metadata={ 61 | "help": "For debugging purposes or quicker training, truncate the number of test examples to this " 62 | "value if set." 63 | }, 64 | ) 65 | label_all_tokens: bool = field( 66 | default=False, 67 | metadata={ 68 | "help": "Whether to put the label for one word on all tokens of generated by that word or just on the " 69 | "one (in which case the other tokens will have a padding index)." 70 | }, 71 | ) 72 | return_entity_level_metrics: bool = field( 73 | default=False, 74 | metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."}, 75 | ) 76 | 77 | 78 | @dataclass 79 | class XFUNDataTrainingArguments(DataTrainingArguments): 80 | lang: Optional[str] = field(default="en") 81 | additional_langs: Optional[str] = field(default=None) 82 | -------------------------------------------------------------------------------- /LiLTfinetune/data/data_collator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Union 3 | 4 | import torch 5 | 6 | from detectron2.structures import ImageList 7 | from transformers import PreTrainedTokenizerBase 8 | from transformers.file_utils import PaddingStrategy 9 | 10 | 11 | @dataclass 12 | class DataCollatorForKeyValueExtraction: 13 | """ 14 | Data collator that will dynamically pad the inputs received, as well as the labels. 15 | 16 | Args: 17 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): 18 | The tokenizer used for encoding the data. 19 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): 20 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 21 | among: 22 | 23 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 24 | sequence if provided). 25 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 26 | maximum acceptable input length for the model if that argument is not provided. 27 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 28 | different lengths). 29 | max_length (:obj:`int`, `optional`): 30 | Maximum length of the returned list and optionally padding length (see above). 31 | pad_to_multiple_of (:obj:`int`, `optional`): 32 | If set will pad the sequence to a multiple of the provided value. 33 | 34 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 35 | 7.5 (Volta). 36 | label_pad_token_id (:obj:`int`, `optional`, defaults to -100): 37 | The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions). 38 | """ 39 | 40 | tokenizer: PreTrainedTokenizerBase 41 | padding: Union[bool, str, PaddingStrategy] = True 42 | max_length: Optional[int] = None 43 | pad_to_multiple_of: Optional[int] = None 44 | label_pad_token_id: int = -100 45 | 46 | def __call__(self, features): 47 | label_name = "label" if "label" in features[0].keys() else "labels" 48 | labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None 49 | 50 | has_image_input = "image" in features[0] 51 | has_bbox_input = "bbox" in features[0] 52 | if has_image_input: 53 | image = ImageList.from_tensors([torch.tensor(feature["image"]) for feature in features], 32) 54 | for feature in features: 55 | del feature["image"] 56 | batch = self.tokenizer.pad( 57 | features, 58 | padding=self.padding, 59 | max_length=self.max_length, 60 | pad_to_multiple_of=self.pad_to_multiple_of, 61 | # Conversion to tensors will fail if we have labels as they are not of the same length yet. 62 | return_tensors="pt" if labels is None else None, 63 | ) 64 | 65 | if labels is None: 66 | return batch 67 | 68 | sequence_length = torch.tensor(batch["input_ids"]).shape[1] 69 | padding_side = self.tokenizer.padding_side 70 | if padding_side == "right": 71 | batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels] 72 | if has_bbox_input: 73 | batch["bbox"] = [bbox + [[0, 0, 0, 0]] * (sequence_length - len(bbox)) for bbox in batch["bbox"]] 74 | else: 75 | batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels] 76 | if has_bbox_input: 77 | batch["bbox"] = [[[0, 0, 0, 0]] * (sequence_length - len(bbox)) + bbox for bbox in batch["bbox"]] 78 | 79 | batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) else v for k, v in batch.items()} 80 | if has_image_input: 81 | batch["image"] = image 82 | return batch 83 | -------------------------------------------------------------------------------- /LiLTfinetune/data/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/LiLTfinetune/data/datasets/__init__.py -------------------------------------------------------------------------------- /LiLTfinetune/data/datasets/funsd.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import json 4 | import os 5 | 6 | import datasets 7 | 8 | from LiLTfinetune.data.utils import load_image, normalize_bbox 9 | 10 | 11 | logger = datasets.logging.get_logger(__name__) 12 | 13 | 14 | _CITATION = """\ 15 | @article{Jaume2019FUNSDAD, 16 | title={FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents}, 17 | author={Guillaume Jaume and H. K. Ekenel and J. Thiran}, 18 | journal={2019 International Conference on Document Analysis and Recognition Workshops (ICDARW)}, 19 | year={2019}, 20 | volume={2}, 21 | pages={1-6} 22 | } 23 | """ 24 | 25 | _DESCRIPTION = """\ 26 | https://guillaumejaume.github.io/FUNSD/ 27 | """ 28 | 29 | 30 | class FunsdConfig(datasets.BuilderConfig): 31 | """BuilderConfig for FUNSD""" 32 | 33 | def __init__(self, **kwargs): 34 | """BuilderConfig for FUNSD. 35 | 36 | Args: 37 | **kwargs: keyword arguments forwarded to super. 38 | """ 39 | super(FunsdConfig, self).__init__(**kwargs) 40 | 41 | 42 | class Funsd(datasets.GeneratorBasedBuilder): 43 | """Conll2003 dataset.""" 44 | 45 | BUILDER_CONFIGS = [ 46 | FunsdConfig(name="funsd", version=datasets.Version("1.0.0"), description="FUNSD dataset"), 47 | ] 48 | 49 | def _info(self): 50 | return datasets.DatasetInfo( 51 | description=_DESCRIPTION, 52 | features=datasets.Features( 53 | { 54 | "id": datasets.Value("string"), 55 | "tokens": datasets.Sequence(datasets.Value("string")), 56 | "bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))), 57 | "ner_tags": datasets.Sequence( 58 | datasets.features.ClassLabel( 59 | names=["O", "B-HEADER", "I-HEADER", "B-QUESTION", "I-QUESTION", "B-ANSWER", "I-ANSWER"] 60 | ) 61 | ), 62 | "image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"), 63 | } 64 | ), 65 | supervised_keys=None, 66 | homepage="https://guillaumejaume.github.io/FUNSD/", 67 | citation=_CITATION, 68 | ) 69 | 70 | def _split_generators(self, dl_manager): 71 | """Returns SplitGenerators.""" 72 | downloaded_file = dl_manager.download_and_extract("https://guillaumejaume.github.io/FUNSD/dataset.zip") 73 | return [ 74 | datasets.SplitGenerator( 75 | name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"} 76 | ), 77 | datasets.SplitGenerator( 78 | name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"} 79 | ), 80 | ] 81 | 82 | def _generate_examples(self, filepath): 83 | logger.info("⏳ Generating examples from = %s", filepath) 84 | ann_dir = os.path.join(filepath, "annotations") 85 | img_dir = os.path.join(filepath, "images") 86 | for guid, file in enumerate(sorted(os.listdir(ann_dir))): 87 | tokens = [] 88 | bboxes = [] 89 | ner_tags = [] 90 | 91 | file_path = os.path.join(ann_dir, file) 92 | with open(file_path, "r", encoding="utf8") as f: 93 | data = json.load(f) 94 | image_path = os.path.join(img_dir, file) 95 | image_path = image_path.replace("json", "png") 96 | image, size = load_image(image_path) 97 | for item in data["form"]: 98 | words, label = item["words"], item["label"] 99 | words = [w for w in words if w["text"].strip() != ""] 100 | if len(words) == 0: 101 | continue 102 | if label == "other": 103 | for w in words: 104 | tokens.append(w["text"]) 105 | ner_tags.append("O") 106 | bboxes.append(normalize_bbox(item["box"], size)) 107 | else: 108 | tokens.append(words[0]["text"]) 109 | ner_tags.append("B-" + label.upper()) 110 | bboxes.append(normalize_bbox(item["box"], size)) 111 | for w in words[1:]: 112 | tokens.append(w["text"]) 113 | ner_tags.append("I-" + label.upper()) 114 | bboxes.append(normalize_bbox(item["box"], size)) 115 | 116 | yield guid, {"id": str(guid), "tokens": tokens, "bboxes": bboxes, "ner_tags": ner_tags, "image": image} 117 | -------------------------------------------------------------------------------- /LiLTfinetune/data/datasets/xfun.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | import json 3 | import logging 4 | import os 5 | 6 | import datasets 7 | 8 | from LiLTfinetune.data.utils import load_image, merge_bbox, normalize_bbox, simplify_bbox 9 | from transformers import AutoTokenizer 10 | 11 | 12 | _URL = "https://github.com/doc-analysis/XFUN/releases/download/v1.0/" 13 | 14 | _LANG = ["zh", "de", "es", "fr", "en", "it", "ja", "pt"] 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class XFUNConfig(datasets.BuilderConfig): 19 | """BuilderConfig for XFUN.""" 20 | 21 | def __init__(self, lang, additional_langs=None, **kwargs): 22 | """ 23 | Args: 24 | lang: string, language for the input text 25 | **kwargs: keyword arguments forwarded to super. 26 | """ 27 | super(XFUNConfig, self).__init__(**kwargs) 28 | self.lang = lang 29 | self.additional_langs = additional_langs 30 | 31 | 32 | class XFUN(datasets.GeneratorBasedBuilder): 33 | """XFUN dataset.""" 34 | 35 | BUILDER_CONFIGS = [XFUNConfig(name=f"xfun.{lang}", lang=lang) for lang in _LANG] 36 | 37 | tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") 38 | 39 | def _info(self): 40 | return datasets.DatasetInfo( 41 | features=datasets.Features( 42 | { 43 | "id": datasets.Value("string"), 44 | "input_ids": datasets.Sequence(datasets.Value("int64")), 45 | "bbox": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))), 46 | "labels": datasets.Sequence( 47 | datasets.ClassLabel( 48 | names=["O", "B-QUESTION", "B-ANSWER", "B-HEADER", "I-ANSWER", "I-QUESTION", "I-HEADER"] 49 | ) 50 | ), 51 | "image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"), 52 | "entities": datasets.Sequence( 53 | { 54 | "start": datasets.Value("int64"), 55 | "end": datasets.Value("int64"), 56 | "label": datasets.ClassLabel(names=["HEADER", "QUESTION", "ANSWER"]), 57 | } 58 | ), 59 | "relations": datasets.Sequence( 60 | { 61 | "head": datasets.Value("int64"), 62 | "tail": datasets.Value("int64"), 63 | "start_index": datasets.Value("int64"), 64 | "end_index": datasets.Value("int64"), 65 | } 66 | ), 67 | } 68 | ), 69 | supervised_keys=None, 70 | ) 71 | 72 | def _split_generators(self, dl_manager): 73 | """Returns SplitGenerators.""" 74 | # urls_to_download = { 75 | # "train": [f"{_URL}{self.config.lang}.train.json", f"{_URL}{self.config.lang}.train.zip"], 76 | # "val": [f"{_URL}{self.config.lang}.val.json", f"{_URL}{self.config.lang}.val.zip"], 77 | # # "test": [f"{_URL}{self.config.lang}.test.json", f"{_URL}{self.config.lang}.test.zip"], 78 | # } 79 | # downloaded_files = dl_manager.download_and_extract(urls_to_download) 80 | # train_files_for_many_langs = [downloaded_files["train"]] 81 | # val_files_for_many_langs = [downloaded_files["val"]] 82 | # # test_files_for_many_langs = [downloaded_files["test"]] 83 | file_dir = 'xfund&funsd/' 84 | train_files_for_many_langs = [[file_dir+f"{self.config.lang}.train.json", file_dir+f"{self.config.lang}"]] 85 | val_files_for_many_langs = [[file_dir+f"{self.config.lang}.val.json", file_dir+f"{self.config.lang}"]] 86 | 87 | if self.config.additional_langs: 88 | additional_langs = self.config.additional_langs.split("+") 89 | if "all" in additional_langs: 90 | additional_langs = [lang for lang in _LANG if lang != self.config.lang] 91 | for lang in additional_langs: 92 | # urls_to_download = {"train": [f"{_URL}{lang}.train.json", f"{_URL}{lang}.train.zip"]} 93 | # additional_downloaded_files = dl_manager.download_and_extract(urls_to_download) 94 | # train_files_for_many_langs.append(additional_downloaded_files["train"]) 95 | train_files_for_many_langs.append([file_dir+f"{lang}.train.json", file_dir+f"{lang}"]) 96 | 97 | 98 | logger.info(f"Training on {self.config.lang} with additional langs({self.config.additional_langs})") 99 | logger.info(f"Evaluating on {self.config.lang}") 100 | logger.info(f"Testing on {self.config.lang}") 101 | return [ 102 | datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": train_files_for_many_langs}), 103 | datasets.SplitGenerator( 104 | name=datasets.Split.VALIDATION, gen_kwargs={"filepaths": val_files_for_many_langs} 105 | ), 106 | # datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": test_files_for_many_langs}), 107 | ] 108 | 109 | def _generate_examples(self, filepaths): 110 | for filepath in filepaths: 111 | logger.info("Generating examples from = %s", filepath) 112 | with open(filepath[0], "r") as f: 113 | data = json.load(f) 114 | 115 | for doc in data["documents"]: 116 | doc["img"]["fpath"] = os.path.join(filepath[1], doc["img"]["fname"]) 117 | image, size = load_image(doc["img"]["fpath"]) 118 | document = doc["document"] 119 | tokenized_doc = {"input_ids": [], "bbox": [], "labels": []} 120 | entities = [] 121 | relations = [] 122 | id2label = {} 123 | entity_id_to_index_map = {} 124 | empty_entity = set() 125 | for line in document: 126 | if len(line["text"]) == 0: 127 | empty_entity.add(line["id"]) 128 | continue 129 | id2label[line["id"]] = line["label"] 130 | relations.extend([tuple(sorted(l)) for l in line["linking"]]) 131 | if '/en' in filepath[0]: 132 | tokenized_inputs = self.tokenizer( 133 | ' '.join([q['text'].replace(u'\uf703','') for q in line['words']]), 134 | add_special_tokens=False, 135 | return_offsets_mapping=True, 136 | return_attention_mask=False, 137 | ) 138 | else: 139 | tokenized_inputs = self.tokenizer( 140 | line["text"], 141 | add_special_tokens=False, 142 | return_offsets_mapping=True, 143 | return_attention_mask=False, 144 | ) 145 | text_length = 0 146 | ocr_length = 0 147 | bbox = [] 148 | last_box = None 149 | for token_id, offset in zip(tokenized_inputs["input_ids"], tokenized_inputs["offset_mapping"]): 150 | if token_id == 6: 151 | bbox.append(None) 152 | continue 153 | text_length += offset[1] - offset[0] 154 | tmp_box = [] 155 | while ocr_length < text_length: 156 | ocr_word = line["words"].pop(0) 157 | ocr_length += len( 158 | self.tokenizer._tokenizer.normalizer.normalize_str(ocr_word["text"].strip()) 159 | ) 160 | tmp_box.append(simplify_bbox(line["box"])) 161 | if len(tmp_box) == 0: 162 | tmp_box = last_box 163 | bbox.append(normalize_bbox(merge_bbox(tmp_box), size)) 164 | last_box = tmp_box 165 | bbox = [ 166 | [bbox[i + 1][0], bbox[i + 1][1], bbox[i + 1][0], bbox[i + 1][1]] if b is None else b 167 | for i, b in enumerate(bbox) 168 | ] 169 | if line["label"] == "other": 170 | label = ["O"] * len(bbox) 171 | else: 172 | label = [f"I-{line['label'].upper()}"] * len(bbox) 173 | label[0] = f"B-{line['label'].upper()}" 174 | tokenized_inputs.update({"bbox": bbox, "labels": label}) 175 | if label[0] != "O": 176 | entity_id_to_index_map[line["id"]] = len(entities) 177 | entities.append( 178 | { 179 | "start": len(tokenized_doc["input_ids"]), 180 | "end": len(tokenized_doc["input_ids"]) + len(tokenized_inputs["input_ids"]), 181 | "label": line["label"].upper(), 182 | } 183 | ) 184 | for i in tokenized_doc: 185 | tokenized_doc[i] = tokenized_doc[i] + tokenized_inputs[i] 186 | relations = list(set(relations)) 187 | relations = [rel for rel in relations if rel[0] not in empty_entity and rel[1] not in empty_entity] 188 | kvrelations = [] 189 | for rel in relations: 190 | pair = [id2label[rel[0]], id2label[rel[1]]] 191 | if pair == ["question", "answer"]: 192 | kvrelations.append( 193 | {"head": entity_id_to_index_map[rel[0]], "tail": entity_id_to_index_map[rel[1]]} 194 | ) 195 | elif pair == ["answer", "question"]: 196 | kvrelations.append( 197 | {"head": entity_id_to_index_map[rel[1]], "tail": entity_id_to_index_map[rel[0]]} 198 | ) 199 | else: 200 | continue 201 | 202 | def get_relation_span(rel): 203 | bound = [] 204 | for entity_index in [rel["head"], rel["tail"]]: 205 | bound.append(entities[entity_index]["start"]) 206 | bound.append(entities[entity_index]["end"]) 207 | return min(bound), max(bound) 208 | 209 | relations = sorted( 210 | [ 211 | { 212 | "head": rel["head"], 213 | "tail": rel["tail"], 214 | "start_index": get_relation_span(rel)[0], 215 | "end_index": get_relation_span(rel)[1], 216 | } 217 | for rel in kvrelations 218 | ], 219 | key=lambda x: x["head"], 220 | ) 221 | chunk_size = 512 222 | for chunk_id, index in enumerate(range(0, len(tokenized_doc["input_ids"]), chunk_size)): 223 | item = {} 224 | for k in tokenized_doc: 225 | item[k] = tokenized_doc[k][index : index + chunk_size] 226 | entities_in_this_span = [] 227 | global_to_local_map = {} 228 | for entity_id, entity in enumerate(entities): 229 | if ( 230 | index <= entity["start"] < index + chunk_size 231 | and index <= entity["end"] < index + chunk_size 232 | ): 233 | entity["start"] = entity["start"] - index 234 | entity["end"] = entity["end"] - index 235 | global_to_local_map[entity_id] = len(entities_in_this_span) 236 | entities_in_this_span.append(entity) 237 | relations_in_this_span = [] 238 | for relation in relations: 239 | if ( 240 | index <= relation["start_index"] < index + chunk_size 241 | and index <= relation["end_index"] < index + chunk_size 242 | ): 243 | relations_in_this_span.append( 244 | { 245 | "head": global_to_local_map[relation["head"]], 246 | "tail": global_to_local_map[relation["tail"]], 247 | "start_index": relation["start_index"] - index, 248 | "end_index": relation["end_index"] - index, 249 | } 250 | ) 251 | item.update( 252 | { 253 | "id": f"{doc['id']}_{chunk_id}", 254 | "image": image, 255 | "entities": entities_in_this_span, 256 | "relations": relations_in_this_span, 257 | } 258 | ) 259 | yield f"{doc['id']}_{chunk_id}", item 260 | -------------------------------------------------------------------------------- /LiLTfinetune/data/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from detectron2.data.detection_utils import read_image 4 | from detectron2.data.transforms import ResizeTransform, TransformList 5 | 6 | 7 | def normalize_bbox(bbox, size): 8 | return [ 9 | int(1000 * bbox[0] / size[0]), 10 | int(1000 * bbox[1] / size[1]), 11 | int(1000 * bbox[2] / size[0]), 12 | int(1000 * bbox[3] / size[1]), 13 | ] 14 | 15 | 16 | def simplify_bbox(bbox): 17 | return [ 18 | min(bbox[0::2]), 19 | min(bbox[1::2]), 20 | max(bbox[2::2]), 21 | max(bbox[3::2]), 22 | ] 23 | 24 | 25 | def merge_bbox(bbox_list): 26 | x0, y0, x1, y1 = list(zip(*bbox_list)) 27 | return [min(x0), min(y0), max(x1), max(y1)] 28 | 29 | 30 | def load_image(image_path): 31 | image = read_image(image_path, format="BGR") 32 | h = image.shape[0] 33 | w = image.shape[1] 34 | img_trans = TransformList([ResizeTransform(h=h, w=w, new_h=224, new_w=224)]) 35 | image = torch.tensor(img_trans.apply_image(image).copy()).permute(2, 0, 1) # copy to make it writeable 36 | return image, (w, h) 37 | -------------------------------------------------------------------------------- /LiLTfinetune/evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import numpy as np 5 | 6 | from transformers.utils import logging 7 | 8 | 9 | logger = logging.get_logger(__name__) 10 | 11 | 12 | PREFIX_CHECKPOINT_DIR = "checkpoint" 13 | _re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$") 14 | 15 | 16 | def get_last_checkpoint(folder): 17 | content = os.listdir(folder) 18 | checkpoints = [ 19 | path 20 | for path in content 21 | if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path)) 22 | ] 23 | if len(checkpoints) == 0: 24 | return 25 | return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0]))) 26 | 27 | 28 | def re_score(pred_relations, gt_relations, mode="strict"): 29 | """Evaluate RE predictions 30 | 31 | Args: 32 | pred_relations (list) : list of list of predicted relations (several relations in each sentence) 33 | gt_relations (list) : list of list of ground truth relations 34 | 35 | rel = { "head": (start_idx (inclusive), end_idx (exclusive)), 36 | "tail": (start_idx (inclusive), end_idx (exclusive)), 37 | "head_type": ent_type, 38 | "tail_type": ent_type, 39 | "type": rel_type} 40 | 41 | vocab (Vocab) : dataset vocabulary 42 | mode (str) : in 'strict' or 'boundaries'""" 43 | 44 | assert mode in ["strict", "boundaries"] 45 | 46 | relation_types = [v for v in [0, 1] if not v == 0] 47 | scores = {rel: {"tp": 0, "fp": 0, "fn": 0} for rel in relation_types + ["ALL"]} 48 | 49 | # Count GT relations and Predicted relations 50 | n_sents = len(gt_relations) 51 | n_rels = sum([len([rel for rel in sent]) for sent in gt_relations]) 52 | n_found = sum([len([rel for rel in sent]) for sent in pred_relations]) 53 | 54 | # Count TP, FP and FN per type 55 | for pred_sent, gt_sent in zip(pred_relations, gt_relations): 56 | for rel_type in relation_types: 57 | # strict mode takes argument types into account 58 | if mode == "strict": 59 | pred_rels = { 60 | (rel["head"], rel["head_type"], rel["tail"], rel["tail_type"]) 61 | for rel in pred_sent 62 | if rel["type"] == rel_type 63 | } 64 | gt_rels = { 65 | (rel["head"], rel["head_type"], rel["tail"], rel["tail_type"]) 66 | for rel in gt_sent 67 | if rel["type"] == rel_type 68 | } 69 | 70 | # boundaries mode only takes argument spans into account 71 | elif mode == "boundaries": 72 | pred_rels = {(rel["head"], rel["tail"]) for rel in pred_sent if rel["type"] == rel_type} 73 | gt_rels = {(rel["head"], rel["tail"]) for rel in gt_sent if rel["type"] == rel_type} 74 | 75 | scores[rel_type]["tp"] += len(pred_rels & gt_rels) 76 | scores[rel_type]["fp"] += len(pred_rels - gt_rels) 77 | scores[rel_type]["fn"] += len(gt_rels - pred_rels) 78 | 79 | # Compute per entity Precision / Recall / F1 80 | for rel_type in scores.keys(): 81 | if scores[rel_type]["tp"]: 82 | scores[rel_type]["p"] = scores[rel_type]["tp"] / (scores[rel_type]["fp"] + scores[rel_type]["tp"]) 83 | scores[rel_type]["r"] = scores[rel_type]["tp"] / (scores[rel_type]["fn"] + scores[rel_type]["tp"]) 84 | else: 85 | scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0 86 | 87 | if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0: 88 | scores[rel_type]["f1"] = ( 89 | 2 * scores[rel_type]["p"] * scores[rel_type]["r"] / (scores[rel_type]["p"] + scores[rel_type]["r"]) 90 | ) 91 | else: 92 | scores[rel_type]["f1"] = 0 93 | 94 | # Compute micro F1 Scores 95 | tp = sum([scores[rel_type]["tp"] for rel_type in relation_types]) 96 | fp = sum([scores[rel_type]["fp"] for rel_type in relation_types]) 97 | fn = sum([scores[rel_type]["fn"] for rel_type in relation_types]) 98 | 99 | if tp: 100 | precision = tp / (tp + fp) 101 | recall = tp / (tp + fn) 102 | f1 = 2 * precision * recall / (precision + recall) 103 | 104 | else: 105 | precision, recall, f1 = 0, 0, 0 106 | 107 | scores["ALL"]["p"] = precision 108 | scores["ALL"]["r"] = recall 109 | scores["ALL"]["f1"] = f1 110 | scores["ALL"]["tp"] = tp 111 | scores["ALL"]["fp"] = fp 112 | scores["ALL"]["fn"] = fn 113 | 114 | # Compute Macro F1 Scores 115 | scores["ALL"]["Macro_f1"] = np.mean([scores[ent_type]["f1"] for ent_type in relation_types]) 116 | scores["ALL"]["Macro_p"] = np.mean([scores[ent_type]["p"] for ent_type in relation_types]) 117 | scores["ALL"]["Macro_r"] = np.mean([scores[ent_type]["r"] for ent_type in relation_types]) 118 | 119 | logger.info(f"RE Evaluation in *** {mode.upper()} *** mode") 120 | 121 | logger.info( 122 | "processed {} sentences with {} relations; found: {} relations; correct: {}.".format( 123 | n_sents, n_rels, n_found, tp 124 | ) 125 | ) 126 | logger.info( 127 | "\tALL\t TP: {};\tFP: {};\tFN: {}".format(scores["ALL"]["tp"], scores["ALL"]["fp"], scores["ALL"]["fn"]) 128 | ) 129 | logger.info("\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(precision, recall, f1)) 130 | logger.info( 131 | "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format( 132 | scores["ALL"]["Macro_p"], scores["ALL"]["Macro_r"], scores["ALL"]["Macro_f1"] 133 | ) 134 | ) 135 | 136 | for rel_type in relation_types: 137 | logger.info( 138 | "\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format( 139 | rel_type, 140 | scores[rel_type]["tp"], 141 | scores[rel_type]["fp"], 142 | scores[rel_type]["fn"], 143 | scores[rel_type]["p"], 144 | scores[rel_type]["r"], 145 | scores[rel_type]["f1"], 146 | scores[rel_type]["tp"] + scores[rel_type]["fp"], 147 | ) 148 | ) 149 | 150 | return scores 151 | -------------------------------------------------------------------------------- /LiLTfinetune/models/LiLTRobertaLike/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_LiLTRobertaLike import LiLTRobertaLikeConfig 2 | from .modeling_LiLTRobertaLike import LiLTRobertaLikeForRelationExtraction, LiLTRobertaLikeForTokenClassification, LiLTRobertaLikeModel 3 | from .tokenization_LiLTRobertaLike import LiLTRobertaLikeTokenizer 4 | from .tokenization_LiLTRobertaLike_fast import LiLTRobertaLikeTokenizerFast 5 | -------------------------------------------------------------------------------- /LiLTfinetune/models/LiLTRobertaLike/configuration_LiLTRobertaLike.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from collections import OrderedDict 3 | from typing import Any, List, Mapping, Optional 4 | 5 | from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType 6 | 7 | from transformers.utils import logging 8 | from transformers import RobertaConfig, XLMRobertaConfig 9 | 10 | logger = logging.get_logger(__name__) 11 | 12 | with open('tag.txt', 'r') as tagf: 13 | TAG = tagf.read().lower() 14 | assert TAG == 'monolingual' or TAG == 'multilingual', 'TAG is wrong. It should be monolingual or multilingual.' 15 | 16 | if TAG == 'monolingual': 17 | class LiLTRobertaLikeConfig(RobertaConfig): 18 | model_type = "liltrobertalike" 19 | 20 | def __init__( 21 | self, 22 | channel_shrink_ratio=4, 23 | max_2d_position_embeddings=1024, 24 | **kwargs 25 | ): 26 | super().__init__( 27 | **kwargs, 28 | ) 29 | self.channel_shrink_ratio = channel_shrink_ratio 30 | self.max_2d_position_embeddings = max_2d_position_embeddings 31 | 32 | elif TAG == 'multilingual': 33 | class LiLTRobertaLikeConfig(XLMRobertaConfig): 34 | model_type = "liltrobertalike" 35 | 36 | def __init__( 37 | self, 38 | channel_shrink_ratio=4, 39 | max_2d_position_embeddings=1024, 40 | **kwargs 41 | ): 42 | super().__init__( 43 | **kwargs, 44 | ) 45 | self.channel_shrink_ratio = channel_shrink_ratio 46 | self.max_2d_position_embeddings = max_2d_position_embeddings 47 | -------------------------------------------------------------------------------- /LiLTfinetune/models/LiLTRobertaLike/modeling_LiLTRobertaLike.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.checkpoint 6 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 7 | from transformers.activations import ACT2FN, gelu 8 | from transformers.file_utils import ( 9 | add_code_sample_docstrings, 10 | add_start_docstrings, 11 | add_start_docstrings_to_model_forward, 12 | replace_return_docstrings, 13 | ) 14 | from transformers.modeling_outputs import ( 15 | BaseModelOutputWithPastAndCrossAttentions, 16 | BaseModelOutputWithPoolingAndCrossAttentions, 17 | CausalLMOutputWithCrossAttentions, 18 | MaskedLMOutput, 19 | MultipleChoiceModelOutput, 20 | QuestionAnsweringModelOutput, 21 | SequenceClassifierOutput, 22 | TokenClassifierOutput, 23 | ) 24 | from transformers.modeling_utils import ( 25 | PreTrainedModel, 26 | apply_chunking_to_forward, 27 | find_pruneable_heads_and_indices, 28 | prune_linear_layer, 29 | ) 30 | from transformers.utils import logging 31 | from .configuration_LiLTRobertaLike import LiLTRobertaLikeConfig 32 | 33 | logger = logging.get_logger(__name__) 34 | 35 | class LiLTRobertaLikeTextEmbeddings(nn.Module): 36 | def __init__(self, config): 37 | super().__init__() 38 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) 39 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 40 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 41 | 42 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 43 | # any TensorFlow checkpoint file 44 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 45 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 46 | 47 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 48 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) 49 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 50 | 51 | # End copy 52 | self.padding_idx = config.pad_token_id 53 | self.position_embeddings = nn.Embedding( 54 | config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx 55 | ) 56 | 57 | def forward( 58 | self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 59 | ): 60 | if position_ids is None: 61 | if input_ids is not None: 62 | # Create the position ids from the input token ids. Any padded tokens remain padded. 63 | position_ids = create_position_ids_from_input_ids( 64 | input_ids, self.padding_idx, past_key_values_length 65 | ).to(input_ids.device) 66 | else: 67 | position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) 68 | 69 | if input_ids is not None: 70 | input_shape = input_ids.size() 71 | else: 72 | input_shape = inputs_embeds.size()[:-1] 73 | 74 | if token_type_ids is None: 75 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) 76 | 77 | if inputs_embeds is None: 78 | inputs_embeds = self.word_embeddings(input_ids) 79 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 80 | 81 | embeddings = inputs_embeds + token_type_embeddings 82 | if self.position_embedding_type == "absolute": 83 | position_embeddings = self.position_embeddings(position_ids) 84 | embeddings += position_embeddings 85 | embeddings = self.LayerNorm(embeddings) 86 | embeddings = self.dropout(embeddings) 87 | return embeddings, position_ids 88 | 89 | def create_position_ids_from_inputs_embeds(self, inputs_embeds): 90 | """ 91 | We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. 92 | Args: 93 | inputs_embeds: torch.Tensor 94 | Returns: torch.Tensor 95 | """ 96 | input_shape = inputs_embeds.size()[:-1] 97 | sequence_length = input_shape[1] 98 | 99 | position_ids = torch.arange( 100 | self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device 101 | ) 102 | return position_ids.unsqueeze(0).expand(input_shape) 103 | 104 | 105 | class LiLTRobertaLikeLayoutEmbeddings(nn.Module): 106 | def __init__(self, config): 107 | super(LiLTRobertaLikeLayoutEmbeddings, self).__init__() 108 | self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6) 109 | self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6) 110 | self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6) 111 | self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6) 112 | 113 | self.padding_idx = config.pad_token_id 114 | self.box_position_embeddings = nn.Embedding( 115 | config.max_position_embeddings, config.hidden_size//config.channel_shrink_ratio, padding_idx=self.padding_idx 116 | ) 117 | self.box_linear_embeddings = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size//config.channel_shrink_ratio) 118 | self.LayerNorm = nn.LayerNorm(config.hidden_size//config.channel_shrink_ratio, eps=config.layer_norm_eps) 119 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 120 | 121 | def forward( 122 | self, 123 | bbox=None, 124 | position_ids=None, 125 | ): 126 | 127 | try: 128 | left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0]) 129 | upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) 130 | right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) 131 | lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) 132 | except IndexError as e: 133 | raise IndexError("The :obj:`bbox`coordinate values should be within 0-1000 range.") from e 134 | 135 | h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1]) 136 | w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0]) 137 | 138 | spatial_position_embeddings = torch.cat( 139 | [ 140 | left_position_embeddings, 141 | upper_position_embeddings, 142 | right_position_embeddings, 143 | lower_position_embeddings, 144 | h_position_embeddings, 145 | w_position_embeddings, 146 | ], 147 | dim=-1, 148 | ) 149 | spatial_position_embeddings = self.box_linear_embeddings(spatial_position_embeddings) 150 | box_position_embeddings = self.box_position_embeddings(position_ids) 151 | 152 | spatial_position_embeddings = spatial_position_embeddings + box_position_embeddings 153 | 154 | spatial_position_embeddings = self.LayerNorm(spatial_position_embeddings) 155 | spatial_position_embeddings = self.dropout(spatial_position_embeddings) 156 | 157 | return spatial_position_embeddings 158 | 159 | 160 | class LiLTRobertaLikeSelfAttention(nn.Module): 161 | def __init__(self, config): 162 | super().__init__() 163 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 164 | raise ValueError( 165 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 166 | f"heads ({config.num_attention_heads})" 167 | ) 168 | 169 | self.num_attention_heads = config.num_attention_heads 170 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 171 | self.all_head_size = self.num_attention_heads * self.attention_head_size 172 | 173 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 174 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 175 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 176 | 177 | self.layout_query = nn.Linear(config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio) 178 | self.layout_key = nn.Linear(config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio) 179 | self.layout_value = nn.Linear(config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio) 180 | 181 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 182 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 183 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 184 | self.max_position_embeddings = config.max_position_embeddings 185 | self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) 186 | 187 | self.is_decoder = config.is_decoder 188 | self.channel_shrink_ratio = config.channel_shrink_ratio 189 | 190 | def transpose_for_scores(self, x, r=1): 191 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size//r) 192 | x = x.view(*new_x_shape) 193 | return x.permute(0, 2, 1, 3) 194 | 195 | def forward( 196 | self, 197 | hidden_states, 198 | layout_inputs, 199 | attention_mask=None, 200 | head_mask=None, 201 | encoder_hidden_states=None, 202 | encoder_attention_mask=None, 203 | past_key_value=None, 204 | output_attentions=False, 205 | ): 206 | 207 | layout_value_layer = self.transpose_for_scores(self.layout_value(layout_inputs), r=self.channel_shrink_ratio) 208 | layout_key_layer = self.transpose_for_scores(self.layout_key(layout_inputs), r=self.channel_shrink_ratio) 209 | layout_query_layer = self.transpose_for_scores(self.layout_query(layout_inputs), r=self.channel_shrink_ratio) 210 | 211 | mixed_query_layer = self.query(hidden_states) 212 | 213 | # If this is instantiated as a cross-attention module, the keys 214 | # and values come from an encoder; the attention mask needs to be 215 | # such that the encoder's padding tokens are not attended to. 216 | is_cross_attention = encoder_hidden_states is not None 217 | 218 | if is_cross_attention and past_key_value is not None: 219 | # reuse k,v, cross_attentions 220 | key_layer = past_key_value[0] 221 | value_layer = past_key_value[1] 222 | attention_mask = encoder_attention_mask 223 | elif is_cross_attention: 224 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 225 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 226 | attention_mask = encoder_attention_mask 227 | elif past_key_value is not None: 228 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 229 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 230 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 231 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 232 | else: 233 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 234 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 235 | 236 | query_layer = self.transpose_for_scores(mixed_query_layer) 237 | 238 | if self.is_decoder: 239 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 240 | # Further calls to cross_attention layer can then reuse all cross-attention 241 | # key/value_states (first "if" case) 242 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 243 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 244 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 245 | # if encoder bi-directional self-attention `past_key_value` is always `None` 246 | past_key_value = (key_layer, value_layer) 247 | 248 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 249 | layout_attention_scores = torch.matmul(layout_query_layer, layout_key_layer.transpose(-1, -2)) 250 | 251 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 252 | seq_length = hidden_states.size()[1] 253 | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) 254 | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) 255 | distance = position_ids_l - position_ids_r 256 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) 257 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility 258 | 259 | if self.position_embedding_type == "relative_key": 260 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 261 | attention_scores = attention_scores + relative_position_scores 262 | elif self.position_embedding_type == "relative_key_query": 263 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 264 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) 265 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key 266 | 267 | tmp_attention_scores = attention_scores / math.sqrt(self.attention_head_size) 268 | tmp_layout_attention_scores = layout_attention_scores / math.sqrt(self.attention_head_size//self.channel_shrink_ratio) 269 | attention_scores = tmp_attention_scores + tmp_layout_attention_scores 270 | layout_attention_scores = tmp_layout_attention_scores + tmp_attention_scores 271 | 272 | if attention_mask is not None: 273 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 274 | layout_attention_scores = layout_attention_scores + attention_mask 275 | 276 | # Normalize the attention scores to probabilities. 277 | layout_attention_probs = nn.Softmax(dim=-1)(layout_attention_scores) 278 | 279 | # This is actually dropping out entire tokens to attend to, which might 280 | # seem a bit unusual, but is taken from the original Transformer paper. 281 | layout_attention_probs = self.dropout(layout_attention_probs) 282 | 283 | # Mask heads if we want to 284 | if head_mask is not None: 285 | layout_attention_probs = layout_attention_probs * head_mask 286 | 287 | layout_context_layer = torch.matmul(layout_attention_probs, layout_value_layer) 288 | 289 | layout_context_layer = layout_context_layer.permute(0, 2, 1, 3).contiguous() 290 | new_context_layer_shape = layout_context_layer.size()[:-2] + (self.all_head_size//self.channel_shrink_ratio,) 291 | layout_context_layer = layout_context_layer.view(*new_context_layer_shape) 292 | 293 | if attention_mask is not None: 294 | # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) 295 | attention_scores = attention_scores + attention_mask 296 | 297 | # Normalize the attention scores to probabilities. 298 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 299 | 300 | # This is actually dropping out entire tokens to attend to, which might 301 | # seem a bit unusual, but is taken from the original Transformer paper. 302 | attention_probs = self.dropout(attention_probs) 303 | 304 | # Mask heads if we want to 305 | if head_mask is not None: 306 | attention_probs = attention_probs * head_mask 307 | 308 | context_layer = torch.matmul(attention_probs, value_layer) 309 | 310 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 311 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 312 | context_layer = context_layer.view(*new_context_layer_shape) 313 | 314 | outputs = ((context_layer, layout_context_layer), attention_probs) if output_attentions else ((context_layer, layout_context_layer),) 315 | 316 | if self.is_decoder: 317 | outputs = outputs + (past_key_value,) 318 | return outputs 319 | 320 | 321 | class LiLTRobertaLikeSelfOutput(nn.Module): 322 | def __init__(self, config): 323 | super().__init__() 324 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 325 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 326 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 327 | 328 | def forward(self, hidden_states, input_tensor): 329 | hidden_states = self.dense(hidden_states) 330 | hidden_states = self.dropout(hidden_states) 331 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 332 | return hidden_states 333 | 334 | 335 | class LiLTRobertaLikeAttention(nn.Module): 336 | def __init__(self, config): 337 | super().__init__() 338 | self.self = LiLTRobertaLikeSelfAttention(config) 339 | self.output = LiLTRobertaLikeSelfOutput(config) 340 | self.pruned_heads = set() 341 | 342 | ori_hidden_size = config.hidden_size 343 | config.hidden_size = config.hidden_size // config.channel_shrink_ratio 344 | self.layout_output = LiLTRobertaLikeSelfOutput(config) 345 | config.hidden_size = ori_hidden_size 346 | 347 | def prune_heads(self, heads): 348 | if len(heads) == 0: 349 | return 350 | heads, index = find_pruneable_heads_and_indices( 351 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads 352 | ) 353 | 354 | # Prune linear layers 355 | self.self.query = prune_linear_layer(self.self.query, index) 356 | self.self.key = prune_linear_layer(self.self.key, index) 357 | self.self.value = prune_linear_layer(self.self.value, index) 358 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 359 | 360 | # Update hyper params and store pruned heads 361 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 362 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 363 | self.pruned_heads = self.pruned_heads.union(heads) 364 | 365 | def forward( 366 | self, 367 | hidden_states, 368 | layout_inputs, 369 | attention_mask=None, 370 | head_mask=None, 371 | encoder_hidden_states=None, 372 | encoder_attention_mask=None, 373 | past_key_value=None, 374 | output_attentions=False, 375 | ): 376 | self_outputs = self.self( 377 | hidden_states, 378 | layout_inputs, 379 | attention_mask, 380 | head_mask, 381 | encoder_hidden_states, 382 | encoder_attention_mask, 383 | past_key_value, 384 | output_attentions, 385 | ) 386 | attention_output = self.output(self_outputs[0][0], hidden_states) 387 | layout_attention_output = self.layout_output(self_outputs[0][1], layout_inputs) 388 | 389 | outputs = ((attention_output, layout_attention_output),) + self_outputs[1:] # add attentions if we output them 390 | return outputs 391 | 392 | 393 | class LiLTRobertaLikeIntermediate(nn.Module): 394 | def __init__(self, config): 395 | super().__init__() 396 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 397 | if isinstance(config.hidden_act, str): 398 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 399 | else: 400 | self.intermediate_act_fn = config.hidden_act 401 | 402 | def forward(self, hidden_states): 403 | hidden_states = self.dense(hidden_states) 404 | hidden_states = self.intermediate_act_fn(hidden_states) 405 | return hidden_states 406 | 407 | 408 | class LiLTRobertaLikeOutput(nn.Module): 409 | def __init__(self, config): 410 | super().__init__() 411 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 412 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 413 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 414 | 415 | def forward(self, hidden_states, input_tensor): 416 | hidden_states = self.dense(hidden_states) 417 | hidden_states = self.dropout(hidden_states) 418 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 419 | return hidden_states 420 | 421 | 422 | class LiLTRobertaLikeLayer(nn.Module): 423 | def __init__(self, config): 424 | super().__init__() 425 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 426 | self.seq_len_dim = 1 427 | self.attention = LiLTRobertaLikeAttention(config) 428 | self.is_decoder = config.is_decoder 429 | self.add_cross_attention = config.add_cross_attention 430 | if self.add_cross_attention: 431 | assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added" 432 | self.crossattention = LiLTRobertaLikeAttention(config) 433 | self.intermediate = LiLTRobertaLikeIntermediate(config) 434 | self.output = LiLTRobertaLikeOutput(config) 435 | 436 | ori_hidden_size = config.hidden_size 437 | ori_intermediate_size = config.intermediate_size 438 | config.hidden_size = config.hidden_size // config.channel_shrink_ratio 439 | config.intermediate_size = config.intermediate_size // config.channel_shrink_ratio 440 | self.layout_intermediate = LiLTRobertaLikeIntermediate(config) 441 | self.layout_output = LiLTRobertaLikeOutput(config) 442 | config.hidden_size = ori_hidden_size 443 | config.intermediate_size = ori_intermediate_size 444 | 445 | def forward( 446 | self, 447 | hidden_states, 448 | layout_inputs, 449 | attention_mask=None, 450 | head_mask=None, 451 | encoder_hidden_states=None, 452 | encoder_attention_mask=None, 453 | past_key_value=None, 454 | output_attentions=False, 455 | ): 456 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 457 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 458 | self_attention_outputs = self.attention( 459 | hidden_states, 460 | layout_inputs, 461 | attention_mask, 462 | head_mask, 463 | output_attentions=output_attentions, 464 | past_key_value=self_attn_past_key_value, 465 | ) 466 | attention_output = self_attention_outputs[0][0] 467 | layout_attention_output = self_attention_outputs[0][1] 468 | 469 | # if decoder, the last output is tuple of self-attn cache 470 | if self.is_decoder: 471 | outputs = self_attention_outputs[1:-1] 472 | present_key_value = self_attention_outputs[-1] 473 | else: 474 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights 475 | 476 | cross_attn_present_key_value = None 477 | if self.is_decoder and encoder_hidden_states is not None: 478 | assert hasattr( 479 | self, "crossattention" 480 | ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" 481 | 482 | # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple 483 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 484 | cross_attention_outputs = self.crossattention( 485 | attention_output, 486 | attention_mask, 487 | head_mask, 488 | encoder_hidden_states, 489 | encoder_attention_mask, 490 | cross_attn_past_key_value, 491 | output_attentions, 492 | ) 493 | attention_output = cross_attention_outputs[0] 494 | outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights 495 | 496 | # add cross-attn cache to positions 3,4 of present_key_value tuple 497 | cross_attn_present_key_value = cross_attention_outputs[-1] 498 | present_key_value = present_key_value + cross_attn_present_key_value 499 | 500 | layer_output = apply_chunking_to_forward( 501 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output 502 | ) 503 | 504 | layout_layer_output = apply_chunking_to_forward( 505 | self.layout_feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, layout_attention_output 506 | ) 507 | 508 | outputs = ((layer_output, layout_layer_output),) + outputs 509 | 510 | # if decoder, return the attn key/values as the last output 511 | if self.is_decoder: 512 | outputs = outputs + (present_key_value,) 513 | 514 | return outputs 515 | 516 | def feed_forward_chunk(self, attention_output): 517 | intermediate_output = self.intermediate(attention_output) 518 | layer_output = self.output(intermediate_output, attention_output) 519 | return layer_output 520 | 521 | def layout_feed_forward_chunk(self, attention_output): 522 | intermediate_output = self.layout_intermediate(attention_output) 523 | layer_output = self.layout_output(intermediate_output, attention_output) 524 | return layer_output 525 | 526 | 527 | class LiLTRobertaLikeEncoder(nn.Module): 528 | def __init__(self, config): 529 | super().__init__() 530 | self.config = config 531 | self.layer = nn.ModuleList([LiLTRobertaLikeLayer(config) for _ in range(config.num_hidden_layers)]) 532 | 533 | def forward( 534 | self, 535 | hidden_states, 536 | layout_inputs, 537 | attention_mask=None, 538 | head_mask=None, 539 | encoder_hidden_states=None, 540 | encoder_attention_mask=None, 541 | past_key_values=None, 542 | use_cache=None, 543 | output_attentions=False, 544 | output_hidden_states=False, 545 | return_dict=True, 546 | ): 547 | all_hidden_states = () if output_hidden_states else None 548 | all_self_attentions = () if output_attentions else None 549 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 550 | 551 | next_decoder_cache = () if use_cache else None 552 | for i, layer_module in enumerate(self.layer): 553 | if output_hidden_states: 554 | all_hidden_states = all_hidden_states + (hidden_states,) 555 | 556 | layer_head_mask = head_mask[i] if head_mask is not None else None 557 | past_key_value = past_key_values[i] if past_key_values is not None else None 558 | 559 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 560 | 561 | if use_cache: 562 | logger.warning( 563 | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " 564 | "`use_cache=False`..." 565 | ) 566 | use_cache = False 567 | 568 | def create_custom_forward(module): 569 | def custom_forward(*inputs): 570 | return module(*inputs, past_key_value, output_attentions) 571 | 572 | return custom_forward 573 | 574 | layer_outputs = torch.utils.checkpoint.checkpoint( 575 | create_custom_forward(layer_module), 576 | hidden_states, 577 | layout_inputs, 578 | attention_mask, 579 | layer_head_mask, 580 | encoder_hidden_states, 581 | encoder_attention_mask, 582 | ) 583 | 584 | else: 585 | layer_outputs = layer_module( 586 | hidden_states, 587 | layout_inputs, 588 | attention_mask, 589 | layer_head_mask, 590 | encoder_hidden_states, 591 | encoder_attention_mask, 592 | past_key_value, 593 | output_attentions, 594 | ) 595 | 596 | hidden_states = layer_outputs[0][0] 597 | layout_inputs = layer_outputs[0][1] 598 | 599 | if use_cache: 600 | next_decoder_cache += (layer_outputs[-1],) 601 | if output_attentions: 602 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 603 | if self.config.add_cross_attention: 604 | all_cross_attentions = all_cross_attentions + (layer_outputs[2],) 605 | 606 | if output_hidden_states: 607 | all_hidden_states = all_hidden_states + (hidden_states,) 608 | 609 | if not return_dict: 610 | return tuple( 611 | v 612 | for v in [ 613 | hidden_states, 614 | next_decoder_cache, 615 | all_hidden_states, 616 | all_self_attentions, 617 | all_cross_attentions, 618 | ] 619 | if v is not None 620 | ), layout_inputs 621 | 622 | return BaseModelOutputWithPastAndCrossAttentions( 623 | last_hidden_state=hidden_states, 624 | past_key_values=next_decoder_cache, 625 | hidden_states=all_hidden_states, 626 | attentions=all_self_attentions, 627 | cross_attentions=all_cross_attentions, 628 | ), layout_inputs 629 | 630 | 631 | class LiLTRobertaLikePooler(nn.Module): 632 | def __init__(self, config): 633 | super().__init__() 634 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 635 | self.activation = nn.Tanh() 636 | 637 | def forward(self, hidden_states): 638 | # We "pool" the model by simply taking the hidden state corresponding 639 | # to the first token. 640 | first_token_tensor = hidden_states[:, 0] 641 | pooled_output = self.dense(first_token_tensor) 642 | pooled_output = self.activation(pooled_output) 643 | return pooled_output 644 | 645 | 646 | class LiLTRobertaLikePreTrainedModel(PreTrainedModel): 647 | """ 648 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 649 | models. 650 | """ 651 | config_class = LiLTRobertaLikeConfig 652 | base_model_prefix = "liltrobertalike" 653 | def _init_weights(self, module): 654 | """Initialize the weights""" 655 | if isinstance(module, nn.Linear): 656 | # Slightly different from the TF version which uses truncated_normal for initialization 657 | # cf https://github.com/pytorch/pytorch/pull/5617 658 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 659 | if module.bias is not None: 660 | module.bias.data.zero_() 661 | elif isinstance(module, nn.Embedding): 662 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 663 | if module.padding_idx is not None: 664 | module.weight.data[module.padding_idx].zero_() 665 | elif isinstance(module, nn.LayerNorm): 666 | module.bias.data.zero_() 667 | module.weight.data.fill_(1.0) 668 | 669 | 670 | 671 | class LiLTRobertaLikeModel(LiLTRobertaLikePreTrainedModel): 672 | 673 | _keys_to_ignore_on_load_missing = [r"position_ids"] 674 | 675 | def __init__(self, config, add_pooling_layer=True): 676 | super().__init__(config) 677 | self.config = config 678 | 679 | self.embeddings = LiLTRobertaLikeTextEmbeddings(config) 680 | self.layout_embeddings = LiLTRobertaLikeLayoutEmbeddings(config) 681 | 682 | self.encoder = LiLTRobertaLikeEncoder(config) 683 | 684 | self.pooler = LiLTRobertaLikePooler(config) if add_pooling_layer else None 685 | 686 | self.init_weights() 687 | 688 | def get_input_embeddings(self): 689 | return self.embeddings.word_embeddings 690 | 691 | def set_input_embeddings(self, value): 692 | self.embeddings.word_embeddings = value 693 | 694 | def _prune_heads(self, heads_to_prune): 695 | """ 696 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 697 | class PreTrainedModel 698 | """ 699 | for layer, heads in heads_to_prune.items(): 700 | self.encoder.layer[layer].attention.prune_heads(heads) 701 | 702 | def forward( 703 | self, 704 | input_ids=None, 705 | bbox=None, 706 | attention_mask=None, 707 | token_type_ids=None, 708 | position_ids=None, 709 | head_mask=None, 710 | inputs_embeds=None, 711 | encoder_hidden_states=None, 712 | encoder_attention_mask=None, 713 | past_key_values=None, 714 | use_cache=None, 715 | output_attentions=None, 716 | output_hidden_states=None, 717 | return_dict=None, 718 | ): 719 | 720 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 721 | output_hidden_states = ( 722 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 723 | ) 724 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 725 | 726 | if self.config.is_decoder: 727 | use_cache = use_cache if use_cache is not None else self.config.use_cache 728 | else: 729 | use_cache = False 730 | 731 | if input_ids is not None and inputs_embeds is not None: 732 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 733 | elif input_ids is not None: 734 | input_shape = input_ids.size() 735 | batch_size, seq_length = input_shape 736 | elif inputs_embeds is not None: 737 | input_shape = inputs_embeds.size()[:-1] 738 | batch_size, seq_length = input_shape 739 | else: 740 | raise ValueError("You have to specify either input_ids or inputs_embeds") 741 | 742 | device = input_ids.device if input_ids is not None else inputs_embeds.device 743 | 744 | # past_key_values_length 745 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 746 | 747 | if attention_mask is None: 748 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) 749 | if token_type_ids is None: 750 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 751 | 752 | if bbox is None: 753 | bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device) 754 | 755 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 756 | # ourselves in which case we just need to make it broadcastable to all heads. 757 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) 758 | 759 | # If a 2D or 3D attention mask is provided for the cross-attention 760 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 761 | if self.config.is_decoder and encoder_hidden_states is not None: 762 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 763 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 764 | if encoder_attention_mask is None: 765 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 766 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 767 | else: 768 | encoder_extended_attention_mask = None 769 | 770 | # Prepare head mask if needed 771 | # 1.0 in head_mask indicate we keep the head 772 | # attention_probs has shape bsz x n_heads x N x N 773 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 774 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 775 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 776 | 777 | embedding_output, position_ids = self.embeddings( 778 | input_ids=input_ids, 779 | position_ids=position_ids, 780 | token_type_ids=token_type_ids, 781 | inputs_embeds=inputs_embeds, 782 | past_key_values_length=past_key_values_length, 783 | ) 784 | 785 | layout_embedding_output = self.layout_embeddings( 786 | bbox=bbox, 787 | position_ids=position_ids, 788 | ) 789 | 790 | encoder_outputs, layout_encoder_outputs = self.encoder( 791 | embedding_output, 792 | layout_embedding_output, 793 | attention_mask=extended_attention_mask, 794 | head_mask=head_mask, 795 | encoder_hidden_states=encoder_hidden_states, 796 | encoder_attention_mask=encoder_extended_attention_mask, 797 | past_key_values=past_key_values, 798 | use_cache=use_cache, 799 | output_attentions=output_attentions, 800 | output_hidden_states=output_hidden_states, 801 | return_dict=return_dict, 802 | ) 803 | 804 | sequence_output = encoder_outputs[0] 805 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None 806 | 807 | if not return_dict: 808 | return (sequence_output, pooled_output) + encoder_outputs[1:] 809 | 810 | return BaseModelOutputWithPoolingAndCrossAttentions( 811 | last_hidden_state=sequence_output, 812 | pooler_output=pooled_output, 813 | past_key_values=encoder_outputs.past_key_values, 814 | hidden_states=encoder_outputs.hidden_states, 815 | attentions=encoder_outputs.attentions, 816 | cross_attentions=encoder_outputs.cross_attentions, 817 | ), layout_encoder_outputs 818 | 819 | 820 | 821 | class LiLTRobertaLikeForTokenClassification(LiLTRobertaLikePreTrainedModel): 822 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 823 | _keys_to_ignore_on_load_missing = [r"position_ids"] 824 | 825 | def __init__(self, config): 826 | super().__init__(config) 827 | self.num_labels = config.num_labels 828 | 829 | self.lilt = LiLTRobertaLikeModel(config, add_pooling_layer=False) 830 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 831 | 832 | self.classifier = nn.Linear(config.hidden_size + config.hidden_size//config.channel_shrink_ratio, config.num_labels) 833 | 834 | self.init_weights() 835 | 836 | def forward( 837 | self, 838 | input_ids=None, 839 | bbox=None, 840 | attention_mask=None, 841 | token_type_ids=None, 842 | position_ids=None, 843 | head_mask=None, 844 | inputs_embeds=None, 845 | labels=None, 846 | output_attentions=None, 847 | output_hidden_states=None, 848 | return_dict=None, 849 | ): 850 | r""" 851 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 852 | Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - 853 | 1]``. 854 | """ 855 | 856 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 857 | 858 | outputs, layout_outputs = self.lilt( 859 | input_ids, 860 | bbox=bbox, 861 | attention_mask=attention_mask, 862 | token_type_ids=token_type_ids, 863 | position_ids=position_ids, 864 | head_mask=head_mask, 865 | inputs_embeds=inputs_embeds, 866 | output_attentions=output_attentions, 867 | output_hidden_states=output_hidden_states, 868 | return_dict=return_dict, 869 | ) 870 | 871 | sequence_output = outputs[0] 872 | 873 | sequence_output = torch.cat([sequence_output, layout_outputs], -1) 874 | sequence_output = self.dropout(sequence_output) 875 | logits = self.classifier(sequence_output) 876 | 877 | loss = None 878 | if labels is not None: 879 | loss_fct = CrossEntropyLoss() 880 | # Only keep active parts of the loss 881 | if attention_mask is not None: 882 | active_loss = attention_mask.view(-1) == 1 883 | active_logits = logits.view(-1, self.num_labels) 884 | active_labels = torch.where( 885 | active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) 886 | ) 887 | loss = loss_fct(active_logits, active_labels) 888 | else: 889 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 890 | 891 | if not return_dict: 892 | output = (logits,) + outputs[2:] 893 | return ((loss,) + output) if loss is not None else output 894 | 895 | return TokenClassifierOutput( 896 | loss=loss, 897 | logits=logits, 898 | hidden_states=outputs.hidden_states, 899 | attentions=outputs.attentions, 900 | ) 901 | 902 | 903 | from dataclasses import dataclass 904 | from typing import Dict, Optional, Tuple 905 | from transformers.file_utils import ModelOutput 906 | from ...modules.decoders.re import REDecoder 907 | from ...utils import ReOutput 908 | 909 | class LiLTRobertaLikeForRelationExtraction(LiLTRobertaLikePreTrainedModel): 910 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 911 | _keys_to_ignore_on_load_missing = [r"position_ids"] 912 | def __init__(self, config): 913 | super().__init__(config) 914 | 915 | self.lilt = LiLTRobertaLikeModel(config, add_pooling_layer=False) 916 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 917 | self.extractor = REDecoder(config, config.hidden_size + config.hidden_size // config.channel_shrink_ratio) 918 | self.init_weights() 919 | 920 | def forward( 921 | self, 922 | input_ids=None, 923 | bbox=None, 924 | attention_mask=None, 925 | token_type_ids=None, 926 | position_ids=None, 927 | head_mask=None, 928 | inputs_embeds=None, 929 | labels=None, 930 | output_attentions=None, 931 | output_hidden_states=None, 932 | return_dict=None, 933 | entities=None, 934 | relations=None, 935 | ): 936 | 937 | outputs, layout_outputs = self.lilt( 938 | input_ids, 939 | bbox=bbox, 940 | attention_mask=attention_mask, 941 | token_type_ids=token_type_ids, 942 | position_ids=position_ids, 943 | head_mask=head_mask, 944 | inputs_embeds=inputs_embeds, 945 | output_attentions=output_attentions, 946 | output_hidden_states=output_hidden_states, 947 | return_dict=return_dict, 948 | ) 949 | 950 | seq_length = input_ids.size(1) 951 | sequence_output = outputs[0] 952 | sequence_output = torch.cat([sequence_output, layout_outputs], -1) 953 | 954 | sequence_output = self.dropout(sequence_output) 955 | loss, pred_relations = self.extractor(sequence_output, entities, relations) 956 | 957 | return ReOutput( 958 | loss=loss, 959 | entities=entities, 960 | relations=relations, 961 | pred_relations=pred_relations, 962 | ) 963 | 964 | 965 | def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): 966 | """ 967 | Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols 968 | are ignored. This is modified from fairseq's `utils.make_positions`. 969 | Args: 970 | x: torch.Tensor x: 971 | Returns: torch.Tensor 972 | """ 973 | # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. 974 | mask = input_ids.ne(padding_idx).int() 975 | incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask 976 | return incremental_indices.long() + padding_idx -------------------------------------------------------------------------------- /LiLTfinetune/models/LiLTRobertaLike/tokenization_LiLTRobertaLike.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from transformers import RobertaTokenizer, XLMRobertaTokenizer 4 | from transformers.utils import logging 5 | 6 | logger = logging.get_logger(__name__) 7 | 8 | SPIECE_UNDERLINE = "▁" 9 | 10 | VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"} 11 | 12 | with open('tag.txt', 'r') as tagf: 13 | TAG = tagf.read().lower() 14 | assert TAG == 'monolingual' or TAG == 'multilingual', 'TAG is wrong. It should be monolingual or multilingual.' 15 | 16 | if TAG == 'monolingual': 17 | class LiLTRobertaLikeTokenizer(RobertaTokenizer): 18 | vocab_files_names = VOCAB_FILES_NAMES 19 | max_model_input_sizes = {"lilt-roberta-base": 512,} 20 | model_input_names = ["input_ids", "attention_mask"] 21 | 22 | def __init__(self, model_max_length=512, **kwargs): 23 | super().__init__(model_max_length=model_max_length, **kwargs) 24 | 25 | elif TAG == 'multilingual': 26 | class LiLTRobertaLikeTokenizer(XLMRobertaTokenizer): 27 | vocab_files_names = VOCAB_FILES_NAMES 28 | max_model_input_sizes = {"lilt-infoxlm-base": 512,} 29 | model_input_names = ["input_ids", "attention_mask"] 30 | 31 | def __init__(self, model_max_length=512, **kwargs): 32 | super().__init__(model_max_length=model_max_length, **kwargs) 33 | -------------------------------------------------------------------------------- /LiLTfinetune/models/LiLTRobertaLike/tokenization_LiLTRobertaLike_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from transformers import RobertaTokenizerFast, XLMRobertaTokenizerFast 3 | from transformers.file_utils import is_sentencepiece_available 4 | from transformers.utils import logging 5 | 6 | 7 | if is_sentencepiece_available(): 8 | from .tokenization_LiLTRobertaLike import LiLTRobertaLikeTokenizer 9 | else: 10 | LiLTRobertaLikeTokenizer = None 11 | 12 | 13 | logger = logging.get_logger(__name__) 14 | 15 | VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"} 16 | 17 | with open('tag.txt', 'r') as tagf: 18 | TAG = tagf.read().lower() 19 | assert TAG == 'monolingual' or TAG == 'multilingual', 'TAG is wrong. It should be monolingual or multilingual.' 20 | 21 | if TAG == 'monolingual': 22 | class LiLTRobertaLikeTokenizerFast(RobertaTokenizerFast): 23 | 24 | vocab_files_names = VOCAB_FILES_NAMES 25 | max_model_input_sizes = {"lilt-roberta-base": 512,} 26 | model_input_names = ["input_ids", "attention_mask"] 27 | slow_tokenizer_class = LiLTRobertaLikeTokenizer 28 | 29 | def __init__(self, model_max_length=512, **kwargs): 30 | super().__init__(model_max_length=model_max_length, **kwargs) 31 | 32 | elif TAG == 'multilingual': 33 | class LiLTRobertaLikeTokenizerFast(XLMRobertaTokenizerFast): 34 | 35 | vocab_files_names = VOCAB_FILES_NAMES 36 | max_model_input_sizes = {"lilt-infoxlm-base": 512,} 37 | model_input_names = ["input_ids", "attention_mask"] 38 | slow_tokenizer_class = LiLTRobertaLikeTokenizer 39 | 40 | def __init__(self, model_max_length=512, **kwargs): 41 | super().__init__(model_max_length=model_max_length, **kwargs) 42 | -------------------------------------------------------------------------------- /LiLTfinetune/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/LiLTfinetune/models/__init__.py -------------------------------------------------------------------------------- /LiLTfinetune/models/model_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class ModelArguments: 7 | """ 8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 9 | """ 10 | 11 | model_name_or_path: str = field( 12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 13 | ) 14 | config_name: Optional[str] = field( 15 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 16 | ) 17 | tokenizer_name: Optional[str] = field( 18 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 19 | ) 20 | cache_dir: Optional[str] = field( 21 | default=None, 22 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 23 | ) 24 | model_revision: str = field( 25 | default="main", 26 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 27 | ) 28 | use_auth_token: bool = field( 29 | default=False, 30 | metadata={ 31 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 32 | "with private models)." 33 | }, 34 | ) 35 | -------------------------------------------------------------------------------- /LiLTfinetune/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/LiLTfinetune/modules/__init__.py -------------------------------------------------------------------------------- /LiLTfinetune/modules/decoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/LiLTfinetune/modules/decoders/__init__.py -------------------------------------------------------------------------------- /LiLTfinetune/modules/decoders/re.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import CrossEntropyLoss 6 | 7 | 8 | class BiaffineAttention(torch.nn.Module): 9 | """Implements a biaffine attention operator for binary relation classification. 10 | 11 | PyTorch implementation of the biaffine attention operator from "End-to-end neural relation 12 | extraction using deep biaffine attention" (https://arxiv.org/abs/1812.11275) which can be used 13 | as a classifier for binary relation classification. 14 | 15 | Args: 16 | in_features (int): The size of the feature dimension of the inputs. 17 | out_features (int): The size of the feature dimension of the output. 18 | 19 | Shape: 20 | - x_1: `(N, *, in_features)` where `N` is the batch dimension and `*` means any number of 21 | additional dimensisons. 22 | - x_2: `(N, *, in_features)`, where `N` is the batch dimension and `*` means any number of 23 | additional dimensions. 24 | - Output: `(N, *, out_features)`, where `N` is the batch dimension and `*` means any number 25 | of additional dimensions. 26 | 27 | Examples: 28 | >>> batch_size, in_features, out_features = 32, 100, 4 29 | >>> biaffine_attention = BiaffineAttention(in_features, out_features) 30 | >>> x_1 = torch.randn(batch_size, in_features) 31 | >>> x_2 = torch.randn(batch_size, in_features) 32 | >>> output = biaffine_attention(x_1, x_2) 33 | >>> print(output.size()) 34 | torch.Size([32, 4]) 35 | """ 36 | 37 | def __init__(self, in_features, out_features): 38 | super(BiaffineAttention, self).__init__() 39 | 40 | self.in_features = in_features 41 | self.out_features = out_features 42 | 43 | self.bilinear = torch.nn.Bilinear(in_features, in_features, out_features, bias=False) 44 | self.linear = torch.nn.Linear(2 * in_features, out_features, bias=True) 45 | 46 | self.reset_parameters() 47 | 48 | def forward(self, x_1, x_2): 49 | return self.bilinear(x_1, x_2) + self.linear(torch.cat((x_1, x_2), dim=-1)) 50 | 51 | def reset_parameters(self): 52 | self.bilinear.reset_parameters() 53 | self.linear.reset_parameters() 54 | 55 | 56 | class REDecoder(nn.Module): 57 | def __init__(self, config, input_size): 58 | super().__init__() 59 | self.entity_emb = nn.Embedding(3, input_size, scale_grad_by_freq=True) 60 | projection = nn.Sequential( 61 | nn.Linear(input_size * 2, config.hidden_size), 62 | nn.ReLU(), 63 | nn.Dropout(config.hidden_dropout_prob), 64 | nn.Linear(config.hidden_size, config.hidden_size // 2), 65 | nn.ReLU(), 66 | nn.Dropout(config.hidden_dropout_prob), 67 | ) 68 | self.ffnn_head = copy.deepcopy(projection) 69 | self.ffnn_tail = copy.deepcopy(projection) 70 | self.rel_classifier = BiaffineAttention(config.hidden_size // 2, 2) 71 | self.loss_fct = CrossEntropyLoss() 72 | 73 | def build_relation(self, relations, entities): 74 | batch_size = len(relations) 75 | new_relations = [] 76 | for b in range(batch_size): 77 | if len(entities[b]["start"]) <= 2: 78 | entities[b] = {"end": [1, 1], "label": [0, 0], "start": [0, 0]} 79 | all_possible_relations = set( 80 | [ 81 | (i, j) 82 | for i in range(len(entities[b]["label"])) 83 | for j in range(len(entities[b]["label"])) 84 | if entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2 85 | ] 86 | ) 87 | if len(all_possible_relations) == 0: 88 | all_possible_relations = set([(0, 1)]) 89 | positive_relations = set(list(zip(relations[b]["head"], relations[b]["tail"]))) 90 | negative_relations = all_possible_relations - positive_relations 91 | positive_relations = set([i for i in positive_relations if i in all_possible_relations]) 92 | reordered_relations = list(positive_relations) + list(negative_relations) 93 | relation_per_doc = {"head": [], "tail": [], "label": []} 94 | relation_per_doc["head"] = [i[0] for i in reordered_relations] 95 | relation_per_doc["tail"] = [i[1] for i in reordered_relations] 96 | relation_per_doc["label"] = [1] * len(positive_relations) + [0] * ( 97 | len(reordered_relations) - len(positive_relations) 98 | ) 99 | assert len(relation_per_doc["head"]) != 0 100 | new_relations.append(relation_per_doc) 101 | return new_relations, entities 102 | 103 | def get_predicted_relations(self, logits, relations, entities): 104 | pred_relations = [] 105 | for i, pred_label in enumerate(logits.argmax(-1)): 106 | if pred_label != 1: 107 | continue 108 | rel = {} 109 | rel["head_id"] = relations["head"][i] 110 | rel["head"] = (entities["start"][rel["head_id"]], entities["end"][rel["head_id"]]) 111 | rel["head_type"] = entities["label"][rel["head_id"]] 112 | 113 | rel["tail_id"] = relations["tail"][i] 114 | rel["tail"] = (entities["start"][rel["tail_id"]], entities["end"][rel["tail_id"]]) 115 | rel["tail_type"] = entities["label"][rel["tail_id"]] 116 | rel["type"] = 1 117 | pred_relations.append(rel) 118 | return pred_relations 119 | 120 | def forward(self, hidden_states, entities, relations): 121 | batch_size, max_n_words, context_dim = hidden_states.size() 122 | device = hidden_states.device 123 | relations, entities = self.build_relation(relations, entities) 124 | loss = 0 125 | all_pred_relations = [] 126 | all_logits = [] 127 | all_labels = [] 128 | 129 | for b in range(batch_size): 130 | head_entities = torch.tensor(relations[b]["head"], device=device) 131 | tail_entities = torch.tensor(relations[b]["tail"], device=device) 132 | relation_labels = torch.tensor(relations[b]["label"], device=device) 133 | entities_start_index = torch.tensor(entities[b]["start"], device=device) 134 | entities_labels = torch.tensor(entities[b]["label"], device=device) 135 | head_index = entities_start_index[head_entities] 136 | head_label = entities_labels[head_entities] 137 | head_label_repr = self.entity_emb(head_label) 138 | 139 | tail_index = entities_start_index[tail_entities] 140 | tail_label = entities_labels[tail_entities] 141 | tail_label_repr = self.entity_emb(tail_label) 142 | 143 | head_repr = torch.cat( 144 | (hidden_states[b][head_index], head_label_repr), 145 | dim=-1, 146 | ) 147 | tail_repr = torch.cat( 148 | (hidden_states[b][tail_index], tail_label_repr), 149 | dim=-1, 150 | ) 151 | heads = self.ffnn_head(head_repr) 152 | tails = self.ffnn_tail(tail_repr) 153 | logits = self.rel_classifier(heads, tails) 154 | pred_relations = self.get_predicted_relations(logits, relations[b], entities[b]) 155 | all_pred_relations.append(pred_relations) 156 | all_logits.append(logits) 157 | all_labels.append(relation_labels) 158 | all_logits = torch.cat(all_logits, 0) 159 | all_labels = torch.cat(all_labels, 0) 160 | loss = self.loss_fct(all_logits, all_labels) 161 | return loss, all_pred_relations -------------------------------------------------------------------------------- /LiLTfinetune/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .funsd_trainer import FunsdTrainer 2 | from .xfun_trainer import XfunReTrainer, XfunSerTrainer 3 | -------------------------------------------------------------------------------- /LiLTfinetune/trainers/funsd_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Union 2 | 3 | import torch 4 | 5 | from transformers import Trainer 6 | 7 | 8 | class FunsdTrainer(Trainer): 9 | def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: 10 | """ 11 | Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and 12 | handling potential state. 13 | """ 14 | for k, v in inputs.items(): 15 | if hasattr(v, "to") and hasattr(v, "device"): 16 | inputs[k] = v.to(self.args.device) 17 | 18 | if self.args.past_index >= 0 and self._past is not None: 19 | inputs["mems"] = self._past 20 | 21 | return inputs 22 | -------------------------------------------------------------------------------- /LiLTfinetune/trainers/xfun_trainer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import time 3 | from typing import Any, Dict, List, Optional, Tuple, Union 4 | 5 | import torch 6 | from packaging import version 7 | from torch import nn 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | from transformers.utils import logging 11 | from transformers.file_utils import is_sagemaker_mp_enabled 12 | from transformers.trainer_utils import EvalPrediction, PredictionOutput, speed_metrics, ShardedDDPOption 13 | from transformers.trainer_pt_utils import get_parameter_names 14 | from transformers.optimization import Adafactor, AdamW, get_scheduler 15 | 16 | from .funsd_trainer import FunsdTrainer 17 | 18 | 19 | if version.parse(torch.__version__) >= version.parse("1.6"): 20 | _is_native_amp_available = True 21 | from torch.cuda.amp import autocast 22 | 23 | logger = logging.get_logger(__name__) 24 | 25 | 26 | class XfunSerTrainer(FunsdTrainer): 27 | pass 28 | 29 | 30 | class XfunReTrainer(FunsdTrainer): 31 | def __init__(self, **kwargs): 32 | super().__init__(**kwargs) 33 | self.label_names.append("relations") 34 | 35 | def prediction_step( 36 | self, 37 | model: nn.Module, 38 | inputs: Dict[str, Union[torch.Tensor, Any]], 39 | prediction_loss_only: bool, 40 | ignore_keys: Optional[List[str]] = None, 41 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 42 | inputs = self._prepare_inputs(inputs) 43 | 44 | with torch.no_grad(): 45 | if self.use_amp: 46 | with autocast(): 47 | outputs = model(**inputs) 48 | else: 49 | outputs = model(**inputs) 50 | labels = tuple(inputs.get(name) for name in self.label_names) 51 | return outputs, labels 52 | 53 | def prediction_loop( 54 | self, 55 | dataloader: DataLoader, 56 | description: str, 57 | prediction_loss_only: Optional[bool] = None, 58 | ignore_keys: Optional[List[str]] = None, 59 | metric_key_prefix: str = "eval", 60 | ) -> PredictionOutput: 61 | """ 62 | Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. 63 | 64 | Works both with or without labels. 65 | """ 66 | if not isinstance(dataloader.dataset, collections.abc.Sized): 67 | raise ValueError("dataset must implement __len__") 68 | prediction_loss_only = ( 69 | prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only 70 | ) 71 | 72 | if self.args.deepspeed and not self.args.do_train: 73 | # no harm, but flagging to the user that deepspeed config is ignored for eval 74 | # flagging only for when --do_train wasn't passed as only then it's redundant 75 | logger.info("Detected the deepspeed argument but it will not be used for evaluation") 76 | 77 | model = self._wrap_model(self.model, training=False) 78 | 79 | # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while 80 | # ``train`` is running, half it first and then put on device 81 | if not self.is_in_train and self.args.fp16_full_eval: 82 | model = model.half().to(self.args.device) 83 | 84 | batch_size = dataloader.batch_size 85 | num_examples = self.num_examples(dataloader) 86 | logger.info("***** Running %s *****", description) 87 | logger.info(" Num examples = %d", num_examples) 88 | logger.info(" Batch size = %d", batch_size) 89 | 90 | model.eval() 91 | 92 | self.callback_handler.eval_dataloader = dataloader 93 | 94 | re_labels = None 95 | pred_relations = None 96 | entities = None 97 | for step, inputs in enumerate(dataloader): 98 | outputs, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) 99 | re_labels = labels[1] if re_labels is None else re_labels + labels[1] 100 | pred_relations = ( 101 | outputs.pred_relations if pred_relations is None else pred_relations + outputs.pred_relations 102 | ) 103 | entities = outputs.entities if entities is None else entities + outputs.entities 104 | 105 | self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control) 106 | 107 | gt_relations = [] 108 | for b in range(len(re_labels)): 109 | rel_sent = [] 110 | for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]): 111 | rel = {} 112 | rel["head_id"] = head 113 | rel["head"] = (entities[b]["start"][rel["head_id"]], entities[b]["end"][rel["head_id"]]) 114 | rel["head_type"] = entities[b]["label"][rel["head_id"]] 115 | 116 | rel["tail_id"] = tail 117 | rel["tail"] = (entities[b]["start"][rel["tail_id"]], entities[b]["end"][rel["tail_id"]]) 118 | rel["tail_type"] = entities[b]["label"][rel["tail_id"]] 119 | 120 | rel["type"] = 1 121 | 122 | rel_sent.append(rel) 123 | 124 | gt_relations.append(rel_sent) 125 | 126 | re_metrics = self.compute_metrics(EvalPrediction(predictions=pred_relations, label_ids=gt_relations)) 127 | 128 | re_metrics = { 129 | "precision": re_metrics["ALL"]["p"], 130 | "recall": re_metrics["ALL"]["r"], 131 | "f1": re_metrics["ALL"]["f1"], 132 | } 133 | re_metrics[f"{metric_key_prefix}_loss"] = outputs.loss.mean().item() 134 | 135 | metrics = {} 136 | 137 | # # Prefix all keys with metric_key_prefix + '_' 138 | for key in list(re_metrics.keys()): 139 | if not key.startswith(f"{metric_key_prefix}_"): 140 | metrics[f"{metric_key_prefix}_{key}"] = re_metrics.pop(key) 141 | else: 142 | metrics[f"{key}"] = re_metrics.pop(key) 143 | 144 | return metrics 145 | 146 | def evaluate( 147 | self, 148 | eval_dataset: Optional[Dataset] = None, 149 | ignore_keys: Optional[List[str]] = None, 150 | metric_key_prefix: str = "eval", 151 | ) -> Dict[str, float]: 152 | """ 153 | Run evaluation and returns metrics. 154 | 155 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent 156 | (pass it to the init :obj:`compute_metrics` argument). 157 | 158 | You can also subclass and override this method to inject custom behavior. 159 | 160 | Args: 161 | eval_dataset (:obj:`Dataset`, `optional`): 162 | Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, 163 | columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the 164 | :obj:`__len__` method. 165 | ignore_keys (:obj:`Lst[str]`, `optional`): 166 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 167 | gathering predictions. 168 | metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`): 169 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 170 | "eval_bleu" if the prefix is "eval" (default) 171 | 172 | Returns: 173 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The 174 | dictionary also contains the epoch number which comes from the training state. 175 | """ 176 | if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): 177 | raise ValueError("eval_dataset must implement __len__") 178 | 179 | self.args.local_rank = -1 180 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 181 | self.args.local_rank = torch.distributed.get_rank() 182 | 183 | start_time = time.time() 184 | 185 | metrics = self.prediction_loop( 186 | eval_dataloader, 187 | description="Evaluation", 188 | # No point gathering the predictions if there are no metrics, otherwise we defer to 189 | # self.args.prediction_loss_only 190 | prediction_loss_only=True if self.compute_metrics is None else None, 191 | ignore_keys=ignore_keys, 192 | metric_key_prefix=metric_key_prefix, 193 | ) 194 | 195 | n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset) 196 | metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples)) 197 | self.log(metrics) 198 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) 199 | 200 | return metrics 201 | 202 | def create_optimizer(self, speedup_r=4.): 203 | if self.optimizer is None: 204 | decay_parameters = get_parameter_names(self.model, [torch.nn.LayerNorm]) 205 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 206 | speedup_parameters = [name for name in get_parameter_names(self.model, []) if 'extractor' in name and 'rel_classifier' not in name] 207 | optimizer_grouped_parameters = [ 208 | { 209 | "params": [p for n, p in self.model.named_parameters() if n in decay_parameters and n in speedup_parameters], 210 | "weight_decay": self.args.weight_decay, 211 | "lr": self.args.learning_rate *speedup_r, 212 | }, 213 | { 214 | "params": [p for n, p in self.model.named_parameters() if n not in decay_parameters and n in speedup_parameters], 215 | "weight_decay": 0.0, 216 | "lr": self.args.learning_rate *speedup_r, 217 | }, 218 | { 219 | "params": [p for n, p in self.model.named_parameters() if n in decay_parameters and n not in speedup_parameters], 220 | "weight_decay": self.args.weight_decay, 221 | "lr": self.args.learning_rate, 222 | }, 223 | { 224 | "params": [p for n, p in self.model.named_parameters() if n not in decay_parameters and n not in speedup_parameters], 225 | "weight_decay": 0.0, 226 | "lr": self.args.learning_rate, 227 | }, 228 | ] 229 | optimizer_cls = Adafactor if self.args.adafactor else AdamW 230 | if self.args.adafactor: 231 | optimizer_cls = Adafactor 232 | optimizer_kwargs = {"scale_parameter": False, "relative_step": False} 233 | else: 234 | optimizer_cls = AdamW 235 | optimizer_kwargs = { 236 | "betas": (self.args.adam_beta1, self.args.adam_beta2), 237 | "eps": self.args.adam_epsilon, 238 | } 239 | 240 | if self.sharded_ddp == ShardedDDPOption.SIMPLE: 241 | self.optimizer = OSS( 242 | params=optimizer_grouped_parameters, 243 | optim=optimizer_cls, 244 | **optimizer_kwargs, 245 | ) 246 | else: 247 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 248 | 249 | if is_sagemaker_mp_enabled(): 250 | import smdistributed.modelparallel.torch as smp 251 | self.optimizer = smp.DistributedOptimizer(self.optimizer) 252 | -------------------------------------------------------------------------------- /LiLTfinetune/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional, Tuple 3 | 4 | import torch 5 | 6 | from transformers.file_utils import ModelOutput 7 | 8 | 9 | @dataclass 10 | class ReOutput(ModelOutput): 11 | loss: Optional[torch.FloatTensor] = None 12 | logits: torch.FloatTensor = None 13 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 14 | attentions: Optional[Tuple[torch.FloatTensor]] = None 15 | entities: Optional[Dict] = None 16 | relations: Optional[Dict] = None 17 | pred_relations: Optional[Dict] = None 18 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: quality style 2 | check_dirs := LiLTfinetune examples 3 | # Check that source code meets quality standards 4 | 5 | quality: 6 | black --check $(check_dirs) 7 | isort --check-only $(check_dirs) 8 | flake8 $(check_dirs) 9 | 10 | # Format source code automatically 11 | 12 | style: 13 | black $(check_dirs) 14 | isort $(check_dirs) 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # What's New 2 | 3 | [2022/10] LiLT has been added to 🤗[huggingface/transformers](https://github.com/huggingface/transformers) in [HERE](https://huggingface.co/docs/transformers/main/model_doc/lilt). 4 | 5 | [2022/03] Initial model and code release. 6 | 7 | # LiLT (ACL 2022) 8 | 9 | This is the official PyTorch implementation of the ACL 2022 paper: "LiLT: A Simple yet Effective Language-Independent Layout Transformer for Structured Document Understanding". [[official](https://aclanthology.org/2022.acl-long.534/)] [[arXiv](https://arxiv.org/abs/2202.13669)] 10 | 11 | framework 12 | 13 | LiLT is pre-trained on the visually-rich documents of a single language (English) and can be directly fine-tuned on other languages with the corresponding off-the-shelf monolingual/multilingual pre-trained textual models. We hope the public availability of this work can help document intelligence researches. 14 | 15 | ## Installation 16 | 17 | For CUDA 11.X: 18 | 19 | ~~~bash 20 | conda create -n liltfinetune python=3.7 21 | conda activate liltfinetune 22 | conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=11.0 -c pytorch 23 | python -m pip install detectron2==0.5 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu110/torch1.7/index.html 24 | git clone https://github.com/jpWang/LiLT 25 | cd LiLT 26 | pip install -r requirements.txt 27 | pip install -e . 28 | ~~~ 29 | 30 | Or check [Detectron2](https://github.com/facebookresearch/detectron2/releases)/[PyTorch](https://pytorch.org/get-started/previous-versions/) versions and modify the command lines accordingly. 31 | 32 | ## Datasets 33 | 34 | In this repository, we provide the fine-tuning codes for [FUNSD](https://guillaumejaume.github.io/FUNSD/) and [XFUND](https://github.com/doc-analysis/XFUND). 35 | 36 | You can download our **pre-processed data (~1.2GB)** from [**HERE**](https://1drv.ms/u/s!Ahd-h7H5akVZeZQvKieg8g5THV8?e=mBRnxw), and put the unzipped `xfund&funsd/` under `LiLT/`. 37 | 38 | ## Available Checkpoints 39 | 40 | | Model | Language | Size | Download | 41 | | ----------------------------- | --------- | ----- | ------------ | 42 | | `lilt-roberta-en-base` | EN | 293MB | [OneDrive](https://1drv.ms/u/s!Ahd-h7H5akVZfhPVHQQ1tOypA48?e=nraHn3) | 43 | | `lilt-infoxlm-base` | MUL | 846MB | [OneDrive](https://1drv.ms/u/s!Ahd-h7H5akVZfeIhAQ8KHELRvcc?e=WS1P82) | 44 | | `lilt-only-base` | None | 21MB | [OneDrive](https://1drv.ms/u/s!Ahd-h7H5akVZfEIRbCmcWKjhoSM?e=6tMGbe) | 45 | 46 | ## Or Generate Your Own Checkpoint (Optional) 47 | 48 | If you want to combine the pre-trained LiLT with **other language's *RoBERTa***, please download `lilt-only-base` and use `gen_weight_roberta_like.py` to generate your own pre-trained checkpoint. 49 | 50 | **For example,** combine `lilt-only-base` with English `roberta-base`: 51 | 52 | ~~~bash 53 | mkdir roberta-en-base 54 | wget https://huggingface.co/roberta-base/resolve/main/config.json -O roberta-en-base/config.json 55 | wget https://huggingface.co/roberta-base/resolve/main/pytorch_model.bin -O roberta-en-base/pytorch_model.bin 56 | python gen_weight_roberta_like.py \ 57 | --lilt lilt-only-base/pytorch_model.bin \ 58 | --text roberta-en-base/pytorch_model.bin \ 59 | --config roberta-en-base/config.json \ 60 | --out lilt-roberta-en-base 61 | ~~~ 62 | 63 | Or combine `lilt-only-base` with `microsoft/infoxlm-base`: 64 | 65 | ~~~bash 66 | mkdir infoxlm-base 67 | wget https://huggingface.co/microsoft/infoxlm-base/resolve/main/config.json -O infoxlm-base/config.json 68 | wget https://huggingface.co/microsoft/infoxlm-base/resolve/main/pytorch_model.bin -O infoxlm-base/pytorch_model.bin 69 | python gen_weight_roberta_like.py \ 70 | --lilt lilt-only-base/pytorch_model.bin \ 71 | --text infoxlm-base/pytorch_model.bin \ 72 | --config infoxlm-base/config.json \ 73 | --out lilt-infoxlm-base 74 | ~~~ 75 | 76 | 77 | ## Fine-tuning 78 | 79 | 80 | ### Semantic Entity Recognition on FUNSD 81 | 82 | ``` 83 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 examples/run_funsd.py \ 84 | --model_name_or_path lilt-roberta-en-base \ 85 | --tokenizer_name roberta-base \ 86 | --output_dir ser_funsd_lilt-roberta-en-base \ 87 | --do_train \ 88 | --do_predict \ 89 | --max_steps 2000 \ 90 | --per_device_train_batch_size 8 \ 91 | --warmup_ratio 0.1 \ 92 | --fp16 93 | ``` 94 | 95 | ### Language-specific (For example, ZH) Semantic Entity Recognition on XFUND 96 | 97 | ``` 98 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 examples/run_xfun_ser.py \ 99 | --model_name_or_path lilt-infoxlm-base \ 100 | --tokenizer_name xlm-roberta-base \ 101 | --output_dir ls_ser_xfund_zh_lilt-infoxlm-base \ 102 | --do_train \ 103 | --do_eval \ 104 | --lang zh \ 105 | --max_steps 2000 \ 106 | --per_device_train_batch_size 16 \ 107 | --warmup_ratio 0.1 \ 108 | --fp16 109 | ``` 110 | 111 | ### Language-specific (For example, ZH) Relation Extraction on XFUND 112 | 113 | ``` 114 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 examples/run_xfun_re.py \ 115 | --model_name_or_path lilt-infoxlm-base \ 116 | --tokenizer_name xlm-roberta-base \ 117 | --output_dir ls_re_xfund_zh_lilt-infoxlm-base \ 118 | --do_train \ 119 | --do_eval \ 120 | --lang zh \ 121 | --max_steps 5000 \ 122 | --per_device_train_batch_size 8 \ 123 | --learning_rate 6.25e-6 \ 124 | --warmup_ratio 0.1 \ 125 | --fp16 126 | ``` 127 | 128 | ### Multi-task Semantic Entity Recognition on XFUND 129 | 130 | ``` 131 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 examples/run_xfun_ser.py \ 132 | --model_name_or_path lilt-infoxlm-base \ 133 | --tokenizer_name xlm-roberta-base \ 134 | --output_dir mt_ser_xfund_all_lilt-infoxlm-base \ 135 | --do_train \ 136 | --additional_langs all \ 137 | --max_steps 16000 \ 138 | --per_device_train_batch_size 16 \ 139 | --warmup_ratio 0.1 \ 140 | --fp16 141 | ``` 142 | 143 | ### Multi-task Relation Extraction on XFUND 144 | 145 | ``` 146 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 examples/run_xfun_re.py \ 147 | --model_name_or_path lilt-infoxlm-base \ 148 | --tokenizer_name xlm-roberta-base \ 149 | --output_dir mt_re_xfund_all_lilt-infoxlm-base \ 150 | --do_train \ 151 | --additional_langs all \ 152 | --max_steps 40000 \ 153 | --per_device_train_batch_size 8 \ 154 | --learning_rate 6.25e-6 \ 155 | --warmup_ratio 0.1 \ 156 | --fp16 157 | ``` 158 | 159 | 160 | ## Results 161 | 162 | ### Semantic Entity Recognition on FUNSD 163 | funsd 164 | 165 | ### Language-specific Fine-tuning on XFUND 166 | ls_xfund 167 | 168 | ### Cross-lingual Zero-shot Transfer on XFUND 169 | cl_xfund 170 | 171 | ### Multitask Fine-tuning on XFUND 172 | mt_xfund 173 | 174 | 175 | 176 | ## Acknowledge 177 | 178 | The repository benefits greatly from [unilm/layoutlmft](https://github.com/microsoft/unilm/tree/master/layoutlmft). Thanks a lot for their excellent work. 179 | 180 | ## Citation 181 | If our paper helps your research, please cite it in your publication(s): 182 | ``` 183 | @inproceedings{wang-etal-2022-lilt, 184 | title = "{L}i{LT}: A Simple yet Effective Language-Independent Layout Transformer for Structured Document Understanding", 185 | author={Wang, Jiapeng and Jin, Lianwen and Ding, Kai}, 186 | booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 187 | month = may, 188 | year = "2022", 189 | publisher = "Association for Computational Linguistics", 190 | url = "https://aclanthology.org/2022.acl-long.534", 191 | doi = "10.18653/v1/2022.acl-long.534", 192 | pages = "7747--7757", 193 | } 194 | ``` 195 | 196 | ## Feedback 197 | Suggestions and discussions are greatly welcome. Please contact the authors by sending email to `eejpwang@mail.scut.edu.cn`. 198 | -------------------------------------------------------------------------------- /examples/run_funsd.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | with open('tag.txt', 'w') as tagf: 4 | tagf.write('monolingual') 5 | import logging 6 | import os 7 | import sys 8 | from dataclasses import dataclass, field 9 | from typing import Optional 10 | 11 | import numpy as np 12 | from datasets import ClassLabel, load_dataset, load_metric 13 | 14 | import LiLTfinetune.data.datasets.funsd 15 | import transformers 16 | from LiLTfinetune.data import DataCollatorForKeyValueExtraction 17 | from LiLTfinetune.data.data_args import DataTrainingArguments 18 | from LiLTfinetune.models.model_args import ModelArguments 19 | from LiLTfinetune.trainers import FunsdTrainer as Trainer 20 | from transformers import ( 21 | AutoConfig, 22 | AutoModelForTokenClassification, 23 | AutoTokenizer, 24 | HfArgumentParser, 25 | PreTrainedTokenizerFast, 26 | TrainingArguments, 27 | set_seed, 28 | ) 29 | from transformers.trainer_utils import get_last_checkpoint, is_main_process 30 | from transformers.utils import check_min_version 31 | 32 | 33 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 34 | check_min_version("4.5.0") 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | def main(): 40 | # See all possible arguments in layoutlmft/transformers/training_args.py 41 | # or by passing the --help flag to this script. 42 | # We now keep distinct sets of args, for a cleaner separation of concerns. 43 | 44 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 45 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 46 | # If we pass only one argument to the script and it's the path to a json file, 47 | # let's parse it to get our arguments. 48 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 49 | else: 50 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 51 | 52 | # Detecting last checkpoint. 53 | last_checkpoint = None 54 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 55 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 56 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 57 | raise ValueError( 58 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 59 | "Use --overwrite_output_dir to overcome." 60 | ) 61 | elif last_checkpoint is not None: 62 | logger.info( 63 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 64 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 65 | ) 66 | 67 | # Setup logging 68 | logging.basicConfig( 69 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 70 | datefmt="%m/%d/%Y %H:%M:%S", 71 | handlers=[logging.StreamHandler(sys.stdout)], 72 | ) 73 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) 74 | 75 | # Log on each process the small summary: 76 | logger.warning( 77 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 78 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 79 | ) 80 | # Set the verbosity to info of the Transformers logger (on main process only): 81 | if is_main_process(training_args.local_rank): 82 | transformers.utils.logging.set_verbosity_info() 83 | transformers.utils.logging.enable_default_handler() 84 | transformers.utils.logging.enable_explicit_format() 85 | logger.info(f"Training/evaluation parameters {training_args}") 86 | 87 | # Set seed before initializing model. 88 | set_seed(training_args.seed) 89 | 90 | datasets = load_dataset(os.path.abspath(LiLTfinetune.data.datasets.funsd.__file__)) 91 | 92 | if training_args.do_train: 93 | column_names = datasets["train"].column_names 94 | features = datasets["train"].features 95 | else: 96 | column_names = datasets["validation"].column_names 97 | features = datasets["validation"].features 98 | text_column_name = "tokens" if "tokens" in column_names else column_names[0] 99 | label_column_name = ( 100 | f"{data_args.task_name}_tags" if f"{data_args.task_name}_tags" in column_names else column_names[1] 101 | ) 102 | 103 | remove_columns = column_names 104 | 105 | # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the 106 | # unique labels. 107 | def get_label_list(labels): 108 | unique_labels = set() 109 | for label in labels: 110 | unique_labels = unique_labels | set(label) 111 | label_list = list(unique_labels) 112 | label_list.sort() 113 | return label_list 114 | 115 | if isinstance(features[label_column_name].feature, ClassLabel): 116 | label_list = features[label_column_name].feature.names 117 | # No need to convert the labels since they are already ints. 118 | label_to_id = {i: i for i in range(len(label_list))} 119 | else: 120 | label_list = get_label_list(datasets["train"][label_column_name]) 121 | label_to_id = {l: i for i, l in enumerate(label_list)} 122 | num_labels = len(label_list) 123 | 124 | # Load pretrained model and tokenizer 125 | # 126 | # Distributed training: 127 | # The .from_pretrained methods guarantee that only one local process can concurrently 128 | # download model & vocab. 129 | config = AutoConfig.from_pretrained( 130 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 131 | num_labels=num_labels, 132 | finetuning_task=data_args.task_name, 133 | cache_dir=model_args.cache_dir, 134 | revision=model_args.model_revision, 135 | use_auth_token=True if model_args.use_auth_token else None, 136 | ) 137 | tokenizer = AutoTokenizer.from_pretrained( 138 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 139 | cache_dir=model_args.cache_dir, 140 | use_fast=True, 141 | revision=model_args.model_revision, 142 | use_auth_token=True if model_args.use_auth_token else None, 143 | add_prefix_space=True, 144 | ) 145 | model = AutoModelForTokenClassification.from_pretrained( 146 | model_args.model_name_or_path, 147 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 148 | config=config, 149 | cache_dir=model_args.cache_dir, 150 | revision=model_args.model_revision, 151 | use_auth_token=True if model_args.use_auth_token else None, 152 | ) 153 | 154 | # Tokenizer check: this script requires a fast tokenizer. 155 | if not isinstance(tokenizer, PreTrainedTokenizerFast): 156 | raise ValueError( 157 | "This example script only works for models that have a fast tokenizer. Checkout the big table of models " 158 | "at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this " 159 | "requirement" 160 | ) 161 | 162 | # Preprocessing the dataset 163 | # Padding strategy 164 | padding = "max_length" if data_args.pad_to_max_length else False 165 | 166 | # Tokenize all texts and align the labels with them. 167 | def tokenize_and_align_labels(examples): 168 | tokenized_inputs = tokenizer( 169 | examples[text_column_name], 170 | padding=padding, 171 | truncation=True, 172 | return_overflowing_tokens=True, 173 | # We use this argument because the texts in our dataset are lists of words (with a label for each word). 174 | is_split_into_words=True, 175 | ) 176 | 177 | labels = [] 178 | bboxes = [] 179 | images = [] 180 | for batch_index in range(len(tokenized_inputs["input_ids"])): 181 | word_ids = tokenized_inputs.word_ids(batch_index=batch_index) 182 | org_batch_index = tokenized_inputs["overflow_to_sample_mapping"][batch_index] 183 | 184 | label = examples[label_column_name][org_batch_index] 185 | bbox = examples["bboxes"][org_batch_index] 186 | image = examples["image"][org_batch_index] 187 | previous_word_idx = None 188 | label_ids = [] 189 | bbox_inputs = [] 190 | for word_idx in word_ids: 191 | # Special tokens have a word id that is None. We set the label to -100 so they are automatically 192 | # ignored in the loss function. 193 | if word_idx is None: 194 | label_ids.append(-100) 195 | bbox_inputs.append([0, 0, 0, 0]) 196 | # We set the label for the first token of each word. 197 | elif word_idx != previous_word_idx: 198 | label_ids.append(label_to_id[label[word_idx]]) 199 | bbox_inputs.append(bbox[word_idx]) 200 | # For the other tokens in a word, we set the label to either the current label or -100, depending on 201 | # the label_all_tokens flag. 202 | else: 203 | label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100) 204 | bbox_inputs.append(bbox[word_idx]) 205 | previous_word_idx = word_idx 206 | labels.append(label_ids) 207 | bboxes.append(bbox_inputs) 208 | images.append(image) 209 | tokenized_inputs["labels"] = labels 210 | tokenized_inputs["bbox"] = bboxes 211 | tokenized_inputs["image"] = images 212 | return tokenized_inputs 213 | 214 | if training_args.do_train: 215 | if "train" not in datasets: 216 | raise ValueError("--do_train requires a train dataset") 217 | train_dataset = datasets["train"] 218 | if data_args.max_train_samples is not None: 219 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 220 | train_dataset = train_dataset.map( 221 | tokenize_and_align_labels, 222 | batched=True, 223 | remove_columns=remove_columns, 224 | num_proc=data_args.preprocessing_num_workers, 225 | load_from_cache_file=not data_args.overwrite_cache, 226 | ) 227 | 228 | if training_args.do_eval: 229 | if "validation" not in datasets: 230 | raise ValueError("--do_eval requires a validation dataset") 231 | eval_dataset = datasets["validation"] 232 | if data_args.max_val_samples is not None: 233 | eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) 234 | eval_dataset = eval_dataset.map( 235 | tokenize_and_align_labels, 236 | batched=True, 237 | remove_columns=remove_columns, 238 | num_proc=data_args.preprocessing_num_workers, 239 | load_from_cache_file=not data_args.overwrite_cache, 240 | ) 241 | 242 | if training_args.do_predict: 243 | if "test" not in datasets: 244 | raise ValueError("--do_predict requires a test dataset") 245 | test_dataset = datasets["test"] 246 | if data_args.max_test_samples is not None: 247 | test_dataset = test_dataset.select(range(data_args.max_test_samples)) 248 | test_dataset = test_dataset.map( 249 | tokenize_and_align_labels, 250 | batched=True, 251 | remove_columns=remove_columns, 252 | num_proc=data_args.preprocessing_num_workers, 253 | load_from_cache_file=not data_args.overwrite_cache, 254 | ) 255 | 256 | # Data collator 257 | data_collator = DataCollatorForKeyValueExtraction( 258 | tokenizer, 259 | pad_to_multiple_of=8 if training_args.fp16 else None, 260 | padding=padding, 261 | max_length=512, 262 | ) 263 | 264 | # Metrics 265 | metric = load_metric("seqeval") 266 | 267 | def compute_metrics(p): 268 | predictions, labels = p 269 | predictions = np.argmax(predictions, axis=2) 270 | 271 | # Remove ignored index (special tokens) 272 | true_predictions = [ 273 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100] 274 | for prediction, label in zip(predictions, labels) 275 | ] 276 | true_labels = [ 277 | [label_list[l] for (p, l) in zip(prediction, label) if l != -100] 278 | for prediction, label in zip(predictions, labels) 279 | ] 280 | 281 | results = metric.compute(predictions=true_predictions, references=true_labels) 282 | if data_args.return_entity_level_metrics: 283 | # Unpack nested dictionaries 284 | final_results = {} 285 | for key, value in results.items(): 286 | if isinstance(value, dict): 287 | for n, v in value.items(): 288 | final_results[f"{key}_{n}"] = v 289 | else: 290 | final_results[key] = value 291 | return final_results 292 | else: 293 | return { 294 | "precision": results["overall_precision"], 295 | "recall": results["overall_recall"], 296 | "f1": results["overall_f1"], 297 | "accuracy": results["overall_accuracy"], 298 | } 299 | 300 | # Initialize our Trainer 301 | trainer = Trainer( 302 | model=model, 303 | args=training_args, 304 | train_dataset=train_dataset if training_args.do_train else None, 305 | eval_dataset=eval_dataset if training_args.do_eval else None, 306 | tokenizer=tokenizer, 307 | data_collator=data_collator, 308 | compute_metrics=compute_metrics, 309 | ) 310 | 311 | # Training 312 | if training_args.do_train: 313 | checkpoint = last_checkpoint if last_checkpoint else None 314 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 315 | metrics = train_result.metrics 316 | trainer.save_model() # Saves the tokenizer too for easy upload 317 | 318 | max_train_samples = ( 319 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 320 | ) 321 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 322 | 323 | trainer.log_metrics("train", metrics) 324 | trainer.save_metrics("train", metrics) 325 | trainer.save_state() 326 | 327 | # Evaluation 328 | if training_args.do_eval: 329 | logger.info("*** Evaluate ***") 330 | 331 | metrics = trainer.evaluate() 332 | 333 | max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) 334 | metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) 335 | 336 | trainer.log_metrics("eval", metrics) 337 | trainer.save_metrics("eval", metrics) 338 | 339 | # Predict 340 | if training_args.do_predict: 341 | logger.info("*** Predict ***") 342 | 343 | predictions, labels, metrics = trainer.predict(test_dataset) 344 | predictions = np.argmax(predictions, axis=2) 345 | 346 | # Remove ignored index (special tokens) 347 | true_predictions = [ 348 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100] 349 | for prediction, label in zip(predictions, labels) 350 | ] 351 | 352 | trainer.log_metrics("test", metrics) 353 | trainer.save_metrics("test", metrics) 354 | 355 | # Save predictions 356 | output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt") 357 | if trainer.is_world_process_zero(): 358 | with open(output_test_predictions_file, "w") as writer: 359 | for prediction in true_predictions: 360 | writer.write(" ".join(prediction) + "\n") 361 | 362 | 363 | def _mp_fn(index): 364 | # For xla_spawn (TPUs) 365 | main() 366 | 367 | 368 | if __name__ == "__main__": 369 | main() 370 | -------------------------------------------------------------------------------- /examples/run_xfun_re.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | with open('tag.txt', 'w') as tagf: 4 | tagf.write('multilingual') 5 | import logging 6 | import os 7 | import sys 8 | 9 | import numpy as np 10 | from datasets import ClassLabel, load_dataset 11 | 12 | import LiLTfinetune.data.datasets.xfun 13 | import transformers 14 | from LiLTfinetune import AutoModelForRelationExtraction 15 | from LiLTfinetune.data.data_args import XFUNDataTrainingArguments 16 | from LiLTfinetune.data.data_collator import DataCollatorForKeyValueExtraction 17 | from LiLTfinetune.evaluation import re_score 18 | from LiLTfinetune.models.model_args import ModelArguments 19 | from LiLTfinetune.trainers import XfunReTrainer 20 | from transformers import ( 21 | AutoConfig, 22 | AutoTokenizer, 23 | HfArgumentParser, 24 | PreTrainedTokenizerFast, 25 | TrainingArguments, 26 | set_seed, 27 | ) 28 | from transformers.trainer_utils import get_last_checkpoint, is_main_process 29 | 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | def main(): 35 | # See all possible arguments in src/transformers/training_args.py 36 | # or by passing the --help flag to this script. 37 | # We now keep distinct sets of args, for a cleaner separation of concerns. 38 | 39 | parser = HfArgumentParser((ModelArguments, XFUNDataTrainingArguments, TrainingArguments)) 40 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 41 | # If we pass only one argument to the script and it's the path to a json file, 42 | # let's parse it to get our arguments. 43 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 44 | else: 45 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 46 | 47 | # Detecting last checkpoint. 48 | last_checkpoint = None 49 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 50 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 51 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 52 | raise ValueError( 53 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 54 | "Use --overwrite_output_dir to overcome." 55 | ) 56 | elif last_checkpoint is not None: 57 | logger.info( 58 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 59 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 60 | ) 61 | 62 | # Setup logging 63 | logging.basicConfig( 64 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 65 | datefmt="%m/%d/%Y %H:%M:%S", 66 | handlers=[logging.StreamHandler(sys.stdout)], 67 | ) 68 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) 69 | 70 | # Log on each process the small summary: 71 | logger.warning( 72 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 73 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 74 | ) 75 | # Set the verbosity to info of the Transformers logger (on main process only): 76 | if is_main_process(training_args.local_rank): 77 | transformers.utils.logging.set_verbosity_info() 78 | transformers.utils.logging.enable_default_handler() 79 | transformers.utils.logging.enable_explicit_format() 80 | logger.info(f"Training/evaluation parameters {training_args}") 81 | 82 | # Set seed before initializing model. 83 | set_seed(training_args.seed) 84 | datasets = load_dataset( 85 | os.path.abspath(LiLTfinetune.data.datasets.xfun.__file__), 86 | f"xfun.{data_args.lang}", 87 | additional_langs=data_args.additional_langs, 88 | keep_in_memory=True, 89 | ) 90 | if training_args.do_train: 91 | column_names = datasets["train"].column_names 92 | features = datasets["train"].features 93 | else: 94 | column_names = datasets["validation"].column_names 95 | features = datasets["validation"].features 96 | text_column_name = "input_ids" 97 | label_column_name = "labels" 98 | 99 | remove_columns = column_names 100 | 101 | # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the 102 | # unique labels. 103 | def get_label_list(labels): 104 | unique_labels = set() 105 | for label in labels: 106 | unique_labels = unique_labels | set(label) 107 | label_list = list(unique_labels) 108 | label_list.sort() 109 | return label_list 110 | 111 | if isinstance(features[label_column_name].feature, ClassLabel): 112 | label_list = features[label_column_name].feature.names 113 | # No need to convert the labels since they are already ints. 114 | label_to_id = {i: i for i in range(len(label_list))} 115 | else: 116 | label_list = get_label_list(datasets["train"][label_column_name]) 117 | label_to_id = {l: i for i, l in enumerate(label_list)} 118 | num_labels = len(label_list) 119 | 120 | # Load pretrained model and tokenizer 121 | # 122 | # Distributed training: 123 | # The .from_pretrained methods guarantee that only one local process can concurrently 124 | # download model & vocab. 125 | config = AutoConfig.from_pretrained( 126 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 127 | num_labels=num_labels, 128 | finetuning_task=data_args.task_name, 129 | cache_dir=model_args.cache_dir, 130 | revision=model_args.model_revision, 131 | use_auth_token=True if model_args.use_auth_token else None, 132 | ) 133 | tokenizer = AutoTokenizer.from_pretrained( 134 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 135 | cache_dir=model_args.cache_dir, 136 | use_fast=True, 137 | revision=model_args.model_revision, 138 | use_auth_token=True if model_args.use_auth_token else None, 139 | ) 140 | model = AutoModelForRelationExtraction.from_pretrained( 141 | model_args.model_name_or_path, 142 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 143 | config=config, 144 | cache_dir=model_args.cache_dir, 145 | revision=model_args.model_revision, 146 | use_auth_token=True if model_args.use_auth_token else None, 147 | ) 148 | 149 | # Tokenizer check: this script requires a fast tokenizer. 150 | if not isinstance(tokenizer, PreTrainedTokenizerFast): 151 | raise ValueError( 152 | "This example script only works for models that have a fast tokenizer. Checkout the big table of models " 153 | "at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this " 154 | "requirement" 155 | ) 156 | 157 | # Preprocessing the dataset 158 | # Padding strategy 159 | padding = "max_length" if data_args.pad_to_max_length else False 160 | 161 | if training_args.do_train: 162 | if "train" not in datasets: 163 | raise ValueError("--do_train requires a train dataset") 164 | train_dataset = datasets["train"] 165 | if data_args.max_train_samples is not None: 166 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 167 | 168 | if training_args.do_eval: 169 | if "validation" not in datasets: 170 | raise ValueError("--do_eval requires a validation dataset") 171 | eval_dataset = datasets["validation"] 172 | if data_args.max_val_samples is not None: 173 | eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) 174 | 175 | if training_args.do_predict: 176 | if "test" not in datasets: 177 | raise ValueError("--do_predict requires a test dataset") 178 | test_dataset = datasets["test"] 179 | if data_args.max_test_samples is not None: 180 | test_dataset = test_dataset.select(range(data_args.max_test_samples)) 181 | 182 | # Data collator 183 | data_collator = DataCollatorForKeyValueExtraction( 184 | tokenizer, 185 | pad_to_multiple_of=8 if training_args.fp16 else None, 186 | padding=padding, 187 | max_length=512, 188 | ) 189 | 190 | def compute_metrics(p): 191 | pred_relations, gt_relations = p 192 | score = re_score(pred_relations, gt_relations, mode="boundaries") 193 | return score 194 | 195 | # Initialize our Trainer 196 | trainer = XfunReTrainer( 197 | model=model, 198 | args=training_args, 199 | train_dataset=train_dataset if training_args.do_train else None, 200 | eval_dataset=eval_dataset if training_args.do_eval else None, 201 | tokenizer=tokenizer, 202 | data_collator=data_collator, 203 | compute_metrics=compute_metrics, 204 | ) 205 | 206 | # Training 207 | if training_args.do_train: 208 | checkpoint = last_checkpoint if last_checkpoint else None 209 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 210 | metrics = train_result.metrics 211 | trainer.save_model() # Saves the tokenizer too for easy upload 212 | 213 | max_train_samples = ( 214 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 215 | ) 216 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 217 | 218 | trainer.log_metrics("train", metrics) 219 | trainer.save_metrics("train", metrics) 220 | trainer.save_state() 221 | 222 | # Evaluation 223 | if training_args.do_eval: 224 | logger.info("*** Evaluate ***") 225 | 226 | metrics = trainer.evaluate() 227 | 228 | max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) 229 | metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) 230 | 231 | trainer.log_metrics("eval", metrics) 232 | trainer.save_metrics("eval", metrics) 233 | 234 | 235 | def _mp_fn(index): 236 | # For xla_spawn (TPUs) 237 | main() 238 | 239 | 240 | if __name__ == "__main__": 241 | main() 242 | -------------------------------------------------------------------------------- /examples/run_xfun_ser.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | with open('tag.txt', 'w') as tagf: 4 | tagf.write('multilingual') 5 | import logging 6 | import os 7 | import sys 8 | from dataclasses import dataclass, field 9 | from typing import Optional 10 | 11 | import numpy as np 12 | from datasets import ClassLabel, load_dataset, load_metric 13 | 14 | import LiLTfinetune.data.datasets.xfun 15 | import transformers 16 | from LiLTfinetune.data import DataCollatorForKeyValueExtraction 17 | from LiLTfinetune.data.data_args import XFUNDataTrainingArguments 18 | from LiLTfinetune.models.model_args import ModelArguments 19 | from LiLTfinetune.trainers import XfunSerTrainer 20 | from transformers import ( 21 | AutoConfig, 22 | AutoModelForTokenClassification, 23 | AutoTokenizer, 24 | HfArgumentParser, 25 | PreTrainedTokenizerFast, 26 | TrainingArguments, 27 | set_seed, 28 | ) 29 | from transformers.trainer_utils import get_last_checkpoint, is_main_process 30 | from transformers.utils import check_min_version 31 | 32 | 33 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 34 | check_min_version("4.5.0") 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | def main(): 40 | 41 | parser = HfArgumentParser((ModelArguments, XFUNDataTrainingArguments, TrainingArguments)) 42 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 43 | # If we pass only one argument to the script and it's the path to a json file, 44 | # let's parse it to get our arguments. 45 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 46 | else: 47 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 48 | 49 | # Detecting last checkpoint. 50 | last_checkpoint = None 51 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 52 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 53 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 54 | raise ValueError( 55 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 56 | "Use --overwrite_output_dir to overcome." 57 | ) 58 | elif last_checkpoint is not None: 59 | logger.info( 60 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 61 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 62 | ) 63 | 64 | # Setup logging 65 | logging.basicConfig( 66 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 67 | datefmt="%m/%d/%Y %H:%M:%S", 68 | handlers=[logging.StreamHandler(sys.stdout)], 69 | ) 70 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) 71 | 72 | # Log on each process the small summary: 73 | logger.warning( 74 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 75 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 76 | ) 77 | # Set the verbosity to info of the Transformers logger (on main process only): 78 | if is_main_process(training_args.local_rank): 79 | transformers.utils.logging.set_verbosity_info() 80 | transformers.utils.logging.enable_default_handler() 81 | transformers.utils.logging.enable_explicit_format() 82 | logger.info(f"Training/evaluation parameters {training_args}") 83 | 84 | # Set seed before initializing model. 85 | set_seed(training_args.seed) 86 | datasets = load_dataset( 87 | os.path.abspath(LiLTfinetune.data.datasets.xfun.__file__), 88 | f"xfun.{data_args.lang}", 89 | additional_langs=data_args.additional_langs, 90 | keep_in_memory=True, 91 | ) 92 | if training_args.do_train: 93 | column_names = datasets["train"].column_names 94 | features = datasets["train"].features 95 | else: 96 | column_names = datasets["validation"].column_names 97 | features = datasets["validation"].features 98 | text_column_name = "input_ids" 99 | label_column_name = "labels" 100 | 101 | remove_columns = column_names 102 | 103 | # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the 104 | # unique labels. 105 | def get_label_list(labels): 106 | unique_labels = set() 107 | for label in labels: 108 | unique_labels = unique_labels | set(label) 109 | label_list = list(unique_labels) 110 | label_list.sort() 111 | return label_list 112 | 113 | if isinstance(features[label_column_name].feature, ClassLabel): 114 | label_list = features[label_column_name].feature.names 115 | # No need to convert the labels since they are already ints. 116 | label_to_id = {i: i for i in range(len(label_list))} 117 | else: 118 | label_list = get_label_list(datasets["train"][label_column_name]) 119 | label_to_id = {l: i for i, l in enumerate(label_list)} 120 | num_labels = len(label_list) 121 | 122 | # Load pretrained model and tokenizer 123 | # 124 | # Distributed training: 125 | # The .from_pretrained methods guarantee that only one local process can concurrently 126 | # download model & vocab. 127 | config = AutoConfig.from_pretrained( 128 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 129 | num_labels=num_labels, 130 | finetuning_task=data_args.task_name, 131 | cache_dir=model_args.cache_dir, 132 | revision=model_args.model_revision, 133 | use_auth_token=True if model_args.use_auth_token else None, 134 | ) 135 | tokenizer = AutoTokenizer.from_pretrained( 136 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 137 | cache_dir=model_args.cache_dir, 138 | use_fast=True, 139 | revision=model_args.model_revision, 140 | use_auth_token=True if model_args.use_auth_token else None, 141 | ) 142 | model = AutoModelForTokenClassification.from_pretrained( 143 | model_args.model_name_or_path, 144 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 145 | config=config, 146 | cache_dir=model_args.cache_dir, 147 | revision=model_args.model_revision, 148 | use_auth_token=True if model_args.use_auth_token else None, 149 | ) 150 | 151 | # Tokenizer check: this script requires a fast tokenizer. 152 | if not isinstance(tokenizer, PreTrainedTokenizerFast): 153 | raise ValueError( 154 | "This example script only works for models that have a fast tokenizer. Checkout the big table of models " 155 | "at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this " 156 | "requirement" 157 | ) 158 | 159 | # Preprocessing the dataset 160 | # Padding strategy 161 | padding = "max_length" if data_args.pad_to_max_length else False 162 | 163 | if training_args.do_train: 164 | if "train" not in datasets: 165 | raise ValueError("--do_train requires a train dataset") 166 | train_dataset = datasets["train"] 167 | if data_args.max_train_samples is not None: 168 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 169 | 170 | if training_args.do_eval: 171 | if "validation" not in datasets: 172 | raise ValueError("--do_eval requires a validation dataset") 173 | eval_dataset = datasets["validation"] 174 | if data_args.max_val_samples is not None: 175 | eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) 176 | 177 | if training_args.do_predict: 178 | if "test" not in datasets: 179 | raise ValueError("--do_predict requires a test dataset") 180 | test_dataset = datasets["test"] 181 | if data_args.max_test_samples is not None: 182 | test_dataset = test_dataset.select(range(data_args.max_test_samples)) 183 | 184 | # Data collator 185 | data_collator = DataCollatorForKeyValueExtraction( 186 | tokenizer, 187 | pad_to_multiple_of=8 if training_args.fp16 else None, 188 | padding=padding, 189 | max_length=512, 190 | ) 191 | 192 | # Metrics 193 | metric = load_metric("seqeval") 194 | 195 | def compute_metrics(p): 196 | predictions, labels = p 197 | predictions = np.argmax(predictions, axis=2) 198 | 199 | # Remove ignored index (special tokens) 200 | true_predictions = [ 201 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100] 202 | for prediction, label in zip(predictions, labels) 203 | ] 204 | true_labels = [ 205 | [label_list[l] for (p, l) in zip(prediction, label) if l != -100] 206 | for prediction, label in zip(predictions, labels) 207 | ] 208 | 209 | results = metric.compute(predictions=true_predictions, references=true_labels) 210 | if data_args.return_entity_level_metrics: 211 | # Unpack nested dictionaries 212 | final_results = {} 213 | for key, value in results.items(): 214 | if isinstance(value, dict): 215 | for n, v in value.items(): 216 | final_results[f"{key}_{n}"] = v 217 | else: 218 | final_results[key] = value 219 | return final_results 220 | else: 221 | return { 222 | "precision": results["overall_precision"], 223 | "recall": results["overall_recall"], 224 | "f1": results["overall_f1"], 225 | "accuracy": results["overall_accuracy"], 226 | } 227 | 228 | # Initialize our Trainer 229 | trainer = XfunSerTrainer( 230 | model=model, 231 | args=training_args, 232 | train_dataset=train_dataset if training_args.do_train else None, 233 | eval_dataset=eval_dataset if training_args.do_eval else None, 234 | tokenizer=tokenizer, 235 | data_collator=data_collator, 236 | compute_metrics=compute_metrics, 237 | ) 238 | 239 | # Training 240 | if training_args.do_train: 241 | checkpoint = last_checkpoint if last_checkpoint else None 242 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 243 | metrics = train_result.metrics 244 | trainer.save_model() # Saves the tokenizer too for easy upload 245 | 246 | max_train_samples = ( 247 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 248 | ) 249 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 250 | 251 | trainer.log_metrics("train", metrics) 252 | trainer.save_metrics("train", metrics) 253 | trainer.save_state() 254 | 255 | # Evaluation 256 | if training_args.do_eval: 257 | logger.info("*** Evaluate ***") 258 | 259 | metrics = trainer.evaluate() 260 | 261 | max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) 262 | metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) 263 | 264 | trainer.log_metrics("eval", metrics) 265 | trainer.save_metrics("eval", metrics) 266 | 267 | # Predict 268 | if training_args.do_predict: 269 | logger.info("*** Predict ***") 270 | 271 | predictions, labels, metrics = trainer.predict(test_dataset) 272 | predictions = np.argmax(predictions, axis=2) 273 | 274 | # Remove ignored index (special tokens) 275 | true_predictions = [ 276 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100] 277 | for prediction, label in zip(predictions, labels) 278 | ] 279 | 280 | trainer.log_metrics("test", metrics) 281 | trainer.save_metrics("test", metrics) 282 | 283 | # Save predictions 284 | output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt") 285 | if trainer.is_world_process_zero(): 286 | with open(output_test_predictions_file, "w") as writer: 287 | for prediction in true_predictions: 288 | writer.write(" ".join(prediction) + "\n") 289 | 290 | 291 | def _mp_fn(index): 292 | # For xla_spawn (TPUs) 293 | main() 294 | 295 | 296 | if __name__ == "__main__": 297 | main() 298 | -------------------------------------------------------------------------------- /figs/cl_xfund.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/figs/cl_xfund.png -------------------------------------------------------------------------------- /figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/figs/framework.png -------------------------------------------------------------------------------- /figs/funsd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/figs/funsd.png -------------------------------------------------------------------------------- /figs/ls_xfund.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/figs/ls_xfund.png -------------------------------------------------------------------------------- /figs/mt_xfund.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpWang/LiLT/a39930b2c5425da7250f0dde04252cf60ec3b1b7/figs/mt_xfund.png -------------------------------------------------------------------------------- /gen_weight_roberta_like.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os, json 3 | import argparse 4 | from transformers import AutoConfig, AutoModel 5 | 6 | if __name__ == '__main__': 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--lilt', type=str, required=True, help='Path to LiLT model.') 10 | parser.add_argument('--text', type=str, required=True, help='Path to text model.') 11 | parser.add_argument('--config', type=str, required=True, help='Path to text config.') 12 | parser.add_argument('--out', type=str, required=True, help='Path to output.') 13 | opt = parser.parse_args() 14 | 15 | with open(opt.config, 'r') as jf: 16 | config = json.load(jf) 17 | config['channel_shrink_ratio'] = 4 18 | config['max_2d_position_embeddings'] = 1024 19 | config['model_type'] = 'liltrobertalike' 20 | 21 | if not os.path.isdir(opt.out): 22 | os.makedirs(opt.out) 23 | with open(os.path.join(opt.out, 'config.json'), 'w') as jf: 24 | json.dump(config, jf, sort_keys=True, indent=2, separators=(',', ': '),) 25 | 26 | text_model = torch.load(opt.text) 27 | text_model = {k.replace('roberta.', 'lilt.'): v for (k, v) in text_model.items()} 28 | lilt_model = torch.load(opt.lilt) 29 | total_model = {**text_model, **lilt_model} 30 | torch.save(total_model, os.path.join(opt.out, 'pytorch_model.bin')) 31 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 119 3 | target-version = ['py35'] 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==1.6.2 2 | transformers==4.5.1 3 | seqeval==1.2.2 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | default_section = FIRSTPARTY 3 | ensure_newline_before_comments = True 4 | force_grid_wrap = 0 5 | include_trailing_comma = True 6 | known_first_party = LiLT 7 | known_third_party = 8 | datasets 9 | git 10 | h5py 11 | numpy 12 | packaging 13 | PIL 14 | seqeval 15 | torch 16 | torchvision 17 | tqdm 18 | 19 | line_length = 119 20 | lines_after_imports = 2 21 | multi_line_output = 3 22 | use_parentheses = True 23 | 24 | [flake8] 25 | ignore = E203, E501, E741, W503, W605 26 | max-line-length = 119 27 | per-file-ignores = __init__.py:F401 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from setuptools import find_packages, setup 3 | setup( 4 | name="LiLTfinetune", 5 | version="1.0", 6 | author="Deep Learning and Vision Computing Lab, SCUT", 7 | url="https://github.com/jpWang/LiLT", 8 | packages=find_packages(), 9 | python_requires=">=3.7", 10 | extras_require={"dev": ["flake8", "isort", "black"]}, 11 | ) --------------------------------------------------------------------------------