├── .gitignore ├── RE_HuggingFace ├── __pycache__ │ └── main.cpython-38.pyc ├── download.py ├── layoutlm_re.py └── layoutlmv2_re.py ├── .vscode ├── settings.json └── launch.json ├── LICENSE └── SER_HuggingFace ├── main.py └── download.py /.gitignore: -------------------------------------------------------------------------------- 1 | dataset/ 2 | *.log 3 | *.pt 4 | /SER_HuggingFace/model 5 | /SER_HuggingFace/out 6 | /unilm 7 | /energy_tracker 8 | /wandb 9 | -------------------------------------------------------------------------------- /RE_HuggingFace/__pycache__/main.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang0369/Information_Extraction/HEAD/RE_HuggingFace/__pycache__/main.cpython-38.pyc -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestArgs": ["{workspace}/RE_2"], // set pytest root dir 3 | "python.terminal.activateEnvironment": true, 4 | "python.testing.unittestEnabled": false, 5 | "python.testing.pytestEnabled": true, 6 | 7 | // linting 8 | "python.linting.flake8Enabled": true, 9 | "python.linting.enabled": true, 10 | "python.linting.flake8Args": ["--select", "E20,E21,E22,E23,E24,E3,F401,F8", "--ignore=E226", "--verbose"], 11 | 12 | // pylance 13 | "python.analysis.extraPaths": [""], // set pylance paths, default import paths 14 | "python.autoComplete.extraPaths": [""], 15 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2023, Yang Kewen 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "stopOnEntry": true, 14 | "justMyCode": false, 15 | }, 16 | { 17 | "name": "Python: run layoutlmv2_re", 18 | "type": "python", 19 | "request": "launch", 20 | "program": "${workspaceFolder}/RE_HuggingFace/layoutlmv2_re.py", 21 | "console": "integratedTerminal", 22 | "justMyCode": true, 23 | "env": { 24 | // "CUDA_LAUNCH_BLOCKING": "1", 25 | "PYTHONPATH": "${workspaceFolder}" // set multiple paths to sys.path, windows by semicolon while linux by colon 26 | } 27 | }, 28 | { 29 | "name": "Python: run layoutlm_re", 30 | "type": "python", 31 | "request": "launch", 32 | "program": "${workspaceFolder}/RE_HuggingFace/layoutlm_re.py", 33 | "console": "integratedTerminal", 34 | "justMyCode": true, 35 | "env": { 36 | // "CUDA_LAUNCH_BLOCKING": "1", 37 | "PYTHONPATH": "${workspaceFolder}" // set multiple paths to sys.path, windows by semicolon while linux by colon 38 | } 39 | }, 40 | { 41 | "name": "Python: run SER_HuggingFace", 42 | "type": "python", 43 | "request": "launch", 44 | "program": "${workspaceFolder}/SER_HuggingFace/main.py", 45 | "console": "integratedTerminal", 46 | "justMyCode": true, 47 | "env": { 48 | "PYTHONPATH": "${workspaceFolder}" 49 | } 50 | } 51 | ] 52 | } -------------------------------------------------------------------------------- /SER_HuggingFace/main.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Optional, Union 4 | 5 | import numpy as np 6 | from datasets import load_dataset, load_metric 7 | from torch.utils.data import DataLoader 8 | 9 | from transformers import (LayoutLMv2FeatureExtractor, 10 | LayoutLMv2ForTokenClassification, 11 | LayoutLMv2Processor, LayoutLMv2TokenizerFast, 12 | PreTrainedTokenizerBase, Trainer, TrainingArguments) 13 | from transformers.file_utils import PaddingStrategy 14 | 15 | ROOT = Path(__file__).parents[1] 16 | 17 | 18 | dataset = load_dataset((ROOT / "RE_HuggingFace/download.py").as_posix(), "en") 19 | 20 | labels = dataset['train'].features['labels'].feature.names 21 | id2label = {k: v for k, v in enumerate(labels)} 22 | label2id = {v: k for k, v in enumerate(labels)} 23 | 24 | feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False) 25 | tokenizer = LayoutLMv2TokenizerFast.from_pretrained("microsoft/layoutlmv2-base-uncased") 26 | 27 | 28 | @dataclass 29 | class DataCollatorForTokenClassification: 30 | """ 31 | Data collator that will dynamically pad the inputs received, as well as the labels. 32 | 33 | Args: 34 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): 35 | The tokenizer used for encoding the data. 36 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): 37 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 38 | among: 39 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 40 | sequence if provided). 41 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 42 | maximum acceptable input length for the model if that argument is not provided. 43 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 44 | different lengths). 45 | max_length (:obj:`int`, `optional`): 46 | Maximum length of the returned list and optionally padding length (see above). 47 | pad_to_multiple_of (:obj:`int`, `optional`): 48 | If set will pad the sequence to a multiple of the provided value. 49 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 50 | 7.5 (Volta). 51 | label_pad_token_id (:obj:`int`, `optional`, defaults to -100): 52 | The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions). 53 | """ 54 | feature_extractor: LayoutLMv2FeatureExtractor 55 | tokenizer: PreTrainedTokenizerBase 56 | padding: Union[bool, str, PaddingStrategy] = True 57 | max_length: Optional[int] = None 58 | pad_to_multiple_of: Optional[int] = None 59 | label_pad_token_id: int = -100 60 | 61 | def __call__(self, features): 62 | # prepare image input 63 | image = self.feature_extractor([feature["original_image"] for feature in features], return_tensors="pt").pixel_values 64 | 65 | # prepare text input 66 | for feature in features: 67 | del feature["image"] 68 | del feature["id"] 69 | del feature["original_image"] 70 | del feature["entities"] 71 | del feature["relations"] 72 | 73 | batch = self.tokenizer.pad( 74 | features, 75 | padding=self.padding, 76 | max_length=self.max_length, 77 | pad_to_multiple_of=self.pad_to_multiple_of, 78 | return_tensors="pt" 79 | ) 80 | 81 | batch["image"] = image 82 | 83 | return batch 84 | 85 | 86 | data_collator = DataCollatorForTokenClassification( 87 | feature_extractor, 88 | tokenizer, 89 | pad_to_multiple_of=None, 90 | padding="max_length", 91 | max_length=512, 92 | ) 93 | 94 | train_dataset = dataset['train'] 95 | val_dataset = dataset['validation'] 96 | 97 | dataloader = DataLoader(train_dataset, batch_size=4, collate_fn=data_collator) 98 | 99 | 100 | model = LayoutLMv2ForTokenClassification.from_pretrained('microsoft/layoutlmv2-base-uncased', 101 | id2label=id2label, 102 | label2id=label2id) 103 | 104 | 105 | # Metrics 106 | metric = load_metric("seqeval") 107 | return_entity_level_metrics = True 108 | 109 | 110 | def compute_metrics(p): 111 | predictions, labels = p 112 | predictions = np.argmax(predictions, axis=2) 113 | 114 | # Remove ignored index (special tokens) 115 | true_predictions = [ 116 | [id2label[p] for (p, l) in zip(prediction, label) if l != -100] 117 | for prediction, label in zip(predictions, labels) 118 | ] 119 | true_labels = [ 120 | [id2label[l] for (p, l) in zip(prediction, label) if l != -100] 121 | for prediction, label in zip(predictions, labels) 122 | ] 123 | 124 | results = metric.compute(predictions=true_predictions, references=true_labels) 125 | if return_entity_level_metrics: 126 | # Unpack nested dictionaries 127 | final_results = {} 128 | for key, value in results.items(): 129 | if isinstance(value, dict): 130 | for n, v in value.items(): 131 | final_results[f"{key}_{n}"] = v 132 | else: 133 | final_results[key] = value 134 | return final_results 135 | else: 136 | return { 137 | "precision": results["overall_precision"], 138 | "recall": results["overall_recall"], 139 | "f1": results["overall_f1"], 140 | "accuracy": results["overall_accuracy"], 141 | } 142 | 143 | 144 | args = TrainingArguments( 145 | output_dir=ROOT / "SER" / "out", # name of directory to store the checkpoints 146 | overwrite_output_dir=True, 147 | # max_steps=1000, # we train for a maximum of 1,000 batches 148 | num_train_epochs=10, 149 | no_cuda=False, 150 | 151 | warmup_ratio=0.1, # we warmup a bit 152 | # fp16=True, # we use mixed precision (less memory consumption) 153 | per_device_train_batch_size=2, 154 | per_device_eval_batch_size=2, 155 | learning_rate=1e-5, 156 | remove_unused_columns=False, 157 | push_to_hub=False, # we'd like to push our model to the hub during training 158 | ) 159 | 160 | # Initialize our Trainer 161 | trainer = Trainer( 162 | model=model, 163 | args=args, 164 | train_dataset=train_dataset, 165 | eval_dataset=val_dataset, 166 | tokenizer=tokenizer, 167 | data_collator=data_collator, 168 | compute_metrics=compute_metrics, 169 | ) 170 | 171 | train_metrics = trainer.train() 172 | 173 | eval_metrics = trainer.evaluate() 174 | print(eval_metrics) 175 | trainer.save_model(ROOT / "SER" / "model") 176 | 177 | 178 | feature_extractor = LayoutLMv2FeatureExtractor(ocr_lang="eng") 179 | processor = LayoutLMv2Processor(feature_extractor, tokenizer) 180 | processor.save_pretrained(ROOT / "SER" / "model") 181 | -------------------------------------------------------------------------------- /SER_HuggingFace/download.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | import json 3 | import logging 4 | import os 5 | from pathlib import Path 6 | 7 | import datasets 8 | 9 | from PIL import Image 10 | import numpy as np 11 | 12 | ROOT = Path(__file__).parents[1] 13 | 14 | from transformers import AutoTokenizer 15 | 16 | 17 | def load_image(image_path, size=None): 18 | image = Image.open(image_path).convert("RGB") 19 | w, h = image.size 20 | if size is not None: 21 | # resize image 22 | image = image.resize((size, size)) 23 | image = np.asarray(image) 24 | image = image[:, :, ::-1] # flip color channels from RGB to BGR 25 | image = image.transpose(2, 0, 1) # move channels to first dimension 26 | return image, (w, h) 27 | 28 | 29 | def normalize_bbox(bbox, size): 30 | return [ 31 | int(1000 * bbox[0] / size[0]), 32 | int(1000 * bbox[1] / size[1]), 33 | int(1000 * bbox[2] / size[0]), 34 | int(1000 * bbox[3] / size[1]), 35 | ] 36 | 37 | 38 | def simplify_bbox(bbox): 39 | return [ 40 | min(bbox[0::2]), 41 | min(bbox[1::2]), 42 | max(bbox[2::2]), 43 | max(bbox[3::2]), 44 | ] 45 | 46 | 47 | def merge_bbox(bbox_list): 48 | x0, y0, x1, y1 = list(zip(*bbox_list)) 49 | return [min(x0), min(y0), max(x1), max(y1)] 50 | 51 | 52 | logger = logging.getLogger(__name__) 53 | 54 | 55 | class XFUNConfig(datasets.BuilderConfig): 56 | """BuilderConfig for XFUN.""" 57 | 58 | def __init__(self, **kwargs): 59 | """ 60 | Args: 61 | lang: string, language for the input text 62 | **kwargs: keyword arguments forwarded to super. 63 | """ 64 | super(XFUNConfig, self).__init__(**kwargs) 65 | 66 | 67 | class XFUN(datasets.GeneratorBasedBuilder): 68 | """XFUN dataset.""" 69 | 70 | BUILDER_CONFIGS = [XFUNConfig(name="en")] 71 | 72 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 73 | 74 | def _info(self): 75 | return datasets.DatasetInfo( 76 | features=datasets.Features( 77 | { 78 | "id": datasets.Value("string"), 79 | "input_ids": datasets.Sequence(datasets.Value("int64")), 80 | "bbox": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))), 81 | "labels": datasets.Sequence( 82 | datasets.ClassLabel( 83 | names=["O", "B-QUESTION", "B-ANSWER", "I-ANSWER", "I-QUESTION"] 84 | ) 85 | ), 86 | "image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"), 87 | "original_image": datasets.features.Image(), 88 | "entities": datasets.Sequence( 89 | { 90 | "start": datasets.Value("int64"), 91 | "end": datasets.Value("int64"), 92 | "label": datasets.ClassLabel(names=["HEADER", "QUESTION", "ANSWER"]), 93 | } 94 | ), 95 | "relations": datasets.Sequence( 96 | { 97 | "head": datasets.Value("int64"), 98 | "tail": datasets.Value("int64"), 99 | "start_index": datasets.Value("int64"), 100 | "end_index": datasets.Value("int64"), 101 | } 102 | ), 103 | } 104 | ), 105 | supervised_keys=None, 106 | ) 107 | 108 | def _split_generators(self, dl_manager): 109 | """Returns SplitGenerators.""" 110 | return [ 111 | datasets.SplitGenerator( 112 | name=datasets.Split.TRAIN, gen_kwargs={"filepath": [ROOT / "dataset/en_sample/en.train.json", 113 | ROOT / "dataset/en_sample/"]}), 114 | datasets.SplitGenerator( 115 | name=datasets.Split.VALIDATION, gen_kwargs={"filepath": [ROOT / "dataset/en_sample/en.val.json", 116 | ROOT / "dataset/en_sample/"]}), 117 | 118 | # datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": test_files_for_many_langs}), 119 | ] 120 | 121 | def _generate_examples(self, filepath): 122 | logger.info("Generating examples from = %s", filepath) 123 | with open(filepath[0], "r", encoding="utf-8") as f: 124 | data = json.load(f) 125 | 126 | for doc in data["documents"]: 127 | print(f'processing {doc["img"]["fname"]}') 128 | doc["img"]["fpath"] = os.path.join(filepath[1], doc["img"]["fname"]) 129 | image, size = load_image(doc["img"]["fpath"], size=224) 130 | original_image, _ = load_image(doc["img"]["fpath"]) 131 | document = doc["document"] 132 | tokenized_doc = {"input_ids": [], "bbox": [], "labels": []} 133 | entities = [] 134 | relations = [] 135 | id2label = {} 136 | entity_id_to_index_map = {} 137 | empty_entity = set() 138 | for line in document: 139 | if len(line["text"]) == 0: 140 | empty_entity.add(line["id"]) 141 | continue 142 | id2label[line["id"]] = line["label"] 143 | relations.extend([tuple(sorted(l)) for l in line["linking"]]) 144 | tokenized_inputs = self.tokenizer( 145 | line["text"], 146 | add_special_tokens=False, 147 | return_offsets_mapping=True, 148 | return_attention_mask=False, 149 | ) 150 | text_length = 0 151 | ocr_length = 0 152 | bbox = [] 153 | for token_id, offset in zip(tokenized_inputs["input_ids"], tokenized_inputs["offset_mapping"]): 154 | if token_id == 6: 155 | bbox.append(None) 156 | continue 157 | text_length += offset[1] - offset[0] 158 | tmp_box = [] 159 | while ocr_length < text_length: 160 | ocr_word = line["words"].pop(0) 161 | ocr_length += len( 162 | self.tokenizer._tokenizer.normalizer.normalize_str(ocr_word["text"].strip()) 163 | ) 164 | tmp_box.append(simplify_bbox(ocr_word["box"])) 165 | if len(tmp_box) == 0: 166 | tmp_box = last_box 167 | bbox.append(normalize_bbox(merge_bbox(tmp_box), size)) 168 | last_box = tmp_box # noqa 169 | bbox = [ 170 | [bbox[i + 1][0], bbox[i + 1][1], bbox[i + 1][0], bbox[i + 1][1]] if b is None else b 171 | for i, b in enumerate(bbox) 172 | ] 173 | 174 | # remove header from labels 175 | if line["label"] in ["other", "header"]: 176 | label = ["O"] * len(bbox) 177 | else: 178 | label = [f"I-{line['label'].upper()}"] * len(bbox) 179 | label[0] = f"B-{line['label'].upper()}" 180 | 181 | tokenized_inputs.update({"bbox": bbox, "labels": label}) 182 | if label[0] != "O": 183 | entity_id_to_index_map[line["id"]] = len(entities) 184 | entities.append( 185 | { 186 | "start": len(tokenized_doc["input_ids"]), 187 | "end": len(tokenized_doc["input_ids"]) + len(tokenized_inputs["input_ids"]), 188 | "label": line["label"].upper(), 189 | } 190 | ) 191 | for i in tokenized_doc: 192 | tokenized_doc[i] = tokenized_doc[i] + tokenized_inputs[i] 193 | 194 | relations = list(set(relations)) 195 | relations = [rel for rel in relations if rel[0] not in empty_entity and rel[1] not in empty_entity] 196 | kvrelations = [] 197 | for rel in relations: 198 | pair = [id2label[rel[0]], id2label[rel[1]]] 199 | if pair == ["question", "answer"]: 200 | kvrelations.append( 201 | {"head": entity_id_to_index_map[rel[0]], "tail": entity_id_to_index_map[rel[1]]} 202 | ) 203 | elif pair == ["answer", "question"]: 204 | kvrelations.append( 205 | {"head": entity_id_to_index_map[rel[1]], "tail": entity_id_to_index_map[rel[0]]} 206 | ) 207 | else: 208 | continue 209 | 210 | def get_relation_span(rel): 211 | bound = [] 212 | for entity_index in [rel["head"], rel["tail"]]: 213 | bound.append(entities[entity_index]["start"]) 214 | bound.append(entities[entity_index]["end"]) 215 | return min(bound), max(bound) 216 | 217 | relations = sorted( 218 | [ 219 | { 220 | "head": rel["head"], 221 | "tail": rel["tail"], 222 | "start_index": get_relation_span(rel)[0], 223 | "end_index": get_relation_span(rel)[1], 224 | } 225 | for rel in kvrelations 226 | ], 227 | key=lambda x: x["head"], 228 | ) 229 | chunk_size = 512 230 | for chunk_id, index in enumerate(range(0, len(tokenized_doc["input_ids"]), chunk_size)): 231 | item = {} 232 | for k in tokenized_doc: 233 | item[k] = tokenized_doc[k][index: index + chunk_size] 234 | entities_in_this_span = [] 235 | global_to_local_map = {} 236 | for entity_id, entity in enumerate(entities): 237 | if ( 238 | index <= entity["start"] < index + chunk_size 239 | and index <= entity["end"] < index + chunk_size 240 | ): 241 | entity["start"] = entity["start"] - index 242 | entity["end"] = entity["end"] - index 243 | global_to_local_map[entity_id] = len(entities_in_this_span) 244 | entities_in_this_span.append(entity) 245 | relations_in_this_span = [] 246 | for relation in relations: 247 | if ( 248 | index <= relation["start_index"] < index + chunk_size 249 | and index <= relation["end_index"] < index + chunk_size 250 | ): 251 | relations_in_this_span.append( 252 | { 253 | "head": global_to_local_map[relation["head"]], 254 | "tail": global_to_local_map[relation["tail"]], 255 | "start_index": relation["start_index"] - index, 256 | "end_index": relation["end_index"] - index, 257 | } 258 | ) 259 | item.update( 260 | { 261 | "id": f"{doc['id']}_{chunk_id}", 262 | "image": image, 263 | "original_image": original_image, 264 | "entities": entities_in_this_span, 265 | "relations": relations_in_this_span, 266 | } 267 | ) 268 | yield f"{doc['id']}_{chunk_id}", item 269 | 270 | 271 | """ 272 | Input Schema: 273 | 274 | { 275 | "id": image name, e.g. "en_train_0_0" 276 | "input_ids": List[token_ids], all_texts = "".join([tokenizer.decode(i) for i in dataset["train"][0]["input_ids"]]) 277 | "bbox": tensors, 278 | "labels": tensors, 279 | "hd_image": 280 | "entities": 281 | "relations": 282 | "attention_mask": 283 | "image": 284 | } 285 | 286 | for English, tokens usually equals to words, but when word is out of vacab, it will be tokenized into multiple common tokens. e.g. "vashering" --> ['va', '##sher', '##ing'] 287 | """ -------------------------------------------------------------------------------- /RE_HuggingFace/download.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | import json 3 | import logging 4 | import os 5 | from pathlib import Path 6 | 7 | import datasets 8 | 9 | from PIL import Image 10 | import numpy as np 11 | import torch 12 | 13 | # TODO: log msg in file 14 | 15 | ROOT = Path(__file__).parents[1] 16 | 17 | from transformers import AutoTokenizer 18 | 19 | 20 | def load_image(image_path, size=None): 21 | image = Image.open(image_path).convert("RGB") 22 | w, h = image.size 23 | if size is not None: 24 | # resize image 25 | image = image.resize((size, size)) 26 | image = np.asarray(image) 27 | image = image[:, :, ::-1] # flip color channels from RGB to BGR 28 | image = image.transpose(2, 0, 1) # move channels to first dimension 29 | return image, (w, h) 30 | 31 | 32 | def normalize_bbox(bbox, size): 33 | return [ 34 | int(1000 * bbox[0] / size[0]), 35 | int(1000 * bbox[1] / size[1]), 36 | int(1000 * bbox[2] / size[0]), 37 | int(1000 * bbox[3] / size[1]), 38 | ] 39 | 40 | 41 | def simplify_bbox(bbox): 42 | return [ 43 | min(bbox[0::2]), 44 | min(bbox[1::2]), 45 | max(bbox[2::2]), 46 | max(bbox[3::2]), 47 | ] 48 | 49 | 50 | def merge_bbox(bbox_list): 51 | x0, y0, x1, y1 = list(zip(*bbox_list)) 52 | return [min(x0), min(y0), max(x1), max(y1)] 53 | 54 | 55 | logger = logging.getLogger(__name__) 56 | 57 | 58 | class XFUNConfig(datasets.BuilderConfig): 59 | """BuilderConfig for XFUN.""" 60 | 61 | def __init__(self, **kwargs): 62 | """ 63 | Args: 64 | lang: string, language for the input text 65 | **kwargs: keyword arguments forwarded to super. 66 | """ 67 | super(XFUNConfig, self).__init__(**kwargs) 68 | 69 | 70 | class XFUN(datasets.GeneratorBasedBuilder): 71 | """XFUN dataset.""" 72 | 73 | BUILDER_CONFIGS = [XFUNConfig(name="en")] 74 | 75 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 76 | 77 | def _info(self): 78 | return datasets.DatasetInfo( 79 | features=datasets.Features( 80 | { 81 | "id": datasets.Value("string"), 82 | "input_ids": datasets.Sequence(datasets.Value("int64")), 83 | "bbox": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))), 84 | "labels": datasets.Sequence( 85 | datasets.ClassLabel( 86 | names=["O", "B-QUESTION", "B-ANSWER", "I-ANSWER", "I-QUESTION"] 87 | ) 88 | ), 89 | "image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"), 90 | "original_image": datasets.features.Image(), 91 | "entities": datasets.Sequence( 92 | { 93 | "start": datasets.Value("int64"), 94 | "end": datasets.Value("int64"), 95 | "label": datasets.ClassLabel(names=["HEADER", "QUESTION", "ANSWER"]), 96 | } 97 | ), 98 | "relations": datasets.Sequence( 99 | { 100 | "head": datasets.Value("int64"), 101 | "tail": datasets.Value("int64"), 102 | "start_index": datasets.Value("int64"), 103 | "end_index": datasets.Value("int64"), 104 | } 105 | ), 106 | } 107 | ), 108 | supervised_keys=None, 109 | ) 110 | 111 | def _split_generators(self, dl_manager): 112 | file_dir = "/home/kewen_yang/Information_Extraction/dataset/RE_Finetune_1/" 113 | """Returns SplitGenerators.""" 114 | return [ 115 | datasets.SplitGenerator( 116 | name=datasets.Split.TRAIN, gen_kwargs={"filepath": [file_dir + "train.json", 117 | file_dir]}), 118 | datasets.SplitGenerator( 119 | name=datasets.Split.VALIDATION, gen_kwargs={"filepath": [file_dir + "val.json", 120 | file_dir]}), 121 | 122 | # datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": test_files_for_many_langs}), 123 | ] 124 | 125 | def _generate_examples(self, filepath): 126 | logger.info("Generating examples from = %s", filepath) 127 | with open(filepath[0], "r", encoding="utf-8") as f: 128 | data = json.load(f) 129 | 130 | for doc in data["documents"]: 131 | try: 132 | doc["img"]["fpath"] = os.path.join(filepath[1], doc["img"]["fname"]) 133 | size = (doc["img"]["width"], doc["img"]["height"]) 134 | image, _ = load_image(doc["img"]["fpath"], size=224) 135 | original_image, _ = load_image(doc["img"]["fpath"]) 136 | document = doc["document"] 137 | tokenized_doc = {"input_ids": [], "bbox": [], "labels": []} 138 | entities = [] 139 | relations = [] 140 | id2label = {} 141 | entity_id_to_index_map = {} 142 | empty_entity = set() 143 | for line in document: 144 | if len(line["text"]) == 0: 145 | empty_entity.add(line["id"]) 146 | continue 147 | id2label[line["id"]] = line["label"] 148 | relations.extend([tuple(sorted(l)) for l in line["linking"]]) 149 | tokenized_inputs = self.tokenizer( 150 | line["text"], 151 | add_special_tokens=False, 152 | return_offsets_mapping=True, 153 | return_attention_mask=False, 154 | ) 155 | text_length = 0 156 | ocr_length = 0 157 | bbox = [] 158 | for token_id, offset in zip(tokenized_inputs["input_ids"], tokenized_inputs["offset_mapping"]): 159 | if token_id == 6: 160 | bbox.append(None) 161 | continue 162 | text_length += offset[1] - offset[0] 163 | tmp_box = [] 164 | while ocr_length < text_length: 165 | ocr_word = line["words"].pop(0) 166 | ocr_length += len( 167 | self.tokenizer._tokenizer.normalizer.normalize_str(ocr_word["text"].strip()) 168 | ) 169 | tmp_box.append(simplify_bbox(ocr_word["box"])) 170 | if len(tmp_box) == 0: 171 | tmp_box = last_box 172 | bbox.append(normalize_bbox(merge_bbox(tmp_box), size)) 173 | last_box = tmp_box # noqa 174 | bbox = [ 175 | [bbox[i + 1][0], bbox[i + 1][1], bbox[i + 1][0], bbox[i + 1][1]] if b is None else b 176 | for i, b in enumerate(bbox) 177 | ] 178 | 179 | # remove header from labels 180 | if line["label"] in ["other", "header"]: 181 | label = ["O"] * len(bbox) 182 | else: 183 | label = [f"I-{line['label'].upper()}"] * len(bbox) 184 | label[0] = f"B-{line['label'].upper()}" 185 | 186 | tokenized_inputs.update({"bbox": bbox, "labels": label}) 187 | if label[0] != "O": 188 | entity_id_to_index_map[line["id"]] = len(entities) 189 | entities.append( 190 | { 191 | "start": len(tokenized_doc["input_ids"]), 192 | "end": len(tokenized_doc["input_ids"]) + len(tokenized_inputs["input_ids"]), 193 | "label": line["label"].upper(), 194 | } 195 | ) 196 | for i in tokenized_doc: 197 | tokenized_doc[i] = tokenized_doc[i] + tokenized_inputs[i] 198 | 199 | relations = list(set(relations)) 200 | relations = [rel for rel in relations if rel[0] not in empty_entity and rel[1] not in empty_entity] 201 | kvrelations = [] 202 | for rel in relations: 203 | pair = [id2label[rel[0]], id2label[rel[1]]] 204 | if pair == ["question", "answer"]: 205 | kvrelations.append( 206 | {"head": entity_id_to_index_map[rel[0]], "tail": entity_id_to_index_map[rel[1]]} 207 | ) 208 | elif pair == ["answer", "question"]: 209 | kvrelations.append( 210 | {"head": entity_id_to_index_map[rel[1]], "tail": entity_id_to_index_map[rel[0]]} 211 | ) 212 | else: 213 | continue 214 | 215 | def get_relation_span(rel): 216 | bound = [] 217 | for entity_index in [rel["head"], rel["tail"]]: 218 | bound.append(entities[entity_index]["start"]) 219 | bound.append(entities[entity_index]["end"]) 220 | return min(bound), max(bound) 221 | 222 | relations = sorted( 223 | [ 224 | { 225 | "head": rel["head"], 226 | "tail": rel["tail"], 227 | "start_index": get_relation_span(rel)[0], 228 | "end_index": get_relation_span(rel)[1], 229 | } 230 | for rel in kvrelations 231 | ], 232 | key=lambda x: x["head"], 233 | ) 234 | chunk_size = 512 235 | for chunk_id, index in enumerate(range(0, len(tokenized_doc["input_ids"]), chunk_size)): 236 | item = {} 237 | for k in tokenized_doc: 238 | item[k] = tokenized_doc[k][index: index + chunk_size] 239 | entities_in_this_span = [] 240 | global_to_local_map = {} 241 | for entity_id, entity in enumerate(entities): 242 | if ( 243 | index <= entity["start"] < index + chunk_size 244 | and index <= entity["end"] < index + chunk_size 245 | ): 246 | entity["start"] = entity["start"] - index 247 | entity["end"] = entity["end"] - index 248 | global_to_local_map[entity_id] = len(entities_in_this_span) 249 | entities_in_this_span.append(entity) 250 | relations_in_this_span = [] 251 | for relation in relations: 252 | if ( 253 | index <= relation["start_index"] < index + chunk_size 254 | and index <= relation["end_index"] < index + chunk_size 255 | ): 256 | relations_in_this_span.append( 257 | { 258 | "head": global_to_local_map[relation["head"]], 259 | "tail": global_to_local_map[relation["tail"]], 260 | "start_index": relation["start_index"] - index, 261 | "end_index": relation["end_index"] - index, 262 | } 263 | ) 264 | # check if any bbox value > 1000 265 | if max(torch.tensor(item["bbox"])[:, 1]) > 1000: 266 | raise ValueError(doc["img"]["fpath"]) 267 | 268 | item.update( 269 | { 270 | "id": f"{doc['id']}_{chunk_id}", 271 | "image": image, 272 | "original_image": original_image, 273 | "entities": entities_in_this_span, 274 | "relations": relations_in_this_span, 275 | } 276 | ) 277 | yield f"{doc['id']}_{chunk_id}", item 278 | except: 279 | print(f'>>>>>>>>>>>>>>problem with {doc["id"]}') 280 | 281 | 282 | """ 283 | Input Schema: 284 | 285 | { 286 | "id": image name, e.g. "en_train_0_0" 287 | "input_ids": List[token_ids], all_texts = "".join([tokenizer.decode(i) for i in dataset["train"][0]["input_ids"]]) 288 | "bbox": tensors, 289 | "labels": tensors, 290 | "hd_image": 291 | "entities": 292 | "relations": 293 | "attention_mask": 294 | "image": 295 | } 296 | 297 | for English, tokens usually equals to words, but when word is out of vacab, it will be tokenized into multiple common tokens. e.g. "vashering" --> ['va', '##sher', '##ing'] 298 | """ -------------------------------------------------------------------------------- /RE_HuggingFace/layoutlm_re.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import enum 3 | from functools import reduce 4 | import json 5 | import logging 6 | from math import sqrt 7 | import os 8 | import sys 9 | from collections import defaultdict 10 | from dataclasses import dataclass 11 | from datetime import datetime 12 | from pathlib import Path 13 | from typing import Dict, List, Optional, Union 14 | from PIL import Image, JpegImagePlugin 15 | from torch.utils.data import DataLoader 16 | from transformers import LayoutLMTokenizer 17 | import pandas as pd 18 | import wandb 19 | JpegImagePlugin._getmp = lambda: None 20 | import matplotlib 21 | import tornado 22 | 23 | matplotlib.use('WebAgg') 24 | import matplotlib.pyplot as plt 25 | import numpy as np 26 | import torch 27 | # enable only if using DGX machine to plot visuals 28 | from datasets import load_dataset 29 | 30 | 31 | from pytz import timezone 32 | from transformers import (AutoModelForTokenClassification, AutoProcessor, 33 | LayoutLMForRelationExtraction, 34 | PreTrainedTokenizerBase, 35 | TrainingArguments) 36 | from transformers.file_utils import PaddingStrategy 37 | 38 | from unilm.layoutlmft.layoutlmft.evaluation import re_score 39 | from unilm.layoutlmft.layoutlmft.trainers import XfunReTrainer 40 | 41 | torch.backends.cudnn.benchmark = False 42 | 43 | ROOT = Path(__file__).parents[1] 44 | 45 | TZ = timezone('Asia/Singapore') 46 | CURRENT = datetime.now(tz=TZ) 47 | TIME = CURRENT.strftime("%Y_%m_%d_%H_%M") 48 | MODDEL_DIR = ROOT / "RE_HuggingFace" / f"model/checkpoint_{TIME}.pt" 49 | DEVICE = "cuda" 50 | # DEVICE = "cpu" 51 | 52 | 53 | @enum.unique 54 | class Task(enum.Enum): 55 | FINETUNING = 0 56 | INFERENCE = 1 57 | XTRACT_INFER = 2 58 | 59 | 60 | task = Task.XTRACT_INFER 61 | 62 | 63 | @dataclass 64 | class DataCollatorForKeyValueExtraction: 65 | """ 66 | Data collator that will dynamically pad the inputs received, as well as the labels. 67 | 68 | Args: 69 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): 70 | The tokenizer used for encoding the data. 71 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): 72 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 73 | among: 74 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 75 | sequence if provided). 76 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 77 | maximum acceptable input length for the model if that argument is not provided. 78 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 79 | different lengths). 80 | max_length (:obj:`int`, `optional`): 81 | Maximum length of the returned list and optionally padding length (see above). 82 | pad_to_multiple_of (:obj:`int`, `optional`): 83 | If set will pad the sequence to a multiple of the provided value. 84 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 85 | 7.5 (Volta). 86 | label_pad_token_id (:obj:`int`, `optional`, defaults to -100): 87 | The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions). 88 | """ 89 | tokenizer: PreTrainedTokenizerBase() 90 | padding: Union[bool, str, PaddingStrategy] = True 91 | max_length: Optional[int] = None 92 | pad_to_multiple_of: Optional[int] = None 93 | label_pad_token_id: int = -100 94 | 95 | def __call__(self, features): 96 | # prepare text input 97 | entities = [] 98 | relations = [] 99 | for feature in features: 100 | del feature["image"] 101 | del feature["id"] 102 | del feature["labels"] 103 | del feature["original_image"] 104 | entities.append(feature["entities"]) 105 | del feature["entities"] 106 | relations.append(feature["relations"]) 107 | del feature["relations"] 108 | 109 | # set a break point here and check out the data schema based on Appendix 110 | batch = self.tokenizer.pad( 111 | features, 112 | padding=self.padding, 113 | max_length=self.max_length, 114 | pad_to_multiple_of=self.pad_to_multiple_of, 115 | return_tensors="pt" 116 | ) 117 | 118 | # batch["image"] = image 119 | batch["entities"] = entities 120 | batch["relations"] = relations 121 | 122 | return batch 123 | 124 | 125 | def relate_to_entity(to_merge: set, list_of_entities: list) -> list: 126 | """ 127 | replace id with actual entity 128 | Args: 129 | to_merge: a set of all merging ids 130 | list_of_entities: list of entities 131 | Returns: list of entities to be merged, with no order 132 | """ 133 | out = list() 134 | for e in list_of_entities: 135 | if e.get("id", "") in to_merge: 136 | out.append(e) 137 | return out 138 | 139 | 140 | def merge_two_entities(e_0: dict, e_1: dict) -> dict: 141 | """ 142 | merge two entities 143 | Args: 144 | e_0: entity 145 | e_1: entity 146 | Returns: the combined entity 147 | """ 148 | if "bbox" not in e_0.keys(): 149 | return e_1 150 | if "bbox" not in e_1.keys(): 151 | return e_0 152 | 153 | # decide the base entity 154 | # compare the y0, smaller y0 -> base 155 | if e_0["bbox"][1] == e_1["bbox"][1]: 156 | if e_0["bbox"][0] <= e_1["bbox"][0]: 157 | base_entity = e_0 158 | adding_entity = e_1 159 | else: 160 | base_entity = e_1 161 | adding_entity = e_0 162 | elif e_0["bbox"][1] < e_1["bbox"][1]: 163 | base_entity = e_0 164 | adding_entity = e_1 165 | else: 166 | base_entity = e_1 167 | adding_entity = e_0 168 | 169 | base_entity = merge_text(base_entity, adding_entity) 170 | base_entity = merge_bbox(base_entity, adding_entity) 171 | return base_entity 172 | 173 | 174 | def merge_text(base_entity: dict, adding_entity: dict) -> dict: 175 | """ 176 | merge text from two entities 177 | Args: 178 | base_entity: the entity that text starts 179 | adding_entity: the entity to be added to base 180 | Returns: an entity with combined text 181 | """ 182 | base_entity["text"] = base_entity.get("text", "") + " " + adding_entity.get("text", "") 183 | return base_entity 184 | 185 | 186 | def merge_bbox(base_entity: dict, adding_entity: dict) -> dict: 187 | """ 188 | merge bbox from two entities 189 | Args: 190 | base_entity: the entity that bbox starts 191 | adding_entity: the entity to be added to base 192 | Returns: an entity with combined bbox 193 | """ 194 | base_entity["bbox"] = [ 195 | min(base_entity["bbox"][0], adding_entity["bbox"][0]), 196 | min(base_entity["bbox"][1], adding_entity["bbox"][1]), 197 | max(base_entity["bbox"][2], adding_entity["bbox"][2]), 198 | max(base_entity["bbox"][3], adding_entity["bbox"][3]) 199 | ] 200 | return base_entity 201 | 202 | 203 | def unnormalize_box(bbox, width, height): 204 | return [ 205 | width * (bbox[0] / 1000), 206 | height * (bbox[1] / 1000), 207 | width * (bbox[2] / 1000), 208 | height * (bbox[3] / 1000), 209 | ] 210 | 211 | 212 | def compute_metrics(p): 213 | pred_relations, gt_relations = p 214 | score = re_score(pred_relations, gt_relations, mode="boundaries") 215 | return score 216 | 217 | 218 | def post_process_entities(entity_list: list, threshold: int = 2) -> tuple: 219 | """ 220 | perform some cleaning and format changing task on entity_list 221 | Args: 222 | entity_list: all the entities after pairing step 223 | threshold: the largest vertical difference between two merging entities 224 | if diff > threshold, then no merge for these two entities. 225 | Returns: 226 | """ 227 | # relation diagram to illustrate multi-key and multi-val linking 228 | # Given key_0 and key_1 form a key, and its corresponding val is formed by val_0, val_1 and val_2, 229 | # their linking relation could be as below: 230 | # / val_0 / val_0 231 | # key_0- val_1 key_1- val_1 val_0 - key_1, val_1 - key_1, val_2 - key_1 232 | # \ val_2 \ val_2 233 | # caveat: for the linking in value entity, it only shows key_1 instead of key_1 and key_2 234 | 235 | # remove entities without any linking, and change key "box" to "bbox" 236 | unpaired = list() 237 | en_list_ori = copy.deepcopy(entity_list) 238 | for idx, entity in enumerate(en_list_ori): 239 | if "id" not in entity.keys(): 240 | unpaired.append(idx) 241 | continue 242 | 243 | # change id type to str 244 | entity["id"] = str(entity["id"]) 245 | 246 | # remove invalid entity 247 | if "linking" not in entity.keys() or len(entity["linking"]) == 0: 248 | unpaired.append(idx) 249 | continue 250 | 251 | entity_list = [i for idx, i in enumerate(en_list_ori) if idx not in unpaired] 252 | if len(unpaired) != 0: 253 | unpaired = [i for idx, i in enumerate(en_list_ori) if idx in unpaired] 254 | 255 | # all entities in the list should have "id" and "linking" keys afterwards 256 | merge_list = list() 257 | kv_indexes = [e["id"] for e in entity_list] 258 | 259 | def append_merge_set(merge_list: list, 260 | merge_set: set) -> list: 261 | """ 262 | append merge_set to merge_list 263 | Args: 264 | merge_list: merge_list 265 | merge_set: merge_set 266 | Returns: list 267 | """ 268 | if len(merge_list) == 0: 269 | merge_list.append(merge_set) 270 | return merge_list 271 | 272 | if merge_set in merge_list: 273 | return merge_list 274 | 275 | # there is overlaps between existing and merge_set 276 | for i in merge_list: 277 | if len(i.union(merge_set)) < len(i) + len(merge_set): 278 | i.update(merge_set) 279 | return merge_list 280 | 281 | merge_list.append(merge_set) 282 | return merge_list 283 | 284 | def get_counterpart(link: List, entity: Dict) -> str: 285 | # get the counterpart from a link 286 | if len(link) != 2: 287 | raise ValueError("link is not between 2 entities") 288 | 289 | id = entity["id"] 290 | for i in link: 291 | if str(i) != id: 292 | return str(i) 293 | 294 | # merge answers linked to the same question 295 | for entity in entity_list: 296 | # find the question entity which links multiple answers 297 | if entity.get("label", "") == "question" and len(entity["linking"]) > 1: 298 | merge_set = set() 299 | for link in entity["linking"]: 300 | cp = get_counterpart(link, entity) 301 | if cp in kv_indexes: 302 | merge_set.add(cp) 303 | 304 | if len(merge_set) >= 2: 305 | merge_list = append_merge_set(merge_list, merge_set) 306 | 307 | # merge questions with same linkings 308 | # for temp_dict, key: sorted list of counterparts; value: question id 309 | temp_dict = defaultdict(lambda: set()) 310 | for idx, entity in enumerate(entity_list): 311 | if entity.get("label", "") == "question": 312 | cps = sorted([get_counterpart(i, entity) for i in entity["linking"]]) 313 | temp_dict[tuple(cps)].add(entity["id"]) 314 | 315 | # add ids of merging keys to merge_list 316 | for q_ids in temp_dict.values(): 317 | if len(q_ids) > 1: 318 | merge_list = append_merge_set(merge_list, q_ids) 319 | 320 | if len(merge_list) > 0: 321 | for merge_set in merge_list: 322 | en_list = relate_to_entity(to_merge=merge_set, list_of_entities=entity_list) 323 | 324 | # filter the big gap in merge list 325 | en_list = sorted(en_list, key=lambda x: (x["bbox"][1], x["bbox"][0])) 326 | for idx in range(len(en_list) - 1): 327 | y_diff = en_list[idx + 1]["bbox"][1] - en_list[idx]["bbox"][3] 328 | char_height = abs(en_list[idx]["bbox"][3] - en_list[idx]["bbox"][1]) 329 | if y_diff > threshold * char_height: 330 | # stop merging at idx, **Noted that all entities are sorted by y0 331 | en_list = en_list[:(idx + 1)] 332 | break 333 | base = reduce(merge_two_entities, en_list) 334 | # add base and drop entities in merge_set 335 | entity_list = [e for e in entity_list if e.get('id', -1) not in merge_set] 336 | entity_list.append(base) 337 | 338 | # update valid ids 339 | kv_indexes = [e["id"] for e in entity_list] 340 | 341 | # change linking from list to single index 342 | for entity in entity_list: 343 | if len(entity["linking"]) == 1: 344 | entity["linking"] = get_counterpart(entity["linking"][0], entity) 345 | else: 346 | for link in entity["linking"]: 347 | index = get_counterpart(link, entity) 348 | if index in kv_indexes: 349 | entity["linking"] = index 350 | break 351 | 352 | keyvalue = [e for e in entity_list if isinstance(e["linking"], str)] 353 | return keyvalue, unpaired 354 | 355 | 356 | if task.name == "FINETUNING": 357 | 358 | logger = logging.getLogger(__name__) 359 | logger.setLevel(logging.INFO) 360 | logger.addHandler(logging.StreamHandler(stream=sys.stdout)) 361 | logger_path = ROOT / "RE_HuggingFace" / "artifacts" / f"experiment_{TIME}.log" 362 | fileHandler = logging.FileHandler(f"{logger_path.as_posix()}") 363 | logFormatter = logging.Formatter("[%(levelname)s] - %(message)s") 364 | fileHandler.setFormatter(logFormatter) 365 | logger.addHandler(fileHandler) 366 | 367 | # start a new wandb run to track this script 368 | wandb.init( 369 | # set the wandb project where this run will be logged 370 | project="RE", 371 | 372 | # track hyperparameters and run metadata 373 | config={ 374 | "dataset": "RE_FUNSD", 375 | } 376 | ) 377 | 378 | dataset = load_dataset(path=(ROOT / "RE_HuggingFace" / "download.py").as_posix(), name="en") 379 | 380 | # # check if any bbox value > 1000, use it only for debugging >1000 error 381 | # for i in range(len(dataset["train"]["bbox"])): 382 | # print(max(torch.tensor(dataset["train"]["bbox"][i])[:, 1])) 383 | 384 | model_card = "/home/kewen_yang/Information_Extraction/RE_HuggingFace/model/checkpoint_2023_08_01_15_26.pt" 385 | # model_card = "microsoft/layoutlm-base-uncased" 386 | 387 | model = LayoutLMForRelationExtraction.from_pretrained(model_card) 388 | tokenizer = LayoutLMTokenizer.from_pretrained(model_card) 389 | 390 | logger.info(f"finetuning model on top of {model_card}") 391 | logger.info(f"finetuning with dataset - RE_Finetune_1") 392 | data_collator = DataCollatorForKeyValueExtraction( 393 | tokenizer, 394 | pad_to_multiple_of=1, 395 | padding="max_length", 396 | max_length=512, 397 | ) 398 | 399 | train_dataset = dataset['train'] 400 | val_dataset = dataset['validation'] 401 | 402 | # Define TrainingArguments 403 | # See thread for hyperparameters: https://github.com/microsoft/unilm/issues/586 404 | training_args = TrainingArguments( 405 | output_dir=MODDEL_DIR, 406 | overwrite_output_dir=True, 407 | remove_unused_columns=False, 408 | # fp16=True, -> led to a loss of 0 409 | 410 | max_steps=20000, 411 | # max_steps = 10, 412 | evaluation_strategy="steps", 413 | 414 | # num_train_epochs=1, 415 | # evaluation_strategy="epoch", 416 | save_strategy="no", 417 | no_cuda=(DEVICE == "cpu"), 418 | per_device_train_batch_size=2, 419 | per_device_eval_batch_size=1, 420 | warmup_ratio=0.1, 421 | learning_rate=1e-5, 422 | push_to_hub=False, 423 | report_to="wandb" 424 | ) 425 | 426 | # Initialize our Trainer 427 | trainer = XfunReTrainer( 428 | model=model, 429 | args=training_args, 430 | train_dataset=train_dataset, 431 | eval_dataset=val_dataset, 432 | tokenizer=tokenizer, 433 | data_collator=data_collator, 434 | compute_metrics=compute_metrics, 435 | ) 436 | logger.info("start training model") 437 | train_metrics = trainer.train(resume_from_checkpoint=False) 438 | logger.info(f"training_metrics: {train_metrics}") 439 | 440 | logger.info("start evaluating performance") 441 | eval_metrics = trainer.evaluate() 442 | logger.info(f"evaluation metrics: {eval_metrics}") 443 | trainer.save_model(MODDEL_DIR) 444 | 445 | learning_curve = pd.DataFrame(trainer.state.log_history) 446 | logger.info('\n\t' + learning_curve.to_string().replace('\n', '\n\t')) 447 | 448 | 449 | elif task.name == "INFERENCE": 450 | """do inference by huggingface pipeline 451 | """ 452 | 453 | # test_image = train_dataset[48]['original_image'] 454 | test_image = Image.open(ROOT / 'dataset/test/AUTOVACSTORE-1-2-Bing-image_010.jpg') 455 | # plt.imshow(test_image) 456 | # plt.show() 457 | 458 | # load model + processor from the hub 459 | processor = AutoProcessor.from_pretrained(ROOT / "SER_HuggingFace" / "model") 460 | model = AutoModelForTokenClassification.from_pretrained(ROOT / "SER_HuggingFace" / "model") 461 | # prepare inputs for the model 462 | # we set `return_offsets_mapping=True` as we use the offsets to know which tokens are subwords and which aren't 463 | inputs = processor(test_image, return_offsets_mapping=True, padding="max_length", max_length=512, truncation=True, return_tensors="pt") 464 | 465 | original_text = processor.tokenizer.convert_tokens_to_string([processor.tokenizer.decode(i, skip_special_tokens=True) for i in inputs["input_ids"][0].tolist()]) 466 | # all_token_text = [processor.tokenizer.decode(i, skip_special_tokens=True) for i in inputs["input_ids"][0].tolist()] 467 | 468 | inputs = inputs.to(DEVICE) 469 | model.to(DEVICE) 470 | 471 | # offset_mapping: indicates the start and end index of the actual subword w.r.t each token text, e.g. '##omi' with offset [2, 5] -> 'omi' 472 | offset_mapping = inputs.pop("offset_mapping") 473 | 474 | # word_ids: indicates if the subtokens belong to the same word. 475 | word_ids = inputs.encodings[0].word_ids 476 | 477 | token_ids = inputs.input_ids[0].tolist() 478 | 479 | if_special = inputs.encodings[0].special_tokens_mask 480 | 481 | # forward pass 482 | with torch.no_grad(): 483 | outputs = model(**inputs) 484 | 485 | # take argmax on last dimension to get predicted class ID per token 486 | predictions = outputs.logits.argmax(-1).squeeze().tolist() 487 | 488 | # # check if it's subwords 489 | # is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0 490 | 491 | # merge subwords into word-level based on word_ids 492 | word_pred = defaultdict(lambda: -1) 493 | words = defaultdict(list) 494 | for idx, tp in enumerate(zip(if_special, token_ids, predictions, word_ids)): 495 | if idx == 0 or bool(tp[0]): 496 | continue 497 | 498 | words[tp[-1]].append(idx) 499 | if word_pred[tp[-1]] == -1: 500 | word_pred[tp[-1]] = tp[2] 501 | 502 | id2label = {"QUESTION": 1, "ANSWER": 2} 503 | 504 | # finally, store recognized "question" and "answer" entities in a list 505 | entities = [] 506 | current_entity = None 507 | start = None 508 | end = None 509 | 510 | for idx, (id, pred) in enumerate(zip(words.values(), word_pred.values())): 511 | predicted_label = model.config.id2label[pred] 512 | if predicted_label == "O": 513 | continue 514 | 515 | if predicted_label.startswith("B") and current_entity is None: 516 | # means we're at the start of a new entity 517 | current_entity = predicted_label.replace("B-", "") 518 | start = min(id) 519 | print(f"--------------New entity: at index {start}", current_entity) 520 | 521 | if current_entity is not None and current_entity not in predicted_label: 522 | # means we're at the end of a new entity 523 | end = max(words[idx - 1]) 524 | print("---------------End of new entity") 525 | entities.append((start, end, current_entity, id2label[current_entity])) 526 | current_entity = None 527 | 528 | if predicted_label.startswith("B") and current_entity is None: 529 | # means we're at the start of a new entity 530 | current_entity = predicted_label.replace("B-", "") 531 | start = min(id) 532 | print(f"--------------New entity: at index {start}", current_entity) 533 | 534 | # step 2: run LayoutLMForRelationExtraction 535 | entity_dict = {'start': [entity[0] for entity in entities], 536 | 'end': [entity[1] for entity in entities], 537 | 'label': [entity[3] for entity in entities]} 538 | 539 | relation_extraction_model = LayoutLMForRelationExtraction.from_pretrained("/home/kewen_yang/Information_Extraction/RE_HuggingFace/model/checkpoint_2023_07_11_15_49.pt/checkpoint-5000") 540 | # relation_extraction_model = LayoutLMForRelationExtraction.from_pretrained("nielsr/layoutxlm-finetuned-xfund-fr-re") 541 | relation_extraction_model.to(DEVICE) 542 | 543 | with torch.no_grad(): 544 | # inputs: {'input_ids', 'token_type_ids', 'attention_mask', 'bbox', 'image'} 545 | outputs = relation_extraction_model(**inputs, 546 | entities=[entity_dict], 547 | relations=[{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]) 548 | 549 | # show predicted key-values 550 | for relation in outputs.pred_relations[0]: 551 | head_start, head_end = relation['head'] 552 | tail_start, tail_end = relation['tail'] 553 | print("Question:", processor.decode(inputs.input_ids[0][head_start:head_end])) 554 | print("Answer:", processor.decode(inputs.input_ids[0][tail_start:tail_end])) 555 | print("----------") 556 | 557 | elif task.name == "XTRACT_INFER": 558 | """do inference by Xtract customized pipeline 559 | """ 560 | print("initiating inferencing ...") 561 | model_dir = "/home/kewen_yang/Information_Extraction/RE_HuggingFace/model/checkpoint_2023_08_01_15_26.pt" 562 | relation_extraction_model = LayoutLMForRelationExtraction.from_pretrained(model_dir) 563 | relation_extraction_model.to(DEVICE) 564 | tokenizer = LayoutLMTokenizer.from_pretrained(model_dir) 565 | 566 | for file_name in os.listdir("/home/kewen_yang/Information_Extraction/dataset/Xtract_json_Batch_3"): 567 | 568 | # file_name = "10.json" 569 | 570 | with open(ROOT / "dataset/Xtract_json_Batch_3" / file_name, "rb") as f: 571 | file = json.load(f) 572 | 573 | entity_dict = file["entity_dict"] 574 | 575 | # entity_dict = {k: v[8:10] for k, v in entity_dict.items()} 576 | 577 | inputs = file["input"] 578 | 579 | # del inputs["image"] # image is only applicable to v2 model now 580 | 581 | print("---------------------------------------------------------------------------------------------") 582 | print("key-values before feeding to RE model:") 583 | print(f'questions: {[tokenizer.decode([i for i in inputs["input_ids"][0][s:e]]) for s, e, l in zip(entity_dict["start"], entity_dict["end"], entity_dict["label"]) if l == 1]}') 584 | print(f'answers: {[tokenizer.decode([i for i in inputs["input_ids"][0][s:e]]) for s, e, l in zip(entity_dict["start"], entity_dict["end"], entity_dict["label"]) if l == 2]}') 585 | print("---------------------------------------------------------------------------------------------") 586 | 587 | for k, v in inputs.items(): 588 | inputs[k] = torch.tensor(inputs[k]) 589 | 590 | inputs = {k: v.to(DEVICE) for k, v in inputs.items()} 591 | 592 | if len(entity_dict["label"]) == 0: 593 | pred_rel = [] 594 | 595 | else: 596 | with torch.no_grad(): 597 | # inputs: {'input_ids', 'token_type_ids', 'attention_mask', 'bbox'} 598 | outputs = relation_extraction_model(**inputs, 599 | entities=[entity_dict], 600 | relations=[{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]) 601 | pred_rel = outputs.pred_relations[0] 602 | 603 | res = defaultdict(list) 604 | print("---------------------------------------------------------------------------------------------") 605 | print("key-value pairs by RE model:") 606 | for relation in pred_rel: 607 | head_start, head_end = relation['head'] 608 | tail_start, tail_end = relation['tail'] 609 | key_d = {} 610 | key_d["id"] = relation["head_id"] 611 | key_d["text"] = tokenizer.decode(inputs['input_ids'][0][head_start:head_end]) 612 | key_d["label"] = "question" 613 | key_d["bbox"] = torch.cat((inputs['bbox'][0][head_start:head_end].min(0).values[:2], inputs['bbox'][0][head_start:head_end].max(0).values[2:]), 0).tolist() 614 | key_d["linking"] = [[relation['head_id'], relation['tail_id']]] 615 | key_d["row_idx"] = entity_dict["row_idx"][entity_dict["start"].index(head_start)] 616 | res[relation["head_id"]].append(key_d) 617 | print(f"Question: {key_d['text']}") 618 | 619 | val_d = {} 620 | val_d["id"] = relation["tail_id"] 621 | val_d["text"] = tokenizer.decode(inputs['input_ids'][0][tail_start:tail_end]) 622 | val_d["label"] = "answer" 623 | val_d["bbox"] = torch.cat((inputs['bbox'][0][tail_start:tail_end].min(0).values[:2], inputs['bbox'][0][tail_start:tail_end].max(0).values[2:]), 0).tolist() 624 | val_d["linking"] = [[relation['head_id'], relation['tail_id']]] 625 | val_d["row_idx"] = entity_dict["row_idx"][entity_dict["start"].index(tail_start)] 626 | res[relation["tail_id"]].append(val_d) 627 | print(f"Answer:, {val_d['text']}") 628 | print("-------------------------") 629 | 630 | # remove duplicates 631 | def concat_links(lst): 632 | """concantenate all the links for an entity 633 | 634 | Args: 635 | lst: list of entities with same id, to be concatenated 636 | 637 | Returns: 638 | Dict: entity 639 | """ 640 | if len(lst) == 1: 641 | return lst[0] 642 | 643 | out = lst.pop(0) 644 | for e in lst: 645 | out["linking"].append(copy.deepcopy(e["linking"][0])) 646 | 647 | return out 648 | 649 | def choose_qentity(qns, ans): 650 | if len(qns) == 1: 651 | return None 652 | 653 | qns = sorted(qns, key=lambda q: (abs(q["row_idx"] - ans["row_idx"]), abs(ans["bbox"][0] - q["bbox"][2]))) 654 | 655 | return qns[0] 656 | 657 | res = {k: concat_links(v) for k, v in res.items()} 658 | 659 | # remove redundant links based on distance 660 | for i in res.keys(): 661 | if res[i]['label'] == "question": 662 | continue 663 | 664 | all_qns = [res[l[0]] for l in res[i]["linking"]] 665 | qns = choose_qentity(all_qns, res[i]) 666 | if qns is not None: 667 | res[i]["linking"] = [[int(qns["id"]), int(res[i]["id"])]] 668 | res[qns["id"]]["linking"] = [[int(qns["id"]), int(res[i]["id"])]] 669 | 670 | for i in res.keys(): 671 | if res[i]["label"] == "question" and len(res[i]["linking"]) > 1: 672 | 673 | drops = list() 674 | for l in res[i]["linking"]: 675 | if res[l[1]]["linking"][0][0] != i: 676 | drops.append(l) 677 | 678 | if drops: 679 | for l in drops: 680 | res[i]["linking"].remove(l) 681 | 682 | res = list(res.values()) 683 | 684 | keyvalue, unpaired = post_process_entities(res) 685 | 686 | out = {"keyvalue": keyvalue, "unpaired": unpaired} 687 | 688 | # print key value text 689 | printable = {e["id"]: e for e in out["keyvalue"]} 690 | print("---------------------------------------------------------------------------------------------") 691 | print("key-value pairs after post-processing:") 692 | for e in printable.values(): 693 | if e["label"] == "question": 694 | print(f'{e["text"]}:{printable[e["linking"]]["text"]}') 695 | 696 | with open(ROOT / "dataset/RE_pred_Batch_3" / file_name, 'w') as f: 697 | json.dump(out, f, indent=2) 698 | print("---------------------------------------------------------------------------------------------") 699 | 700 | ################################################ Appendix ################################################ 701 | """ 702 | Input Schema: 703 | 704 | { 705 | "id": image name, e.g. "en_train_0_0" 706 | "input_ids": List[token_ids], all_texts = "".join([tokenizer.decode(i) for i in dataset["train"][0]["input_ids"]]) 707 | "bbox": tensors, 708 | "labels": tensors, 709 | "hd_image": 710 | "entities": 711 | "relations": 712 | "attention_mask": 713 | "image": 714 | } 715 | 716 | for English, tokens usually equals to words, but when word is out of vacab, it will be tokenized into multiple common tokens. e.g. "vashering" --> ['va', '##sher', '##ing'] 717 | 718 | 719 | >>>>input feature: 720 | {"input_ids": [], "bbox": []} 721 | input_ids: the token ids 722 | bbox: bbox 723 | e.g. 724 | # show all the tokens 725 | [self.tokenizer.decode(i) for i in feature["input_ids"]] 726 | >> ['2', 'w', '##m', '##f', 'w', '##m', '##f', 'consumer', 'electric', 'gmbh', 'mess', '##ers', '##ch', '##mit', ...] 727 | 728 | # get the original texts 729 | self.tokenizer.convert_tokens_to_string([self.tokenizer.decode(i) for i in feature["input_ids"]]) 730 | >> '2 wmf wmf consumer electric gmbh messerschmittstrabe : d - 89343 jettingen - scheppach wmf kult x mono induction hob art. nr. : 04 1524 8811 2x art. nr. : 04 _ 1524 _ 8811 ean : 421112 9145688 cmmf : 3200001681' 731 | 732 | 733 | >>>>entities: 734 | {"start": [], "end": [], "label": []} 735 | 736 | "start": starting index of input_ids for this entity 737 | "end": ending index of input_ids for this entity 738 | "label": label 739 | 740 | e.g. 741 | ens = [self.tokenizer.convert_tokens_to_string([self.tokenizer.decode(i) for i in feature["input_ids"][s:e]]) for s, e in zip(entities[1]["start"], entities[1]["end"])] 742 | >> ['art. nr. : 04', '04 1524 8811 2', 'art. nr. : 04', '04 _ 1524 _ 8811 ea', 'ean : 421', '421112 9145688 cm', 'cmmf : 320', '3200001681'] 743 | 744 | 745 | >>>>relations: 746 | {"head": [], "tail": [], "start_index": [], "end_index" : []} 747 | 748 | "start_index": starting index of input_ids for this Question-Answer pair 749 | "end_index": ending index of input_ids for this Question-Answer pair 750 | 751 | e.g. 752 | [self.tokenizer.convert_tokens_to_string([self.tokenizer.decode(i) for i in feature["input_ids"][s:e]]) for s, e in zip(relations[1]["start_index"], relations[1]["end_index"])] 753 | >> ['art. nr. : 04 1524 8811', 'art. nr. : 04 _ 1524 _ 8811', 'ean : 421112 9145688', 'cmmf : 3200001681'] 754 | 755 | "head": question index w.r.t the index of "entities" 756 | "tail": answer index 757 | 758 | e.g. 759 | [(ens[h], ens[t]) for h, t in zip(relations[1]["head"], relations[1]["tail"])] 760 | >> [('art. nr. :', '04 1524 8811'), ('art. nr. :', '04 _ 1524 _ 8811'), ('ean :', '421112 9145688'), ('cmmf :', '3200001681')] 761 | """ -------------------------------------------------------------------------------- /RE_HuggingFace/layoutlmv2_re.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import enum 3 | from functools import reduce 4 | import json 5 | import logging 6 | from math import sqrt 7 | import sys 8 | from collections import defaultdict 9 | from dataclasses import dataclass 10 | from datetime import datetime 11 | from pathlib import Path 12 | from typing import Dict, List, Optional, Union 13 | from PIL import Image, JpegImagePlugin 14 | from torch.utils.data import DataLoader 15 | from transformers import LayoutLMTokenizer 16 | import pandas as pd 17 | from codecarbon import track_emissions 18 | from codecarbon import EmissionsTracker 19 | import wandb 20 | JpegImagePlugin._getmp = lambda: None 21 | import matplotlib 22 | import tornado 23 | 24 | matplotlib.use('WebAgg') 25 | import matplotlib.pyplot as plt 26 | import numpy as np 27 | import torch 28 | # enable only if using DGX machine to plot visuals 29 | from datasets import load_dataset 30 | 31 | 32 | from pytz import timezone 33 | from transformers import (AutoModelForTokenClassification, AutoProcessor, 34 | LayoutLMv2Tokenizer, 35 | LayoutLMv2FeatureExtractor, 36 | LayoutLMv2ForRelationExtraction, 37 | LayoutLMv2ForTokenClassification, 38 | LayoutLMv2Processor, PreTrainedTokenizerBase, 39 | TrainingArguments) 40 | from transformers.file_utils import PaddingStrategy 41 | 42 | from unilm.layoutlmft.layoutlmft.evaluation import re_score 43 | from unilm.layoutlmft.layoutlmft.trainers import XfunReTrainer 44 | 45 | torch.backends.cudnn.benchmark = False 46 | 47 | ROOT = Path(__file__).parents[1] 48 | 49 | TZ = timezone('Asia/Singapore') 50 | CURRENT = datetime.now(tz=TZ) 51 | TIME = CURRENT.strftime("%Y_%m_%d_%H_%M") 52 | MODDEL_DIR = ROOT / "RE_HuggingFace" / f"model/checkpoint_{TIME}.pt" 53 | DEVICE = "cuda" 54 | # DEVICE = "cpu" 55 | 56 | 57 | @enum.unique 58 | class Task(enum.Enum): 59 | FINETUNING = 0 60 | INFERENCE = 1 61 | XTRACT_INFER = 2 62 | 63 | 64 | task = Task.XTRACT_INFER 65 | 66 | 67 | @dataclass 68 | class DataCollatorForKeyValueExtraction: 69 | """ 70 | Data collator that will dynamically pad the inputs received, as well as the labels. 71 | 72 | Args: 73 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): 74 | The tokenizer used for encoding the data. 75 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): 76 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 77 | among: 78 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 79 | sequence if provided). 80 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 81 | maximum acceptable input length for the model if that argument is not provided. 82 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 83 | different lengths). 84 | max_length (:obj:`int`, `optional`): 85 | Maximum length of the returned list and optionally padding length (see above). 86 | pad_to_multiple_of (:obj:`int`, `optional`): 87 | If set will pad the sequence to a multiple of the provided value. 88 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 89 | 7.5 (Volta). 90 | label_pad_token_id (:obj:`int`, `optional`, defaults to -100): 91 | The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions). 92 | """ 93 | feature_extractor: LayoutLMv2FeatureExtractor 94 | tokenizer: PreTrainedTokenizerBase() 95 | padding: Union[bool, str, PaddingStrategy] = True 96 | max_length: Optional[int] = None 97 | pad_to_multiple_of: Optional[int] = None 98 | label_pad_token_id: int = -100 99 | 100 | def __call__(self, features): 101 | # prepare image input 102 | image = self.feature_extractor([feature["original_image"] for feature in features], return_tensors="pt").pixel_values 103 | 104 | # prepare text input 105 | entities = [] 106 | relations = [] 107 | for feature in features: 108 | del feature["image"] 109 | del feature["id"] 110 | del feature["labels"] 111 | del feature["original_image"] 112 | entities.append(feature["entities"]) 113 | del feature["entities"] 114 | relations.append(feature["relations"]) 115 | del feature["relations"] 116 | 117 | # set a break point here and check out the data schema based on Appendix 118 | batch = self.tokenizer.pad( 119 | features, 120 | padding=self.padding, 121 | max_length=self.max_length, 122 | pad_to_multiple_of=self.pad_to_multiple_of, 123 | return_tensors="pt" 124 | ) 125 | 126 | batch["image"] = image 127 | batch["entities"] = entities 128 | batch["relations"] = relations 129 | 130 | return batch 131 | 132 | 133 | def relate_to_entity(to_merge: set, list_of_entities: list) -> list: 134 | """ 135 | replace id with actual entity 136 | Args: 137 | to_merge: a set of all merging ids 138 | list_of_entities: list of entities 139 | Returns: list of entities to be merged, with no order 140 | """ 141 | out = list() 142 | for e in list_of_entities: 143 | if e.get("id", "") in to_merge: 144 | out.append(e) 145 | return out 146 | 147 | 148 | def merge_two_entities(e_0: dict, e_1: dict) -> dict: 149 | """ 150 | merge two entities 151 | Args: 152 | e_0: entity 153 | e_1: entity 154 | Returns: the combined entity 155 | """ 156 | if "bbox" not in e_0.keys(): 157 | return e_1 158 | if "bbox" not in e_1.keys(): 159 | return e_0 160 | 161 | # decide the base entity 162 | # compare the y0, smaller y0 -> base 163 | if e_0["bbox"][1] == e_1["bbox"][1]: 164 | if e_0["bbox"][0] <= e_1["bbox"][0]: 165 | base_entity = e_0 166 | adding_entity = e_1 167 | else: 168 | base_entity = e_1 169 | adding_entity = e_0 170 | elif e_0["bbox"][1] < e_1["bbox"][1]: 171 | base_entity = e_0 172 | adding_entity = e_1 173 | else: 174 | base_entity = e_1 175 | adding_entity = e_0 176 | 177 | base_entity = merge_text(base_entity, adding_entity) 178 | base_entity = merge_bbox(base_entity, adding_entity) 179 | return base_entity 180 | 181 | 182 | def merge_text(base_entity: dict, adding_entity: dict) -> dict: 183 | """ 184 | merge text from two entities 185 | Args: 186 | base_entity: the entity that text starts 187 | adding_entity: the entity to be added to base 188 | Returns: an entity with combined text 189 | """ 190 | base_entity["text"] = base_entity.get("text", "") + " " + adding_entity.get("text", "") 191 | return base_entity 192 | 193 | 194 | def merge_bbox(base_entity: dict, adding_entity: dict) -> dict: 195 | """ 196 | merge bbox from two entities 197 | Args: 198 | base_entity: the entity that bbox starts 199 | adding_entity: the entity to be added to base 200 | Returns: an entity with combined bbox 201 | """ 202 | base_entity["bbox"] = [ 203 | min(base_entity["bbox"][0], adding_entity["bbox"][0]), 204 | min(base_entity["bbox"][1], adding_entity["bbox"][1]), 205 | max(base_entity["bbox"][2], adding_entity["bbox"][2]), 206 | max(base_entity["bbox"][3], adding_entity["bbox"][3]) 207 | ] 208 | return base_entity 209 | 210 | 211 | def unnormalize_box(bbox, width, height): 212 | return [ 213 | width * (bbox[0] / 1000), 214 | height * (bbox[1] / 1000), 215 | width * (bbox[2] / 1000), 216 | height * (bbox[3] / 1000), 217 | ] 218 | 219 | 220 | def compute_metrics(p): 221 | pred_relations, gt_relations = p 222 | score = re_score(pred_relations, gt_relations, mode="boundaries") 223 | return score 224 | 225 | 226 | def post_process_entities(entity_list: list, threshold: int = 2) -> tuple: 227 | """ 228 | perform some cleaning and format changing task on entity_list 229 | Args: 230 | entity_list: all the entities after pairing step 231 | threshold: the largest vertical difference between two merging entities 232 | if diff > threshold, then no merge for these two entities. 233 | Returns: 234 | """ 235 | # relation diagram to illustrate multi-key and multi-val linking 236 | # Given key_0 and key_1 form a key, and its corresponding val is formed by val_0, val_1 and val_2, 237 | # their linking relation could be as below: 238 | # / val_0 / val_0 239 | # key_0- val_1 key_1- val_1 val_0 - key_1, val_1 - key_1, val_2 - key_1 240 | # \ val_2 \ val_2 241 | # caveat: for the linking in value entity, it only shows key_1 instead of key_1 and key_2 242 | 243 | # remove entities without any linking, and change key "box" to "bbox" 244 | unpaired = list() 245 | en_list_ori = copy.deepcopy(entity_list) 246 | for idx, entity in enumerate(en_list_ori): 247 | if "id" not in entity.keys(): 248 | unpaired.append(idx) 249 | continue 250 | 251 | # change id type to str 252 | entity["id"] = str(entity["id"]) 253 | 254 | # remove invalid entity 255 | if "linking" not in entity.keys() or len(entity["linking"]) == 0: 256 | unpaired.append(idx) 257 | continue 258 | 259 | entity_list = [i for idx, i in enumerate(en_list_ori) if idx not in unpaired] 260 | if len(unpaired) != 0: 261 | unpaired = [i for idx, i in enumerate(en_list_ori) if idx in unpaired] 262 | 263 | # all entities in the list should have "id" and "linking" keys afterwards 264 | merge_list = list() 265 | kv_indexes = [e["id"] for e in entity_list] 266 | 267 | def append_merge_set(merge_list: list, 268 | merge_set: set) -> list: 269 | """ 270 | append merge_set to merge_list 271 | Args: 272 | merge_list: merge_list 273 | merge_set: merge_set 274 | Returns: list 275 | """ 276 | if len(merge_list) == 0: 277 | merge_list.append(merge_set) 278 | return merge_list 279 | 280 | if merge_set in merge_list: 281 | return merge_list 282 | 283 | # there is overlaps between existing and merge_set 284 | for i in merge_list: 285 | if len(i.union(merge_set)) < len(i) + len(merge_set): 286 | i.update(merge_set) 287 | return merge_list 288 | 289 | merge_list.append(merge_set) 290 | return merge_list 291 | 292 | def get_counterpart(link: List, entity: Dict) -> str: 293 | # get the counterpart from a link 294 | if len(link) != 2: 295 | raise ValueError("link is not between 2 entities") 296 | 297 | id = entity["id"] 298 | for i in link: 299 | if str(i) != id: 300 | return str(i) 301 | 302 | # merge answers linked to the same question 303 | for entity in entity_list: 304 | # find the question entity which links multiple answers 305 | if entity.get("label", "") == "question" and len(entity["linking"]) > 1: 306 | merge_set = set() 307 | for link in entity["linking"]: 308 | cp = get_counterpart(link, entity) 309 | if cp in kv_indexes: 310 | merge_set.add(cp) 311 | 312 | if len(merge_set) >= 2: 313 | merge_list = append_merge_set(merge_list, merge_set) 314 | 315 | # merge questions with same linkings 316 | # for temp_dict, key: sorted list of counterparts; value: question id 317 | temp_dict = defaultdict(lambda: set()) 318 | for idx, entity in enumerate(entity_list): 319 | if entity.get("label", "") == "question": 320 | cps = sorted([get_counterpart(i, entity) for i in entity["linking"]]) 321 | temp_dict[tuple(cps)].add(entity["id"]) 322 | 323 | # add ids of merging keys to merge_list 324 | for q_ids in temp_dict.values(): 325 | if len(q_ids) > 1: 326 | merge_list = append_merge_set(merge_list, q_ids) 327 | 328 | if len(merge_list) > 0: 329 | for merge_set in merge_list: 330 | en_list = relate_to_entity(to_merge=merge_set, list_of_entities=entity_list) 331 | 332 | # filter the big gap in merge list 333 | en_list = sorted(en_list, key=lambda x: (x["bbox"][1], x["bbox"][0])) 334 | for idx in range(len(en_list) - 1): 335 | y_diff = en_list[idx + 1]["bbox"][1] - en_list[idx]["bbox"][3] 336 | char_height = abs(en_list[idx]["bbox"][3] - en_list[idx]["bbox"][1]) 337 | if y_diff > threshold * char_height: 338 | # stop merging at idx, **Noted that all entities are sorted by y0 339 | en_list = en_list[:(idx + 1)] 340 | break 341 | base = reduce(merge_two_entities, en_list) 342 | # add base and drop entities in merge_set 343 | entity_list = [e for e in entity_list if e.get('id', -1) not in merge_set] 344 | entity_list.append(base) 345 | 346 | # update valid ids 347 | kv_indexes = [e["id"] for e in entity_list] 348 | 349 | # change linking from list to single index 350 | for entity in entity_list: 351 | if len(entity["linking"]) == 1: 352 | entity["linking"] = get_counterpart(entity["linking"][0], entity) 353 | else: 354 | for link in entity["linking"]: 355 | index = get_counterpart(link, entity) 356 | if index in kv_indexes: 357 | entity["linking"] = index 358 | break 359 | 360 | keyvalue = [e for e in entity_list if isinstance(e["linking"], str)] 361 | return keyvalue, unpaired 362 | 363 | 364 | if task.name == "FINETUNING": 365 | 366 | logger = logging.getLogger(__name__) 367 | logger.setLevel(logging.INFO) 368 | logger.addHandler(logging.StreamHandler(stream=sys.stdout)) 369 | logger_path = ROOT / "RE_HuggingFace" / "artifacts" / f"experiment_{TIME}.log" 370 | fileHandler = logging.FileHandler(f"{logger_path.as_posix()}") 371 | logFormatter = logging.Formatter("[%(levelname)s] - %(message)s") 372 | fileHandler.setFormatter(logFormatter) 373 | logger.addHandler(fileHandler) 374 | 375 | # start a new wandb run to track this script 376 | wandb.init( 377 | # set the wandb project where this run will be logged 378 | project="RE", 379 | 380 | # track hyperparameters and run metadata 381 | config={ 382 | "dataset": "1st Batch", 383 | } 384 | ) 385 | 386 | dataset = load_dataset(path=(ROOT / "RE_HuggingFace" / "download.py").as_posix(), name="en") 387 | 388 | # # check if any bbox value > 1000, use it only for debugging >1000 error 389 | # for i in range(len(dataset["train"]["bbox"])): 390 | # print(max(torch.tensor(dataset["train"]["bbox"][i])[:, 1])) 391 | 392 | # model_card = "/home/kewen_yang/Information_Extraction/RE_HuggingFace/model/checkpoint_2023_07_13_17_37.pt" 393 | model_card = "microsoft/layoutlmv2-base-uncased" 394 | 395 | model = LayoutLMv2ForRelationExtraction.from_pretrained(model_card) 396 | tokenizer = LayoutLMv2Tokenizer.from_pretrained(model_card) 397 | 398 | logger.info(f"finetuning model on top of {model_card}") 399 | 400 | feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False) 401 | 402 | data_collator = DataCollatorForKeyValueExtraction( 403 | feature_extractor, 404 | tokenizer, 405 | pad_to_multiple_of=1, 406 | padding="max_length", 407 | max_length=512, 408 | ) 409 | 410 | train_dataset = dataset['train'] 411 | val_dataset = dataset['validation'] 412 | 413 | # Define TrainingArguments 414 | # See thread for hyperparameters: https://github.com/microsoft/unilm/issues/586 415 | training_args = TrainingArguments( 416 | output_dir=MODDEL_DIR, 417 | overwrite_output_dir=True, 418 | remove_unused_columns=False, 419 | # fp16=True, -> led to a loss of 0 420 | 421 | # max_steps=1000 + 5000 + 5000 + 5000, 422 | max_steps = 10, 423 | evaluation_strategy="steps", 424 | 425 | # num_train_epochs=1, 426 | # evaluation_strategy="epoch", 427 | 428 | no_cuda=(DEVICE == "cpu"), 429 | per_device_train_batch_size=2, 430 | per_device_eval_batch_size=1, 431 | warmup_ratio=0.1, 432 | learning_rate=1e-5, 433 | push_to_hub=False, 434 | report_to="wandb" 435 | ) 436 | 437 | # Initialize our Trainer 438 | trainer = XfunReTrainer( 439 | model=model, 440 | args=training_args, 441 | train_dataset=train_dataset, 442 | eval_dataset=val_dataset, 443 | tokenizer=tokenizer, 444 | data_collator=data_collator, 445 | compute_metrics=compute_metrics, 446 | ) 447 | logger.info("start training model") 448 | train_metrics = trainer.train(resume_from_checkpoint=False) 449 | logger.info(f"training_metrics: {train_metrics}") 450 | 451 | logger.info("start evaluating performance") 452 | eval_metrics = trainer.evaluate() 453 | logger.info(f"evaluation metrics: {eval_metrics}") 454 | trainer.save_model(MODDEL_DIR) 455 | 456 | learning_curve = pd.DataFrame(trainer.state.log_history) 457 | logger.info('\n\t' + learning_curve.to_string().replace('\n', '\n\t')) 458 | 459 | 460 | elif task.name == "INFERENCE": 461 | """do inference by huggingface pipeline 462 | """ 463 | 464 | # test_image = train_dataset[48]['original_image'] 465 | test_image = Image.open(ROOT / 'dataset/test/AUTOVACSTORE-1-2-Bing-image_010.jpg') 466 | # plt.imshow(test_image) 467 | # plt.show() 468 | 469 | # load model + processor from the hub 470 | processor = AutoProcessor.from_pretrained(ROOT / "SER_HuggingFace" / "model") 471 | model = AutoModelForTokenClassification.from_pretrained(ROOT / "SER_HuggingFace" / "model") 472 | # prepare inputs for the model 473 | # we set `return_offsets_mapping=True` as we use the offsets to know which tokens are subwords and which aren't 474 | inputs = processor(test_image, return_offsets_mapping=True, padding="max_length", max_length=512, truncation=True, return_tensors="pt") 475 | 476 | original_text = processor.tokenizer.convert_tokens_to_string([processor.tokenizer.decode(i, skip_special_tokens=True) for i in inputs["input_ids"][0].tolist()]) 477 | # all_token_text = [processor.tokenizer.decode(i, skip_special_tokens=True) for i in inputs["input_ids"][0].tolist()] 478 | 479 | inputs = inputs.to(DEVICE) 480 | model.to(DEVICE) 481 | 482 | # offset_mapping: indicates the start and end index of the actual subword w.r.t each token text, e.g. '##omi' with offset [2, 5] -> 'omi' 483 | offset_mapping = inputs.pop("offset_mapping") 484 | 485 | # word_ids: indicates if the subtokens belong to the same word. 486 | word_ids = inputs.encodings[0].word_ids 487 | 488 | token_ids = inputs.input_ids[0].tolist() 489 | 490 | if_special = inputs.encodings[0].special_tokens_mask 491 | 492 | # forward pass 493 | with torch.no_grad(): 494 | outputs = model(**inputs) 495 | 496 | # take argmax on last dimension to get predicted class ID per token 497 | predictions = outputs.logits.argmax(-1).squeeze().tolist() 498 | 499 | # # check if it's subwords 500 | # is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0 501 | 502 | # merge subwords into word-level based on word_ids 503 | word_pred = defaultdict(lambda: -1) 504 | words = defaultdict(list) 505 | for idx, tp in enumerate(zip(if_special, token_ids, predictions, word_ids)): 506 | if idx == 0 or bool(tp[0]): 507 | continue 508 | 509 | words[tp[-1]].append(idx) 510 | if word_pred[tp[-1]] == -1: 511 | word_pred[tp[-1]] = tp[2] 512 | 513 | id2label = {"QUESTION": 1, "ANSWER": 2} 514 | 515 | # finally, store recognized "question" and "answer" entities in a list 516 | entities = [] 517 | current_entity = None 518 | start = None 519 | end = None 520 | 521 | for idx, (id, pred) in enumerate(zip(words.values(), word_pred.values())): 522 | predicted_label = model.config.id2label[pred] 523 | if predicted_label == "O": 524 | continue 525 | 526 | if predicted_label.startswith("B") and current_entity is None: 527 | # means we're at the start of a new entity 528 | current_entity = predicted_label.replace("B-", "") 529 | start = min(id) 530 | print(f"--------------New entity: at index {start}", current_entity) 531 | 532 | if current_entity is not None and current_entity not in predicted_label: 533 | # means we're at the end of a new entity 534 | end = max(words[idx - 1]) 535 | print("---------------End of new entity") 536 | entities.append((start, end, current_entity, id2label[current_entity])) 537 | current_entity = None 538 | 539 | if predicted_label.startswith("B") and current_entity is None: 540 | # means we're at the start of a new entity 541 | current_entity = predicted_label.replace("B-", "") 542 | start = min(id) 543 | print(f"--------------New entity: at index {start}", current_entity) 544 | 545 | # step 2: run LayoutLMv2ForRelationExtraction 546 | entity_dict = {'start': [entity[0] for entity in entities], 547 | 'end': [entity[1] for entity in entities], 548 | 'label': [entity[3] for entity in entities]} 549 | 550 | relation_extraction_model = LayoutLMv2ForRelationExtraction.from_pretrained("/home/kewen_yang/Information_Extraction/RE_HuggingFace/model/checkpoint_2023_07_11_15_49.pt/checkpoint-5000") 551 | # relation_extraction_model = LayoutLMv2ForRelationExtraction.from_pretrained("nielsr/layoutxlm-finetuned-xfund-fr-re") 552 | relation_extraction_model.to(DEVICE) 553 | 554 | with torch.no_grad(): 555 | # inputs: {'input_ids', 'token_type_ids', 'attention_mask', 'bbox', 'image'} 556 | outputs = relation_extraction_model(**inputs, 557 | entities=[entity_dict], 558 | relations=[{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]) 559 | 560 | # show predicted key-values 561 | for relation in outputs.pred_relations[0]: 562 | head_start, head_end = relation['head'] 563 | tail_start, tail_end = relation['tail'] 564 | print("Question:", processor.decode(inputs.input_ids[0][head_start:head_end])) 565 | print("Answer:", processor.decode(inputs.input_ids[0][tail_start:tail_end])) 566 | print("----------") 567 | 568 | elif task.name == "XTRACT_INFER": 569 | """do inference by Xtract customized pipeline 570 | """ 571 | print("run inferencing ...") 572 | relation_extraction_model = LayoutLMv2ForRelationExtraction.from_pretrained("/home/kewen_yang/Information_Extraction/RE_HuggingFace/model/checkpoint_2023_07_12_14_54.pt") 573 | relation_extraction_model.to(DEVICE) 574 | tokenizer = LayoutLMv2Tokenizer.from_pretrained("/home/kewen_yang/Information_Extraction/RE_HuggingFace/model/checkpoint_2023_07_12_14_54.pt") 575 | 576 | with open(ROOT / "dataset/test/AUTOVACSTORE-1-2-Bing-image_010.json", "rb") as f: 577 | file = json.load(f) 578 | 579 | entity_dict = file["entity_dict"] 580 | 581 | entity_dict = {k: v[:4] for k, v in entity_dict.items()} 582 | 583 | inputs = file["input"] 584 | 585 | # # load image 586 | test_image = Image.open(ROOT / 'dataset/test/AUTOVACSTORE-1-2-Bing-image_010.jpg') 587 | # rescale image as RE model requires 224 x 224 588 | test_image = test_image.resize((224, 224), resample=Image.Resampling.BILINEAR) 589 | test_image = np.array(test_image) 590 | # channel first 591 | test_image = test_image.transpose(2, 0, 1) 592 | # flip color channels from RGB to BGR (as Detectron2 requires this) 593 | test_image = test_image[::-1, :, :] 594 | 595 | inputs["image"] = [test_image] 596 | print("----------------------------entities----------------------------------------") 597 | print(f'questions: {[tokenizer.decode([i for i in inputs["input_ids"][0][s:e]]) for s, e, l in zip(entity_dict["start"], entity_dict["end"], entity_dict["label"]) if l == 1]}') 598 | print(f'answers: {[tokenizer.decode([i for i in inputs["input_ids"][0][s:e]]) for s, e, l in zip(entity_dict["start"], entity_dict["end"], entity_dict["label"]) if l == 2]}') 599 | print("----------------------------------------------------------------------------") 600 | 601 | for k, v in inputs.items(): 602 | inputs[k] = torch.tensor(inputs[k]) 603 | 604 | inputs = {k: v.to(DEVICE) for k, v in inputs.items()} 605 | 606 | with torch.no_grad(): 607 | # inputs: {'input_ids', 'token_type_ids', 'attention_mask', 'bbox', 'image'} 608 | outputs = relation_extraction_model(**inputs, 609 | entities=[entity_dict], 610 | relations=[{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]) 611 | 612 | res = defaultdict(list) 613 | done = set() 614 | for relation in outputs.pred_relations[0]: 615 | head_start, head_end = relation['head'] 616 | tail_start, tail_end = relation['tail'] 617 | key_d = {} 618 | key_d["id"] = relation["head_id"] 619 | key_d["text"] = tokenizer.decode(inputs['input_ids'][0][head_start:head_end]) 620 | key_d["label"] = "question" 621 | key_d["bbox"] = torch.cat((inputs['bbox'][0][head_start:head_end].min(0).values[:2], inputs['bbox'][0][head_start:head_end].max(0).values[2:]), 0).tolist() 622 | key_d["linking"] = [[relation['head_id'], relation['tail_id']]] 623 | res[relation["head_id"]].append(key_d) 624 | done.add(key_d["id"]) 625 | print(f"Question: {key_d['text']}") 626 | 627 | val_d = {} 628 | val_d["id"] = relation["tail_id"] 629 | val_d["text"] = tokenizer.decode(inputs['input_ids'][0][tail_start:tail_end]) 630 | val_d["label"] = "answer" 631 | val_d["bbox"] = torch.cat((inputs['bbox'][0][tail_start:tail_end].min(0).values[:2], inputs['bbox'][0][tail_start:tail_end].max(0).values[2:]), 0).tolist() 632 | val_d["linking"] = [[relation['head_id'], relation['tail_id']]] 633 | res[relation["tail_id"]].append(val_d) 634 | done.add(val_d["id"]) 635 | print(f"Answer:, {val_d['text']}") 636 | print("----------") 637 | 638 | # remove duplicates 639 | def remove_dup(lst): 640 | if len(lst) == 1: 641 | return lst[0] 642 | 643 | out = lst.pop(0) 644 | for e in lst: 645 | out["linking"].append(copy.deepcopy(e["linking"][0])) 646 | 647 | return out 648 | 649 | def get_potential_que(avail_qns, ans, res): 650 | if len(avail_qns) == 1: 651 | return None 652 | 653 | qns = [res[id] for id in avail_qns] 654 | out = qns.pop(0) 655 | dis = sqrt((ans["bbox"][0] - out["bbox"][2])**2 + (ans["bbox"][3] - out["bbox"][3])**2) 656 | for q in qns: 657 | if dis >= sqrt((ans["bbox"][0] - q["bbox"][2])**2 + (ans["bbox"][3] - q["bbox"][3])**2): 658 | out = q 659 | dis = sqrt((ans["bbox"][0] - q["bbox"][2])**2 + (ans["bbox"][3] - q["bbox"][3])**2) 660 | 661 | return out 662 | 663 | res = {k: remove_dup(v) for k, v in res.items()} 664 | # remove links based on distance 665 | for i in res.keys(): 666 | if res[i]['label'] == "question": 667 | continue 668 | 669 | avail_qs = [l[0] for l in res[i]["linking"]] 670 | qns = get_potential_que(avail_qs, res[i], res) 671 | if qns is not None: 672 | res[i]["linking"] = [[int(qns["id"]), int(res[i]["id"])]] 673 | res[qns["id"]]["linking"] = [[int(qns["id"]), int(res[i]["id"])]] 674 | 675 | for i in res.keys(): 676 | if res[i]["label"] == "question" and len(res[i]["linking"]) > 1: 677 | 678 | drops = list() 679 | for l in res[i]["linking"]: 680 | if res[l[1]]["linking"][0][0] != i: 681 | drops.append(l) 682 | 683 | if drops: 684 | for l in drops: 685 | res[i]["linking"].remove(l) 686 | 687 | res = list(res.values()) 688 | 689 | keyvalue, unpaired = post_process_entities(res) 690 | 691 | out = {"keyvalue": keyvalue, "unpaired": unpaired} 692 | 693 | # print key value text 694 | printable = {e["id"]: e for e in out["keyvalue"]} 695 | print("Extracted Key-value pairs are:") 696 | for e in printable.values(): 697 | if e["label"] == "question": 698 | print(f'{e["text"]}:{printable[e["linking"]]["text"]}') 699 | 700 | with open("/home/kewen_yang/Information_Extraction/dataset/RE_res/1.json", 'w') as f: 701 | json.dump(out, f) 702 | print() 703 | 704 | ################################################ Appendix ################################################ 705 | """ 706 | Input Schema: 707 | 708 | { 709 | "id": image name, e.g. "en_train_0_0" 710 | "input_ids": List[token_ids], all_texts = "".join([tokenizer.decode(i) for i in dataset["train"][0]["input_ids"]]) 711 | "bbox": tensors, 712 | "labels": tensors, 713 | "hd_image": 714 | "entities": 715 | "relations": 716 | "attention_mask": 717 | "image": 718 | } 719 | 720 | for English, tokens usually equals to words, but when word is out of vacab, it will be tokenized into multiple common tokens. e.g. "vashering" --> ['va', '##sher', '##ing'] 721 | 722 | 723 | >>>>input feature: 724 | {"input_ids": [], "bbox": []} 725 | input_ids: the token ids 726 | bbox: bbox 727 | e.g. 728 | # show all the tokens 729 | [self.tokenizer.decode(i) for i in feature["input_ids"]] 730 | >> ['2', 'w', '##m', '##f', 'w', '##m', '##f', 'consumer', 'electric', 'gmbh', 'mess', '##ers', '##ch', '##mit', ...] 731 | 732 | # get the original texts 733 | self.tokenizer.convert_tokens_to_string([self.tokenizer.decode(i) for i in feature["input_ids"]]) 734 | >> '2 wmf wmf consumer electric gmbh messerschmittstrabe : d - 89343 jettingen - scheppach wmf kult x mono induction hob art. nr. : 04 1524 8811 2x art. nr. : 04 _ 1524 _ 8811 ean : 421112 9145688 cmmf : 3200001681' 735 | 736 | 737 | >>>>entities: 738 | {"start": [], "end": [], "label": []} 739 | 740 | "start": starting index of input_ids for this entity 741 | "end": ending index of input_ids for this entity 742 | "label": label 743 | 744 | e.g. 745 | ens = [self.tokenizer.convert_tokens_to_string([self.tokenizer.decode(i) for i in feature["input_ids"][s:e]]) for s, e in zip(entities[1]["start"], entities[1]["end"])] 746 | >> ['art. nr. : 04', '04 1524 8811 2', 'art. nr. : 04', '04 _ 1524 _ 8811 ea', 'ean : 421', '421112 9145688 cm', 'cmmf : 320', '3200001681'] 747 | 748 | 749 | >>>>relations: 750 | {"head": [], "tail": [], "start_index": [], "end_index" : []} 751 | 752 | "start_index": starting index of input_ids for this Question-Answer pair 753 | "end_index": ending index of input_ids for this Question-Answer pair 754 | 755 | e.g. 756 | [self.tokenizer.convert_tokens_to_string([self.tokenizer.decode(i) for i in feature["input_ids"][s:e]]) for s, e in zip(relations[1]["start_index"], relations[1]["end_index"])] 757 | >> ['art. nr. : 04 1524 8811', 'art. nr. : 04 _ 1524 _ 8811', 'ean : 421112 9145688', 'cmmf : 3200001681'] 758 | 759 | "head": question index w.r.t the index of "entities" 760 | "tail": answer index 761 | 762 | e.g. 763 | [(ens[h], ens[t]) for h, t in zip(relations[1]["head"], relations[1]["tail"])] 764 | >> [('art. nr. :', '04 1524 8811'), ('art. nr. :', '04 _ 1524 _ 8811'), ('ean :', '421112 9145688'), ('cmmf :', '3200001681')] 765 | """ --------------------------------------------------------------------------------