├── .gitignore ├── README.md ├── data ├── __init__.py ├── cord │ ├── cord.py │ └── graph_cord.py └── data_collator.py ├── examples └── run_cord.py ├── metrics └── seqeval │ └── seqeval.py ├── model ├── __init__.py ├── basemodel.py ├── configuration_graphlayoutlm.py ├── graphlayoutlm.py └── tokenization_graphlayoutlm_fast.py ├── requirements.txt └── utils ├── __init__.py ├── graph_builder_uitls.py ├── graph_utils.py └── image_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | /pretrained/ 3 | /datasets/ 4 | /path/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GraphLayoutLM 2 | 3 | ## Installation 4 | 5 | ``` 6 | git clone https://github.com/Line-Kite/GraphLayoutLM 7 | cd GraphLayoutLM 8 | conda create -n graphlayoutlm python=3.7 9 | conda activate graphlayoutlm 10 | pip install torch==1.10.0+cu111 torchvision==0.11.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html 11 | python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.10/index.html 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | 16 | ## Pre-trained Models 17 | 18 | **Password: 2023** 19 | 20 | | Model | Model Name (Path) | 21 | |---------------------|----------------------------------------------------------------------------------------------------------------| 22 | | graphlayoutlm-base | [graphlayoutlm-base](https://pan.baidu.com/s/1xc6kDOc_CWTXYbGMwrocHQ) | 23 | | graphlayoutlm-large | [graphlayoutlm-large](https://pan.baidu.com/s/1uyF-dS7vcY0-fUT5MXibdA) | 24 | 25 | 26 | ## Finetuning Examples 27 | 28 | ### CORD 29 | 30 | **Password: 2023** 31 | 32 | |Model on CORD | precision | recall | f1 | accuracy | 33 | |:---------------------------------------------------------------------------------------------------------------------------:|:---------:|:------:|:--------:|:--------:| 34 | | [graphlayout-base-finetuned-cord](https://pan.baidu.com/s/1lLiDR4Cw07HRcnlZ4qjSdw) | 0.9724 | 0.9760 | 0.9742 | 0.9813 | 35 | | [graphlayout-large-finetuned-cord](https://pan.baidu.com/s/1tZs60aTzQp1esaj0Bw8C9g) | 0.9791 | 0.9805 | 0.9798 | 0.9839 | 36 | 37 | #### finetune 38 | 39 | Download the model weights and move it to a new directory named "pretrained". 40 | 41 | Download the [CORD](https://drive.google.com/drive/folders/14OEWr86qotVBMAsWk7lymMytxn5u-kM6) dataset and move it to a new directory named "datasets". 42 | 43 | **base** 44 | 45 | ``` 46 | cd examples 47 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --main_process_port 20655 run_cord.py \ 48 | --dataset_name cord \ 49 | --do_train \ 50 | --do_eval \ 51 | --model_name_or_path ../pretrained/graphlayoutlm-base \ 52 | --output_dir ../path/cord/base/test \ 53 | --segment_level_layout 1 --visual_embed 1 --input_size 224 \ 54 | --max_steps 2000 --save_steps -1 --evaluation_strategy steps --eval_steps 100 \ 55 | --learning_rate 5e-5 --per_device_train_batch_size 2 --gradient_accumulation_steps 1 \ 56 | --dataloader_num_workers 8 --overwrite_output_dir 57 | ``` 58 | 59 | **large** 60 | 61 | ``` 62 | cd examples 63 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --main_process_port 20655 run_cord.py \ 64 | --dataset_name cord \ 65 | --do_train \ 66 | --do_eval \ 67 | --model_name_or_path ../pretrained/graphlayoutlm-large \ 68 | --output_dir ../path/cord/large/test \ 69 | --segment_level_layout 1 --visual_embed 1 --input_size 224 \ 70 | --max_steps 4000 --save_steps -1 --evaluation_strategy steps --eval_steps 100 \ 71 | --learning_rate 5e-5 --per_device_train_batch_size 2 --gradient_accumulation_steps 1 \ 72 | --dataloader_num_workers 8 --overwrite_output_dir 73 | ``` 74 | 75 | 76 | ## Citation 77 | Please cite our paper if the work helps you. 78 | ``` 79 | @inproceedings{li2023enhancing, 80 | title={Enhancing Visually-Rich Document Understanding via Layout Structure Modeling}, 81 | author={Li, Qiwei and Li, Zuchao and Cai, Xiantao and Du, Bo and Zhao, Hai}, 82 | booktitle={Proceedings of the 31st ACM International Conference on Multimedia}, 83 | pages={4513--4523}, 84 | year={2023} 85 | } 86 | ``` 87 | 88 | 89 | ## Note 90 | 91 | We will follow-up complement other examples. 92 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/cord/cord.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reference: https://huggingface.co/datasets/pierresi/cord/blob/main/cord.py 3 | ''' 4 | 5 | 6 | import json 7 | import os 8 | from pathlib import Path 9 | import datasets 10 | from utils.image_utils import load_image, normalize_bbox 11 | from data.cord.graph_cord import graph_builder 12 | logger = datasets.logging.get_logger(__name__) 13 | _CITATION = """\ 14 | @article{park2019cord, 15 | title={CORD: A Consolidated Receipt Dataset for Post-OCR Parsing}, 16 | author={Park, Seunghyun and Shin, Seung and Lee, Bado and Lee, Junyeop and Surh, Jaeheung and Seo, Minjoon and Lee, Hwalsuk} 17 | booktitle={Document Intelligence Workshop at Neural Information Processing Systems} 18 | year={2019} 19 | } 20 | """ 21 | _DESCRIPTION = """\ 22 | https://github.com/clovaai/cord/ 23 | """ 24 | 25 | def quad_to_box(quad): 26 | # test 87 is wrongly annotated 27 | box = ( 28 | max(0, quad["x1"]), 29 | max(0, quad["y1"]), 30 | quad["x3"], 31 | quad["y3"] 32 | ) 33 | if box[3] < box[1]: 34 | bbox = list(box) 35 | tmp = bbox[3] 36 | bbox[3] = bbox[1] 37 | bbox[1] = tmp 38 | box = tuple(bbox) 39 | if box[2] < box[0]: 40 | bbox = list(box) 41 | tmp = bbox[2] 42 | bbox[2] = bbox[0] 43 | bbox[0] = tmp 44 | box = tuple(bbox) 45 | return box 46 | 47 | 48 | class CordConfig(datasets.BuilderConfig): 49 | """BuilderConfig for CORD""" 50 | def __init__(self, **kwargs): 51 | """BuilderConfig for CORD. 52 | Args: 53 | **kwargs: keyword arguments forwarded to super. 54 | """ 55 | super(CordConfig, self).__init__(**kwargs) 56 | 57 | class Cord(datasets.GeneratorBasedBuilder): 58 | BUILDER_CONFIGS = [ 59 | CordConfig(name="cord", version=datasets.Version("1.0.0"), description="CORD dataset"), 60 | ] 61 | 62 | def _info(self): 63 | return datasets.DatasetInfo( 64 | description=_DESCRIPTION, 65 | features=datasets.Features( 66 | { 67 | "id": datasets.Value("string"), 68 | "words": datasets.Sequence(datasets.Value("string")), 69 | "bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))), 70 | "node_ids": datasets.Sequence(datasets.Value("int64")), 71 | "edges": datasets.Sequence( 72 | { 73 | "head": datasets.Value("int64"), 74 | "tail": datasets.Value("int64"), 75 | "rel": datasets.Value("string"), 76 | } 77 | ), 78 | "ner_tags": datasets.Sequence( 79 | datasets.features.ClassLabel( 80 | names=["O","B-MENU.NM","B-MENU.NUM","B-MENU.UNITPRICE","B-MENU.CNT","B-MENU.DISCOUNTPRICE","B-MENU.PRICE","B-MENU.ITEMSUBTOTAL","B-MENU.VATYN","B-MENU.ETC","B-MENU.SUB_NM","B-MENU.SUB_UNITPRICE","B-MENU.SUB_CNT","B-MENU.SUB_PRICE","B-MENU.SUB_ETC","B-VOID_MENU.NM","B-VOID_MENU.PRICE","B-SUB_TOTAL.SUBTOTAL_PRICE","B-SUB_TOTAL.DISCOUNT_PRICE","B-SUB_TOTAL.SERVICE_PRICE","B-SUB_TOTAL.OTHERSVC_PRICE","B-SUB_TOTAL.TAX_PRICE","B-SUB_TOTAL.ETC","B-TOTAL.TOTAL_PRICE","B-TOTAL.TOTAL_ETC","B-TOTAL.CASHPRICE","B-TOTAL.CHANGEPRICE","B-TOTAL.CREDITCARDPRICE","B-TOTAL.EMONEYPRICE","B-TOTAL.MENUTYPE_CNT","B-TOTAL.MENUQTY_CNT","I-MENU.NM","I-MENU.NUM","I-MENU.UNITPRICE","I-MENU.CNT","I-MENU.DISCOUNTPRICE","I-MENU.PRICE","I-MENU.ITEMSUBTOTAL","I-MENU.VATYN","I-MENU.ETC","I-MENU.SUB_NM","I-MENU.SUB_UNITPRICE","I-MENU.SUB_CNT","I-MENU.SUB_PRICE","I-MENU.SUB_ETC","I-VOID_MENU.NM","I-VOID_MENU.PRICE","I-SUB_TOTAL.SUBTOTAL_PRICE","I-SUB_TOTAL.DISCOUNT_PRICE","I-SUB_TOTAL.SERVICE_PRICE","I-SUB_TOTAL.OTHERSVC_PRICE","I-SUB_TOTAL.TAX_PRICE","I-SUB_TOTAL.ETC","I-TOTAL.TOTAL_PRICE","I-TOTAL.TOTAL_ETC","I-TOTAL.CASHPRICE","I-TOTAL.CHANGEPRICE","I-TOTAL.CREDITCARDPRICE","I-TOTAL.EMONEYPRICE","I-TOTAL.MENUTYPE_CNT","I-TOTAL.MENUQTY_CNT"] 81 | ) 82 | ), 83 | "image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"), 84 | "image_path": datasets.Value("string"), 85 | } 86 | ), 87 | supervised_keys=None, 88 | citation=_CITATION, 89 | homepage="https://github.com/clovaai/cord/", 90 | ) 91 | 92 | def _split_generators(self, dl_manager): 93 | """Returns SplitGenerators.""" 94 | """Uses local files located with data_dir""" 95 | dest=r"../datasets/CORD" 96 | if not os.path.exists(os.path.join(dest,"train","graph")): 97 | graph_builder(dest) 98 | return [ 99 | datasets.SplitGenerator( 100 | name=datasets.Split.TRAIN, gen_kwargs={"filepath": dest+"/train"} 101 | ), 102 | datasets.SplitGenerator( 103 | name=datasets.Split.VALIDATION, gen_kwargs={"filepath": dest+"/dev"} 104 | ), 105 | datasets.SplitGenerator( 106 | name=datasets.Split.TEST, gen_kwargs={"filepath": dest+"/test"} 107 | ), 108 | ] 109 | 110 | def get_line_bbox(self, bboxs): 111 | x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)] 112 | y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)] 113 | 114 | x0, y0, x1, y1 = min(x), min(y), max(x), max(y) 115 | 116 | assert x1 >= x0 and y1 >= y0 117 | bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))] 118 | return bbox 119 | 120 | def _generate_examples(self, filepath): 121 | logger.info("⏳ Generating examples from = %s", filepath) 122 | ann_dir = os.path.join(filepath, "reordered_json") 123 | graph_dir=os.path.join(filepath, "graph") 124 | img_dir = os.path.join(filepath, "image") 125 | for guid, file in enumerate(sorted(os.listdir(ann_dir))): 126 | words = [] 127 | bboxes = [] 128 | ner_tags = [] 129 | node_ids=[] 130 | edges=[] 131 | 132 | with open(os.path.join(graph_dir, file), "r", encoding="utf8") as f: 133 | graph_data = json.load(f) 134 | edges=graph_data["edges"] 135 | if len(edges)==0: 136 | print("len error") 137 | exit(0) 138 | 139 | file_path = os.path.join(ann_dir, file) 140 | with open(file_path, "r", encoding="utf8") as f: 141 | data = json.load(f) 142 | image_path = os.path.join(img_dir, file) 143 | image_path = image_path.replace("json", "png") 144 | image, size = load_image(image_path) 145 | for item in data["valid_line"]: 146 | cur_line_bboxes = [] 147 | line_words, label = item["words"], item["category"] 148 | line_words = [w for w in line_words if w["text"].strip() != ""] 149 | if len(line_words) == 0: 150 | continue 151 | if label == "other": 152 | for w in line_words: 153 | words.append(w["text"]) 154 | ner_tags.append("O") 155 | cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size)) 156 | node_ids.append(item["id"]) 157 | else: 158 | words.append(line_words[0]["text"]) 159 | ner_tags.append("B-" + label.upper()) 160 | cur_line_bboxes.append(normalize_bbox(quad_to_box(line_words[0]["quad"]), size)) 161 | node_ids.append(item["id"]) 162 | for w in line_words[1:]: 163 | words.append(w["text"]) 164 | ner_tags.append("I-" + label.upper()) 165 | cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size)) 166 | node_ids.append(item["id"]) 167 | # by default: --segment_level_layout 1 168 | # if do not want to use segment_level_layout, comment the following line 169 | cur_line_bboxes = self.get_line_bbox(cur_line_bboxes) 170 | bboxes.extend(cur_line_bboxes) 171 | # yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags, "image": image} 172 | yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags, 173 | "node_ids":node_ids,"edges":edges, 174 | "image": image, "image_path": image_path} 175 | -------------------------------------------------------------------------------- /data/cord/graph_cord.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append("..") 5 | 6 | from utils.graph_builder_uitls import OPPOSITE, TreeNode, json_loader, json_saver, node_box_update, posotion_judge 7 | 8 | 9 | def get_item(id,list): 10 | for item in list: 11 | if id==item["id"]: 12 | return item 13 | 14 | 15 | def get_line_bbox(bboxs): 16 | x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)] 17 | y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)] 18 | 19 | x0, y0, x1, y1 = min(x), min(y), max(x), max(y) 20 | 21 | assert x1 >= x0 and y1 >= y0 22 | bbox = [x0, y0, x1, y1] 23 | return bbox 24 | 25 | 26 | def quad_to_box(quad): 27 | box = ( 28 | max(0, quad["x1"]), 29 | max(0, quad["y1"]), 30 | quad["x3"], 31 | quad["y3"] 32 | ) 33 | if box[3] < box[1]: 34 | bbox = list(box) 35 | tmp = bbox[3] 36 | bbox[3] = bbox[1] 37 | bbox[1] = tmp 38 | box = tuple(bbox) 39 | if box[2] < box[0]: 40 | bbox = list(box) 41 | tmp = bbox[2] 42 | bbox[2] = bbox[0] 43 | bbox[0] = tmp 44 | box = tuple(bbox) 45 | return box 46 | 47 | 48 | def data_preprocess(data): 49 | if "id" not in data["valid_line"][0].keys(): 50 | for i,item in enumerate(data["valid_line"]): 51 | item["id"]=i 52 | return data 53 | 54 | 55 | def insertion_reorder(nodes): 56 | sorted_nodes = [] 57 | for node in nodes: 58 | len_sorted=len(sorted_nodes) 59 | if len_sorted==0: 60 | sorted_nodes.append(node) 61 | continue 62 | for i in range(len_sorted): 63 | rel = posotion_judge(sorted_nodes[i].box,node.box) 64 | if rel=="up-left" or rel=="left" or rel=="up" or rel=="up-right": 65 | sorted_nodes.insert(i, node) 66 | break 67 | else: 68 | if i==len_sorted-1: 69 | sorted_nodes.append(node) 70 | assert len(sorted_nodes)==len(nodes) 71 | return sorted_nodes 72 | 73 | 74 | def get_relationship(node_a,node_list,edges): 75 | for node in node_list: 76 | if node.id==node_a.id: 77 | continue 78 | rel=posotion_judge(node_a.box,node.box) 79 | if rel=="left" or rel=="right" or rel=="up" or rel=="down": 80 | edges["edges"].append({"head": node_a.id, "tail": node.id, "rel": rel}) 81 | 82 | 83 | def tree_builder(data): 84 | img_width=data["meta"]["image_size"]["width"] 85 | img_height = data["meta"]["image_size"]["height"] 86 | root=TreeNode(-1,[0,0,img_width,img_height]) 87 | groups={} 88 | for item in data["valid_line"]: 89 | words_box=[] 90 | for word in item["words"]: 91 | quad=word["quad"] 92 | words_box.append(quad_to_box(quad)) 93 | if item["group_id"] not in groups.keys(): 94 | groups[item["group_id"]]=[] 95 | groups[item["group_id"]].append(TreeNode(item["id"], get_line_bbox(words_box))) 96 | for key in groups.keys(): 97 | sorted_nodes=insertion_reorder(groups[key]) 98 | child_tree_root=None 99 | for i,node in enumerate(sorted_nodes): 100 | if i==0: 101 | child_tree_root=node 102 | node.parent=root 103 | root.children.append(node) 104 | node.children=sorted_nodes[1:] 105 | continue 106 | else: 107 | node.parent=child_tree_root 108 | node_box_update(child_tree_root,node) 109 | sorted_group_nodes=insertion_reorder(root.children) 110 | root.children=sorted_group_nodes 111 | return root 112 | 113 | 114 | def get_graph_and_reordered_nodes(root,data): 115 | sorted_valid_line=[] 116 | edges={"edges":[]} 117 | valid_line=data["valid_line"] 118 | for child in root.children: 119 | sorted_valid_line.append(get_item(child.id,valid_line)) 120 | get_relationship(child,root.children,edges) 121 | for node in child.children: 122 | edges["edges"].append({"head": child.id, "tail": node.id, "rel": "child"}) 123 | edges["edges"].append({"head": node.id, "tail": child.id, "rel": "parent"}) 124 | sorted_valid_line.append(get_item(node.id, valid_line)) 125 | get_relationship(node, child.children, edges) 126 | assert len(valid_line)==len(sorted_valid_line) 127 | data["valid_line"]=sorted_valid_line 128 | com_edge=[] 129 | for edge in edges["edges"]: 130 | com_edge.append(edge) 131 | if {"head": edge["tail"], "tail": edge["head"], "rel": OPPOSITE[edge["rel"]]} not in edges["edges"]: 132 | com_edge.append({"head": edge["tail"], "tail": edge["head"], "rel": OPPOSITE[edge["rel"]]}) 133 | edges["edges"]=com_edge 134 | return edges,data 135 | 136 | 137 | def graph_builder(path): 138 | data_types=["train", "dev", "test"] 139 | for data_type in data_types: 140 | json_path=os.path.join(path,data_type,"json") 141 | graph_path=os.path.join(path,data_type,"graph") 142 | if not os.path.exists(graph_path): 143 | os.mkdir(graph_path) 144 | reordered_json_path=os.path.join(path,data_type,"reordered_json") 145 | if not os.path.exists(reordered_json_path): 146 | os.mkdir(reordered_json_path) 147 | for filename in os.listdir(json_path): 148 | file_path=os.path.join(json_path,filename) 149 | reordered_file_path = os.path.join(reordered_json_path, filename) 150 | graph_file_path = os.path.join(graph_path, filename) 151 | data=json_loader(file_path) 152 | data=data_preprocess(data) 153 | tree=tree_builder(data) 154 | edges,reordered_data=get_graph_and_reordered_nodes(tree,data) 155 | assert len(data["valid_line"])==len(reordered_data["valid_line"]) 156 | json_saver(edges,graph_file_path) 157 | json_saver(reordered_data,reordered_file_path) -------------------------------------------------------------------------------- /data/data_collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, List, Optional, Tuple, Union 4 | 5 | from transformers import BatchEncoding, PreTrainedTokenizerBase 6 | from transformers.data.data_collator import ( 7 | DataCollatorMixin, 8 | _torch_collate_batch, 9 | ) 10 | from transformers.file_utils import PaddingStrategy 11 | 12 | from typing import NewType 13 | InputDataClass = NewType("InputDataClass", Any) 14 | 15 | def pre_calc_rel_mat(segment_ids): 16 | valid_span = torch.zeros((segment_ids.shape[0], segment_ids.shape[1], segment_ids.shape[1]), 17 | device=segment_ids.device, dtype=torch.bool) 18 | for i in range(segment_ids.shape[0]): 19 | for j in range(segment_ids.shape[1]): 20 | valid_span[i, j, :] = segment_ids[i, :] == segment_ids[i, j] 21 | 22 | return valid_span 23 | 24 | @dataclass 25 | class DataCollatorForKeyValueExtraction(DataCollatorMixin): 26 | """ 27 | Data collator that will dynamically pad the inputs received, as well as the labels. 28 | Args: 29 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): 30 | The tokenizer used for encoding the data. 31 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): 32 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 33 | among: 34 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 35 | sequence if provided). 36 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 37 | maximum acceptable input length for the model if that argument is not provided. 38 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 39 | different lengths). 40 | max_length (:obj:`int`, `optional`): 41 | Maximum length of the returned list and optionally padding length (see above). 42 | pad_to_multiple_of (:obj:`int`, `optional`): 43 | If set will pad the sequence to a multiple of the provided value. 44 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 45 | 7.5 (Volta). 46 | label_pad_token_id (:obj:`int`, `optional`, defaults to -100): 47 | The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions). 48 | """ 49 | 50 | tokenizer: PreTrainedTokenizerBase 51 | padding: Union[bool, str, PaddingStrategy] = True 52 | max_length: Optional[int] = None 53 | pad_to_multiple_of: Optional[int] = None 54 | label_pad_token_id: int = -100 55 | 56 | def __call__(self, features): 57 | label_name = "label" if "label" in features[0].keys() else "labels" 58 | labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None 59 | 60 | images = None 61 | if "images" in features[0]: 62 | images = torch.stack([torch.tensor(d.pop("images")) for d in features]) 63 | IMAGE_LEN = int(images.shape[-1] / 16) * int(images.shape[-1] / 16) + 1 64 | 65 | batch = self.tokenizer.pad( 66 | features, 67 | padding=self.padding, 68 | max_length=self.max_length, 69 | pad_to_multiple_of=self.pad_to_multiple_of, 70 | # Conversion to tensors will fail if we have labels as they are not of the same length yet. 71 | return_tensors="pt" if labels is None else None, 72 | ) 73 | 74 | if images is not None: 75 | batch["images"] = images 76 | batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) and k == 'attention_mask' else v 77 | for k, v in batch.items()} 78 | visual_attention_mask = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long) 79 | batch["attention_mask"] = torch.cat([batch['attention_mask'], visual_attention_mask], dim=1) 80 | 81 | if labels is None: 82 | return batch 83 | 84 | has_bbox_input = "bbox" in features[0] 85 | has_position_input = "position_ids" in features[0] 86 | padding_idx=self.tokenizer.pad_token_id 87 | sequence_length = torch.tensor(batch["input_ids"]).shape[1] 88 | padding_side = self.tokenizer.padding_side 89 | if padding_side == "right": 90 | batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels] 91 | if has_bbox_input: 92 | batch["bbox"] = [bbox + [[0, 0, 0, 0]] * (sequence_length - len(bbox)) for bbox in batch["bbox"]] 93 | if has_position_input: 94 | batch["position_ids"] = [position_id + [padding_idx] * (sequence_length - len(position_id)) 95 | for position_id in batch["position_ids"]] 96 | 97 | else: 98 | batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels] 99 | if has_bbox_input: 100 | batch["bbox"] = [[[0, 0, 0, 0]] * (sequence_length - len(bbox)) + bbox for bbox in batch["bbox"]] 101 | if has_position_input: 102 | batch["position_ids"] = [[padding_idx] * (sequence_length - len(position_id)) 103 | + position_id for position_id in batch["position_ids"]] 104 | 105 | if 'segment_ids' in batch: 106 | assert 'position_ids' in batch 107 | for i in range(len(batch['segment_ids'])): 108 | batch['segment_ids'][i] = batch['segment_ids'][i] + [batch['segment_ids'][i][-1] + 1] * (sequence_length - len(batch['segment_ids'][i])) + [ 109 | batch['segment_ids'][i][-1] + 2] * IMAGE_LEN 110 | 111 | batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) else v for k, v in batch.items()} 112 | 113 | if 'segment_ids' in batch: 114 | valid_span = pre_calc_rel_mat( 115 | segment_ids=batch['segment_ids'] 116 | ) 117 | batch['valid_span'] = valid_span 118 | del batch['segment_ids'] 119 | 120 | if images is not None: 121 | visual_labels = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long) * -100 122 | batch["labels"] = torch.cat([batch['labels'], visual_labels], dim=1) 123 | 124 | return batch 125 | -------------------------------------------------------------------------------- /examples/run_cord.py: -------------------------------------------------------------------------------- 1 | # Refer to code of layoutlmv3 2 | #!/usr/bin/env python 3 | # coding=utf-8 4 | import logging 5 | import os 6 | import sys 7 | from dataclasses import dataclass, field 8 | from typing import Optional 9 | 10 | import numpy as np 11 | from datasets import ClassLabel, load_dataset, load_metric 12 | 13 | import transformers 14 | 15 | from transformers import ( 16 | HfArgumentParser, 17 | PreTrainedTokenizerFast, 18 | Trainer, 19 | TrainingArguments, 20 | set_seed, 21 | ) 22 | from transformers.trainer_utils import get_last_checkpoint, is_main_process 23 | 24 | 25 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 26 | # check_min_version("4.5.0") 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | from timm.data.constants import \ 31 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 32 | from torchvision import transforms 33 | import torch 34 | 35 | sys.path.append("..") 36 | 37 | from data.data_collator import DataCollatorForKeyValueExtraction 38 | 39 | from utils.image_utils import RandomResizedCropAndInterpolationWithTwoPic, pil_loader, Compose 40 | from utils.graph_utils import set_nodes, set_edges 41 | 42 | 43 | 44 | @dataclass 45 | class ModelArguments: 46 | """ 47 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 48 | """ 49 | 50 | model_name_or_path: str = field( 51 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 52 | ) 53 | config_name: Optional[str] = field( 54 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 55 | ) 56 | tokenizer_name: Optional[str] = field( 57 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 58 | ) 59 | cache_dir: Optional[str] = field( 60 | default=None, 61 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 62 | ) 63 | model_revision: str = field( 64 | default="main", 65 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 66 | ) 67 | use_auth_token: bool = field( 68 | default=False, 69 | metadata={ 70 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 71 | "with private models)." 72 | }, 73 | ) 74 | 75 | 76 | @dataclass 77 | class DataTrainingArguments: 78 | """ 79 | Arguments pertaining to what data we are going to input our model for training and eval. 80 | """ 81 | 82 | task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."}) 83 | dataset_name: Optional[str] = field( 84 | default='funsd', metadata={"help": "The name of the dataset to use (via the datasets library)."} 85 | ) 86 | dataset_config_name: Optional[str] = field( 87 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 88 | ) 89 | train_file: Optional[str] = field( 90 | default=None, metadata={"help": "The input training data file (a csv or JSON file)."} 91 | ) 92 | validation_file: Optional[str] = field( 93 | default=None, 94 | metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."}, 95 | ) 96 | test_file: Optional[str] = field( 97 | default=None, 98 | metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."}, 99 | ) 100 | overwrite_cache: bool = field( 101 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 102 | ) 103 | preprocessing_num_workers: Optional[int] = field( 104 | default=None, 105 | metadata={"help": "The number of processes to use for the preprocessing."}, 106 | ) 107 | pad_to_max_length: bool = field( 108 | default=True, 109 | metadata={ 110 | "help": "Whether to pad all samples to model maximum sentence length. " 111 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 112 | "efficient on GPU but very bad for TPU." 113 | }, 114 | ) 115 | max_train_samples: Optional[int] = field( 116 | default=None, 117 | metadata={ 118 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 119 | "value if set." 120 | }, 121 | ) 122 | max_val_samples: Optional[int] = field( 123 | default=None, 124 | metadata={ 125 | "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " 126 | "value if set." 127 | }, 128 | ) 129 | max_test_samples: Optional[int] = field( 130 | default=None, 131 | metadata={ 132 | "help": "For debugging purposes or quicker training, truncate the number of test examples to this " 133 | "value if set." 134 | }, 135 | ) 136 | label_all_tokens: bool = field( 137 | default=False, 138 | metadata={ 139 | "help": "Whether to put the label for one word on all tokens of generated by that word or just on the " 140 | "one (in which case the other tokens will have a padding index)." 141 | }, 142 | ) 143 | return_entity_level_metrics: bool = field( 144 | default=False, 145 | metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."}, 146 | ) 147 | segment_level_layout: bool = field(default=True) 148 | visual_embed: bool = field(default=True) 149 | data_dir: Optional[str] = field(default=None) 150 | input_size: int = field(default=224, metadata={"help": "images input size for backbone"}) 151 | second_input_size: int = field(default=112, metadata={"help": "images input size for discrete vae"}) 152 | train_interpolation: str = field( 153 | default='bicubic', metadata={"help": "Training interpolation (random, bilinear, bicubic)"}) 154 | second_interpolation: str = field( 155 | default='lanczos', metadata={"help": "Interpolation for discrete vae (random, bilinear, bicubic)"}) 156 | imagenet_default_mean_and_std: bool = field(default=False, metadata={"help": ""}) 157 | 158 | 159 | 160 | def main(): 161 | 162 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 163 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 164 | # If we pass only one argument to the script and it's the path to a json file, 165 | # let's parse it to get our arguments. 166 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 167 | else: 168 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 169 | 170 | # Detecting last checkpoint. 171 | last_checkpoint = None 172 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 173 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 174 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 175 | raise ValueError( 176 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 177 | "Use --overwrite_output_dir to overcome." 178 | ) 179 | elif last_checkpoint is not None: 180 | logger.info( 181 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 182 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 183 | ) 184 | 185 | # Setup logging 186 | logging.basicConfig( 187 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 188 | datefmt="%m/%d/%Y %H:%M:%S", 189 | handlers=[logging.StreamHandler(sys.stdout)], 190 | ) 191 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) 192 | 193 | # Log on each process the small summary: 194 | logger.warning( 195 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 196 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 197 | ) 198 | # Set the verbosity to info of the Transformers logger (on main process only): 199 | if is_main_process(training_args.local_rank): 200 | transformers.utils.logging.set_verbosity_info() 201 | transformers.utils.logging.enable_default_handler() 202 | transformers.utils.logging.enable_explicit_format() 203 | logger.info(f"Training/evaluation parameters {training_args}") 204 | 205 | # Set seed before initializing model. 206 | set_seed(training_args.seed) 207 | 208 | if data_args.dataset_name == 'cord': 209 | import data.cord.cord 210 | datasets = load_dataset(os.path.abspath(data.cord.cord.__file__), cache_dir=model_args.cache_dir) 211 | else: 212 | raise NotImplementedError() 213 | 214 | if training_args.do_train: 215 | column_names = datasets["train"].column_names 216 | features = datasets["train"].features 217 | else: 218 | column_names = datasets["test"].column_names 219 | features = datasets["test"].features 220 | 221 | text_column_name = "words" if "words" in column_names else "tokens" 222 | 223 | label_column_name = ( 224 | f"{data_args.task_name}_tags" if f"{data_args.task_name}_tags" in column_names else column_names[1] 225 | ) 226 | 227 | remove_columns = column_names 228 | 229 | # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the 230 | # unique labels. 231 | def get_label_list(labels): 232 | unique_labels = set() 233 | for label in labels: 234 | unique_labels = unique_labels | set(label) 235 | label_list = list(unique_labels) 236 | label_list.sort() 237 | return label_list 238 | 239 | if isinstance(features[label_column_name].feature, ClassLabel): 240 | label_list = features[label_column_name].feature.names 241 | # No need to convert the labels since they are already ints. 242 | label_to_id = {i: i for i in range(len(label_list))} 243 | else: 244 | label_list = get_label_list(datasets["train"][label_column_name]) 245 | label_to_id = {l: i for i, l in enumerate(label_list)} 246 | num_labels = len(label_list) 247 | 248 | # Load pretrained model and tokenizer 249 | # 250 | # Distributed training: 251 | # The .from_pretrained methods guarantee that only one local process can concurrently 252 | # download model & vocab. 253 | from model.configuration_graphlayoutlm import GraphLayoutLMConfig 254 | config = GraphLayoutLMConfig.from_pretrained( 255 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 256 | num_labels=num_labels, 257 | finetuning_task=data_args.task_name, 258 | cache_dir=model_args.cache_dir, 259 | revision=model_args.model_revision, 260 | input_size=data_args.input_size, 261 | use_auth_token=True if model_args.use_auth_token else None, 262 | ) 263 | from model.tokenization_graphlayoutlm_fast import GraphLayoutLMTokenizerFast 264 | tokenizer = GraphLayoutLMTokenizerFast.from_pretrained( 265 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 266 | tokenizer_file=None, # avoid loading from a cached file of the pre-trained model in another machine 267 | cache_dir=model_args.cache_dir, 268 | use_fast=True, 269 | add_prefix_space=True, 270 | revision=model_args.model_revision, 271 | use_auth_token=True if model_args.use_auth_token else None, 272 | ) 273 | from model.graphlayoutlm import GraphLayoutLMForTokenClassification 274 | model=GraphLayoutLMForTokenClassification.from_pretrained( 275 | model_args.model_name_or_path, 276 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 277 | config=config, 278 | cache_dir=model_args.cache_dir, 279 | revision=model_args.model_revision, 280 | use_auth_token=True if model_args.use_auth_token else None, 281 | ) 282 | 283 | # Tokenizer check: this script requires a fast tokenizer. 284 | if not isinstance(tokenizer, PreTrainedTokenizerFast): 285 | raise ValueError( 286 | "This example script only works for models that have a fast tokenizer. Checkout the big table of models " 287 | "at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this " 288 | "requirement" 289 | ) 290 | 291 | # Preprocessing the dataset 292 | # Padding strategy 293 | padding = "max_length" if data_args.pad_to_max_length else False 294 | 295 | if data_args.visual_embed: 296 | imagenet_default_mean_and_std = data_args.imagenet_default_mean_and_std 297 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 298 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 299 | common_transform = Compose([ 300 | # transforms.ColorJitter(0.4, 0.4, 0.4), 301 | # transforms.RandomHorizontalFlip(p=0.5), 302 | RandomResizedCropAndInterpolationWithTwoPic( 303 | size=data_args.input_size, interpolation=data_args.train_interpolation), 304 | ]) 305 | 306 | patch_transform = transforms.Compose([ 307 | transforms.ToTensor(), 308 | transforms.Normalize( 309 | mean=torch.tensor(mean), 310 | std=torch.tensor(std)) 311 | ]) 312 | 313 | # Tokenize all texts and align the labels with them. 314 | def tokenize_and_align_labels(examples, augmentation=False): 315 | tokenized_inputs = tokenizer( 316 | examples[text_column_name],boxes=examples["bboxes"],word_labels=examples["ner_tags"], 317 | padding=False, 318 | truncation=True, 319 | return_overflowing_tokens=True, 320 | ) 321 | 322 | 323 | labels = [] 324 | bboxes = [] 325 | images = [] 326 | nodes=[] 327 | edges=[] 328 | 329 | for batch_index in range(len(tokenized_inputs["input_ids"])): 330 | word_ids = tokenized_inputs.word_ids(batch_index=batch_index) 331 | org_batch_index = tokenized_inputs["overflow_to_sample_mapping"][batch_index] 332 | 333 | label = examples[label_column_name][org_batch_index] 334 | bbox = examples["bboxes"][org_batch_index] 335 | 336 | node_ids=examples["node_ids"][org_batch_index] 337 | 338 | previous_word_idx = None 339 | label_ids = [] 340 | bbox_inputs = [] 341 | new_node_ids=[] 342 | for word_idx in word_ids: 343 | # Special tokens have a word id that is None. We set the label to -100 so they are automatically 344 | # ignored in the loss function. 345 | if word_idx is None: 346 | label_ids.append(-100) 347 | bbox_inputs.append([0, 0, 0, 0]) 348 | new_node_ids.append(-1) 349 | # We set the label for the first token of each word. 350 | elif word_idx != previous_word_idx: 351 | label_ids.append(label_to_id[label[word_idx]]) 352 | bbox_inputs.append(bbox[word_idx]) 353 | new_node_ids.append(node_ids[word_idx]) 354 | # For the other tokens in a word, we set the label to either the current label or -100, depending on 355 | # the label_all_tokens flag. 356 | else: 357 | label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100) 358 | bbox_inputs.append(bbox[word_idx]) 359 | new_node_ids.append(node_ids[word_idx]) 360 | previous_word_idx = word_idx 361 | labels.append(label_ids) 362 | bboxes.append(bbox_inputs) 363 | 364 | if data_args.visual_embed: 365 | ipath = examples["image_path"][org_batch_index] 366 | img = pil_loader(ipath) 367 | for_patches, _ = common_transform(img, augmentation=augmentation) 368 | patch = patch_transform(for_patches) 369 | images.append(patch) 370 | 371 | new_node_data,new_ids=set_nodes(new_node_ids) 372 | 373 | new_edges_data=set_edges(examples["edges"][org_batch_index],new_ids) 374 | nodes.append(new_node_data) 375 | edges.append(new_edges_data) 376 | 377 | # build graph mask 378 | graph_mask_list=[] 379 | input_len=709 380 | for nodes_data,edges_data in zip(nodes,edges): 381 | edges_len=len(edges_data) 382 | graph_mask= -9e15 * np.ones((input_len,input_len)) 383 | for edge_i in range(edges_len): 384 | edge=edges_data[edge_i] 385 | if edge[0]==-1: 386 | break 387 | a_node_index,b_node_index=edge[0],edge[1] 388 | [a_start,a_end]=nodes_data[a_node_index] 389 | [b_start,b_end]=nodes_data[b_node_index] 390 | graph_mask[a_start:a_end+1,b_start:b_end+1]=0 391 | graph_mask_list.append(graph_mask) 392 | tokenized_inputs["graph_mask"]=graph_mask_list 393 | 394 | 395 | tokenized_inputs["labels"] = labels 396 | tokenized_inputs["bbox"] = bboxes 397 | if data_args.visual_embed: 398 | tokenized_inputs["images"] = images 399 | 400 | return tokenized_inputs 401 | 402 | if training_args.do_train: 403 | if "train" not in datasets: 404 | raise ValueError("--do_train requires a train dataset") 405 | train_dataset = datasets["train"] 406 | if data_args.max_train_samples is not None: 407 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 408 | 409 | 410 | train_dataset = train_dataset.map( 411 | tokenize_and_align_labels, 412 | batched=True, 413 | remove_columns=remove_columns, 414 | num_proc=data_args.preprocessing_num_workers, 415 | load_from_cache_file=not data_args.overwrite_cache, 416 | ) 417 | 418 | 419 | if training_args.do_eval: 420 | validation_name = "test" 421 | if validation_name not in datasets: 422 | raise ValueError("--do_eval requires a validation dataset") 423 | eval_dataset = datasets[validation_name] 424 | if data_args.max_val_samples is not None: 425 | eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) 426 | eval_dataset = eval_dataset.map( 427 | tokenize_and_align_labels, 428 | batched=True, 429 | remove_columns=remove_columns, 430 | num_proc=data_args.preprocessing_num_workers, 431 | load_from_cache_file=not data_args.overwrite_cache, 432 | ) 433 | 434 | if training_args.do_predict: 435 | if "test" not in datasets: 436 | raise ValueError("--do_predict requires a test dataset") 437 | test_dataset = datasets["test"] 438 | if data_args.max_test_samples is not None: 439 | test_dataset = test_dataset.select(range(data_args.max_test_samples)) 440 | test_dataset = test_dataset.map( 441 | tokenize_and_align_labels, 442 | batched=True, 443 | remove_columns=remove_columns, 444 | num_proc=data_args.preprocessing_num_workers, 445 | load_from_cache_file=not data_args.overwrite_cache, 446 | ) 447 | 448 | # Data collator 449 | data_collator = DataCollatorForKeyValueExtraction( 450 | tokenizer, 451 | pad_to_multiple_of=8 if training_args.fp16 else None, 452 | padding=padding, 453 | max_length=512, 454 | ) 455 | 456 | # Metrics 457 | metric = load_metric("../metrics/seqeval") 458 | 459 | def compute_metrics(p): 460 | predictions, labels = p 461 | predictions = np.argmax(predictions, axis=2) 462 | 463 | # Remove ignored index (special tokens) 464 | true_predictions = [ 465 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100] 466 | for prediction, label in zip(predictions, labels) 467 | ] 468 | true_labels = [ 469 | [label_list[l] for (p, l) in zip(prediction, label) if l != -100] 470 | for prediction, label in zip(predictions, labels) 471 | ] 472 | 473 | results = metric.compute(predictions=true_predictions, references=true_labels) 474 | if data_args.return_entity_level_metrics: 475 | # Unpack nested dictionaries 476 | final_results = {} 477 | for key, value in results.items(): 478 | if isinstance(value, dict): 479 | for n, v in value.items(): 480 | final_results[f"{key}_{n}"] = v 481 | else: 482 | final_results[key] = value 483 | return final_results 484 | else: 485 | return { 486 | "precision": results["overall_precision"], 487 | "recall": results["overall_recall"], 488 | "f1": results["overall_f1"], 489 | "accuracy": results["overall_accuracy"], 490 | } 491 | 492 | # Initialize our Trainer 493 | trainer = Trainer( 494 | model=model, 495 | args=training_args, 496 | train_dataset=train_dataset if training_args.do_train else None, 497 | eval_dataset=eval_dataset if training_args.do_eval else None, 498 | tokenizer=tokenizer, 499 | data_collator=data_collator, 500 | compute_metrics=compute_metrics, 501 | ) 502 | 503 | # Training 504 | if training_args.do_train: 505 | checkpoint = last_checkpoint if last_checkpoint else None 506 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 507 | metrics = train_result.metrics 508 | trainer.save_model() # Saves the tokenizer too for easy upload 509 | 510 | max_train_samples = ( 511 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 512 | ) 513 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 514 | 515 | trainer.log_metrics("train", metrics) 516 | trainer.save_metrics("train", metrics) 517 | trainer.save_state() 518 | 519 | # Evaluation 520 | if training_args.do_eval: 521 | logger.info("*** Evaluate ***") 522 | 523 | metrics = trainer.evaluate() 524 | 525 | max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) 526 | metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) 527 | 528 | trainer.log_metrics("eval", metrics) 529 | trainer.save_metrics("eval", metrics) 530 | 531 | # Predict 532 | if training_args.do_predict: 533 | logger.info("*** Predict ***") 534 | 535 | predictions, labels, metrics = trainer.predict(test_dataset) 536 | predictions = np.argmax(predictions, axis=2) 537 | 538 | # Remove ignored index (special tokens) 539 | true_predictions = [ 540 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100] 541 | for prediction, label in zip(predictions, labels) 542 | ] 543 | 544 | trainer.log_metrics("test", metrics) 545 | trainer.save_metrics("test", metrics) 546 | 547 | # Save predictions 548 | output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt") 549 | if trainer.is_world_process_zero(): 550 | with open(output_test_predictions_file, "w") as writer: 551 | for prediction in true_predictions: 552 | writer.write(" ".join(prediction) + "\n") 553 | 554 | 555 | def _mp_fn(index): 556 | # For xla_spawn (TPUs) 557 | main() 558 | 559 | 560 | if __name__ == "__main__": 561 | main() 562 | -------------------------------------------------------------------------------- /metrics/seqeval/seqeval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Datasets Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ seqeval metric. """ 15 | 16 | import importlib 17 | from typing import List, Optional, Union 18 | 19 | from seqeval.metrics import accuracy_score, classification_report 20 | 21 | import datasets 22 | 23 | 24 | _CITATION = """\ 25 | @inproceedings{ramshaw-marcus-1995-text, 26 | title = "Text Chunking using Transformation-Based Learning", 27 | author = "Ramshaw, Lance and 28 | Marcus, Mitch", 29 | booktitle = "Third Workshop on Very Large Corpora", 30 | year = "1995", 31 | url = "https://www.aclweb.org/anthology/W95-0107", 32 | } 33 | @misc{seqeval, 34 | title={{seqeval}: A Python framework for sequence labeling evaluation}, 35 | url={https://github.com/chakki-works/seqeval}, 36 | note={Software available from https://github.com/chakki-works/seqeval}, 37 | author={Hiroki Nakayama}, 38 | year={2018}, 39 | } 40 | """ 41 | 42 | _DESCRIPTION = """\ 43 | seqeval is a Python framework for sequence labeling evaluation. 44 | seqeval can evaluate the performance of chunking tasks such as named-entity recognition, part-of-speech tagging, semantic role labeling and so on. 45 | 46 | This is well-tested by using the Perl script conlleval, which can be used for 47 | measuring the performance of a system that has processed the CoNLL-2000 shared task data. 48 | 49 | seqeval supports following formats: 50 | IOB1 51 | IOB2 52 | IOE1 53 | IOE2 54 | IOBES 55 | 56 | See the [README.md] file at https://github.com/chakki-works/seqeval for more information. 57 | """ 58 | 59 | _KWARGS_DESCRIPTION = """ 60 | Produces labelling scores along with its sufficient statistics 61 | from a source against one or more references. 62 | 63 | Args: 64 | predictions: List of List of predicted labels (Estimated targets as returned by a tagger) 65 | references: List of List of reference labels (Ground truth (correct) target values) 66 | suffix: True if the IOB prefix is after type, False otherwise. default: False 67 | scheme: Specify target tagging scheme. Should be one of ["IOB1", "IOB2", "IOE1", "IOE2", "IOBES", "BILOU"]. 68 | default: None 69 | mode: Whether to count correct entity labels with incorrect I/B tags as true positives or not. 70 | If you want to only count exact matches, pass mode="strict". default: None. 71 | sample_weight: Array-like of shape (n_samples,), weights for individual samples. default: None 72 | zero_division: Which value to substitute as a metric value when encountering zero division. Should be on of 0, 1, 73 | "warn". "warn" acts as 0, but the warning is raised. 74 | 75 | Returns: 76 | 'scores': dict. Summary of the scores for overall and per type 77 | Overall: 78 | 'accuracy': accuracy, 79 | 'precision': precision, 80 | 'recall': recall, 81 | 'f1': F1 score, also known as balanced F-score or F-measure, 82 | Per type: 83 | 'precision': precision, 84 | 'recall': recall, 85 | 'f1': F1 score, also known as balanced F-score or F-measure 86 | Examples: 87 | 88 | >>> predictions = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] 89 | >>> references = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] 90 | >>> seqeval = datasets.load_metric("seqeval") 91 | >>> results = seqeval.compute(predictions=predictions, references=references) 92 | >>> print(list(results.keys())) 93 | ['MISC', 'PER', 'overall_precision', 'overall_recall', 'overall_f1', 'overall_accuracy'] 94 | >>> print(results["overall_f1"]) 95 | 0.5 96 | >>> print(results["PER"]["f1"]) 97 | 1.0 98 | """ 99 | 100 | 101 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 102 | class Seqeval(datasets.Metric): 103 | def _info(self): 104 | return datasets.MetricInfo( 105 | description=_DESCRIPTION, 106 | citation=_CITATION, 107 | homepage="https://github.com/chakki-works/seqeval", 108 | inputs_description=_KWARGS_DESCRIPTION, 109 | features=datasets.Features( 110 | { 111 | "predictions": datasets.Sequence(datasets.Value("string", id="label"), id="sequence"), 112 | "references": datasets.Sequence(datasets.Value("string", id="label"), id="sequence"), 113 | } 114 | ), 115 | codebase_urls=["https://github.com/chakki-works/seqeval"], 116 | reference_urls=["https://github.com/chakki-works/seqeval"], 117 | ) 118 | 119 | def _compute( 120 | self, 121 | predictions, 122 | references, 123 | suffix: bool = False, 124 | scheme: Optional[str] = None, 125 | mode: Optional[str] = None, 126 | sample_weight: Optional[List[int]] = None, 127 | zero_division: Union[str, int] = "warn", 128 | ): 129 | if scheme is not None: 130 | try: 131 | scheme_module = importlib.import_module("seqeval.scheme") 132 | scheme = getattr(scheme_module, scheme) 133 | except AttributeError: 134 | raise ValueError(f"Scheme should be one of [IOB1, IOB2, IOE1, IOE2, IOBES, BILOU], got {scheme}") 135 | report = classification_report( 136 | y_true=references, 137 | y_pred=predictions, 138 | suffix=suffix, 139 | output_dict=True, 140 | scheme=scheme, 141 | mode=mode, 142 | sample_weight=sample_weight, 143 | zero_division=zero_division, 144 | ) 145 | report.pop("macro avg") 146 | report.pop("weighted avg") 147 | overall_score = report.pop("micro avg") 148 | 149 | scores = { 150 | type_name: { 151 | "precision": score["precision"], 152 | "recall": score["recall"], 153 | "f1": score["f1-score"], 154 | "number": score["support"], 155 | } 156 | for type_name, score in report.items() 157 | } 158 | scores["overall_precision"] = overall_score["precision"] 159 | scores["overall_recall"] = overall_score["recall"] 160 | scores["overall_f1"] = overall_score["f1-score"] 161 | scores["overall_accuracy"] = accuracy_score(y_true=references, y_pred=predictions) 162 | 163 | return scores -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/basemodel.py: -------------------------------------------------------------------------------- 1 | """The base Model LayoutLMv3. """ 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.checkpoint 8 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 9 | 10 | from transformers import apply_chunking_to_forward 11 | from transformers.modeling_outputs import ( 12 | BaseModelOutputWithPastAndCrossAttentions, 13 | BaseModelOutputWithPoolingAndCrossAttentions, 14 | ) 15 | from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer 16 | from transformers.models.roberta.modeling_roberta import ( 17 | RobertaIntermediate, 18 | RobertaOutput, 19 | RobertaSelfOutput, 20 | ) 21 | from transformers.utils import logging 22 | 23 | from transformers.models.layoutlmv3.configuration_layoutlmv3 import LayoutLMv3Config 24 | from timm.models.layers import to_2tuple 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | 30 | class PatchEmbed(nn.Module): 31 | """ Image to Patch Embedding 32 | """ 33 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 34 | super().__init__() 35 | img_size = to_2tuple(img_size) 36 | patch_size = to_2tuple(patch_size) 37 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 38 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 39 | # The following variables are used in detection mycheckpointer.py 40 | self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 41 | self.num_patches_w = self.patch_shape[0] 42 | self.num_patches_h = self.patch_shape[1] 43 | 44 | def forward(self, x, position_embedding=None): 45 | x = self.proj(x) 46 | 47 | if position_embedding is not None: 48 | # interpolate the position embedding to the corresponding size 49 | position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3, 1, 2) 50 | Hp, Wp = x.shape[2], x.shape[3] 51 | position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic') 52 | x = x + position_embedding 53 | 54 | x = x.flatten(2).transpose(1, 2) 55 | return x 56 | 57 | class LayoutLMv3Embeddings(nn.Module): 58 | """ 59 | Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. 60 | """ 61 | 62 | # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ 63 | def __init__(self, config): 64 | super().__init__() 65 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) 66 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 67 | 68 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 69 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 70 | 71 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 72 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) 73 | 74 | # End copy 75 | self.padding_idx = config.pad_token_id 76 | self.position_embeddings = nn.Embedding( 77 | config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx 78 | ) 79 | 80 | self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size) 81 | self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size) 82 | self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size) 83 | self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size) 84 | 85 | def _calc_spatial_position_embeddings(self, bbox): 86 | try: 87 | assert torch.all(0 <= bbox) and torch.all(bbox <= 1023) 88 | left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0]) 89 | upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) 90 | right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) 91 | lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) 92 | except IndexError as e: 93 | raise IndexError("The :obj:`bbox` coordinate values should be within 0-1000 range.") from e 94 | 95 | h_position_embeddings = self.h_position_embeddings(torch.clip(bbox[:, :, 3] - bbox[:, :, 1], 0, 1023)) 96 | w_position_embeddings = self.w_position_embeddings(torch.clip(bbox[:, :, 2] - bbox[:, :, 0], 0, 1023)) 97 | 98 | # below is the difference between LayoutLMEmbeddingsV2 (torch.cat) and LayoutLMEmbeddingsV1 (add) 99 | spatial_position_embeddings = torch.cat( 100 | [ 101 | left_position_embeddings, 102 | upper_position_embeddings, 103 | right_position_embeddings, 104 | lower_position_embeddings, 105 | h_position_embeddings, 106 | w_position_embeddings, 107 | ], 108 | dim=-1, 109 | ) 110 | return spatial_position_embeddings 111 | 112 | def create_position_ids_from_input_ids(self, input_ids, padding_idx, past_key_values_length=0): 113 | """ 114 | Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols 115 | are ignored. This is modified from fairseq's `utils.make_positions`. 116 | 117 | Args: 118 | x: torch.Tensor x: 119 | 120 | Returns: torch.Tensor 121 | """ 122 | # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. 123 | mask = input_ids.ne(padding_idx).int() 124 | incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask 125 | return incremental_indices.long() + padding_idx 126 | 127 | def forward( 128 | self, 129 | input_ids=None, 130 | bbox=None, 131 | token_type_ids=None, 132 | position_ids=None, 133 | inputs_embeds=None, 134 | past_key_values_length=0, 135 | ): 136 | if position_ids is None: 137 | if input_ids is not None: 138 | # Create the position ids from the input token ids. Any padded tokens remain padded. 139 | position_ids = self.create_position_ids_from_input_ids( 140 | input_ids, self.padding_idx, past_key_values_length).to(input_ids.device) 141 | else: 142 | position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) 143 | 144 | if input_ids is not None: 145 | input_shape = input_ids.size() 146 | else: 147 | input_shape = inputs_embeds.size()[:-1] 148 | 149 | if token_type_ids is None: 150 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) 151 | 152 | if inputs_embeds is None: 153 | inputs_embeds = self.word_embeddings(input_ids) 154 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 155 | 156 | embeddings = inputs_embeds + token_type_embeddings 157 | position_embeddings = self.position_embeddings(position_ids) 158 | embeddings += position_embeddings 159 | 160 | spatial_position_embeddings = self._calc_spatial_position_embeddings(bbox) 161 | 162 | embeddings = embeddings + spatial_position_embeddings 163 | 164 | embeddings = self.LayerNorm(embeddings) 165 | embeddings = self.dropout(embeddings) 166 | return embeddings 167 | 168 | def create_position_ids_from_inputs_embeds(self, inputs_embeds): 169 | """ 170 | We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. 171 | 172 | Args: 173 | inputs_embeds: torch.Tensor≈ 174 | 175 | Returns: torch.Tensor 176 | """ 177 | input_shape = inputs_embeds.size()[:-1] 178 | sequence_length = input_shape[1] 179 | 180 | position_ids = torch.arange( 181 | self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device 182 | ) 183 | return position_ids.unsqueeze(0).expand(input_shape) 184 | 185 | 186 | class LayoutLMv3PreTrainedModel(PreTrainedModel): 187 | """ 188 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 189 | models. 190 | """ 191 | 192 | config_class = LayoutLMv3Config 193 | base_model_prefix = "layoutlmv3" 194 | 195 | # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights 196 | def _init_weights(self, module): 197 | """Initialize the weights""" 198 | if isinstance(module, nn.Linear): 199 | # Slightly different from the TF version which uses truncated_normal for initialization 200 | # cf https://github.com/pytorch/pytorch/pull/5617 201 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 202 | if module.bias is not None: 203 | module.bias.data.zero_() 204 | elif isinstance(module, nn.Embedding): 205 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 206 | if module.padding_idx is not None: 207 | module.weight.data[module.padding_idx].zero_() 208 | elif isinstance(module, nn.LayerNorm): 209 | module.bias.data.zero_() 210 | module.weight.data.fill_(1.0) 211 | 212 | 213 | class LayoutLMv3SelfAttention(nn.Module): 214 | def __init__(self, config): 215 | super().__init__() 216 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 217 | raise ValueError( 218 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 219 | f"heads ({config.num_attention_heads})" 220 | ) 221 | 222 | self.num_attention_heads = config.num_attention_heads 223 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 224 | self.all_head_size = self.num_attention_heads * self.attention_head_size 225 | 226 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 227 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 228 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 229 | 230 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 231 | self.has_relative_attention_bias = config.has_relative_attention_bias 232 | self.has_spatial_attention_bias = config.has_spatial_attention_bias 233 | 234 | def transpose_for_scores(self, x): 235 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 236 | x = x.view(*new_x_shape) 237 | return x.permute(0, 2, 1, 3) 238 | 239 | def cogview_attn(self, attention_scores, alpha=32): 240 | ''' 241 | https://arxiv.org/pdf/2105.13290.pdf 242 | Section 2.4 Stabilization of training: Precision Bottleneck Relaxation (PB-Relax). 243 | A replacement of the original nn.Softmax(dim=-1)(attention_scores) 244 | Seems the new attention_probs will result in a slower speed and a little bias 245 | Can use torch.allclose(standard_attention_probs, cogview_attention_probs, atol=1e-08) for comparison 246 | The smaller atol (e.g., 1e-08), the better. 247 | ''' 248 | scaled_attention_scores = attention_scores / alpha 249 | max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1) 250 | # max_value = scaled_attention_scores.amax(dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1) 251 | new_attention_scores = (scaled_attention_scores - max_value) * alpha 252 | return nn.Softmax(dim=-1)(new_attention_scores) 253 | 254 | def forward( 255 | self, 256 | hidden_states, 257 | attention_mask=None, 258 | head_mask=None, 259 | encoder_hidden_states=None, 260 | encoder_attention_mask=None, 261 | past_key_value=None, 262 | output_attentions=False, 263 | rel_pos=None, 264 | rel_2d_pos=None, 265 | ): 266 | mixed_query_layer = self.query(hidden_states) 267 | 268 | # If this is instantiated as a cross-attention module, the keys 269 | # and values come from an encoder; the attention mask needs to be 270 | # such that the encoder's padding tokens are not attended to. 271 | is_cross_attention = encoder_hidden_states is not None 272 | 273 | if is_cross_attention and past_key_value is not None: 274 | # reuse k,v, cross_attentions 275 | key_layer = past_key_value[0] 276 | value_layer = past_key_value[1] 277 | attention_mask = encoder_attention_mask 278 | elif is_cross_attention: 279 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 280 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 281 | attention_mask = encoder_attention_mask 282 | elif past_key_value is not None: 283 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 284 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 285 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 286 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 287 | else: 288 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 289 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 290 | 291 | query_layer = self.transpose_for_scores(mixed_query_layer) 292 | 293 | # Take the dot product between "query" and "key" to get the raw attention scores. 294 | # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow. 295 | # Changing the computational order into QT(K/√d) alleviates the problem. (https://arxiv.org/pdf/2105.13290.pdf) 296 | attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2)) 297 | 298 | if self.has_relative_attention_bias and self.has_spatial_attention_bias: 299 | attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size) 300 | elif self.has_relative_attention_bias: 301 | attention_scores += rel_pos / math.sqrt(self.attention_head_size) 302 | 303 | # if self.has_relative_attention_bias: 304 | # attention_scores += rel_pos 305 | # if self.has_spatial_attention_bias: 306 | # attention_scores += rel_2d_pos 307 | 308 | # attention_scores = attention_scores / math.sqrt(self.attention_head_size) 309 | if attention_mask is not None: 310 | # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) 311 | attention_scores = attention_scores + attention_mask 312 | 313 | # Normalize the attention scores to probabilities. 314 | # attention_probs = nn.Softmax(dim=-1)(attention_scores) # comment the line below and use this line for speedup 315 | attention_probs = self.cogview_attn(attention_scores) # to stablize training 316 | # assert torch.allclose(attention_probs, nn.Softmax(dim=-1)(attention_scores), atol=1e-8) 317 | 318 | # This is actually dropping out entire tokens to attend to, which might 319 | # seem a bit unusual, but is taken from the original Transformer paper. 320 | attention_probs = self.dropout(attention_probs) 321 | 322 | # Mask heads if we want to 323 | if head_mask is not None: 324 | attention_probs = attention_probs * head_mask 325 | 326 | context_layer = torch.matmul(attention_probs, value_layer) 327 | 328 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 329 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 330 | context_layer = context_layer.view(*new_context_layer_shape) 331 | 332 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 333 | 334 | return outputs 335 | 336 | 337 | class LayoutLMv3Attention(nn.Module): 338 | def __init__(self, config): 339 | super().__init__() 340 | self.self = LayoutLMv3SelfAttention(config) 341 | self.output = RobertaSelfOutput(config) 342 | self.pruned_heads = set() 343 | 344 | def prune_heads(self, heads): 345 | if len(heads) == 0: 346 | return 347 | heads, index = find_pruneable_heads_and_indices( 348 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads 349 | ) 350 | 351 | # Prune linear layers 352 | self.self.query = prune_linear_layer(self.self.query, index) 353 | self.self.key = prune_linear_layer(self.self.key, index) 354 | self.self.value = prune_linear_layer(self.self.value, index) 355 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 356 | 357 | # Update hyper params and store pruned heads 358 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 359 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 360 | self.pruned_heads = self.pruned_heads.union(heads) 361 | 362 | def forward( 363 | self, 364 | hidden_states, 365 | attention_mask=None, 366 | head_mask=None, 367 | encoder_hidden_states=None, 368 | encoder_attention_mask=None, 369 | past_key_value=None, 370 | output_attentions=False, 371 | rel_pos=None, 372 | rel_2d_pos=None, 373 | ): 374 | self_outputs = self.self( 375 | hidden_states, 376 | attention_mask, 377 | head_mask, 378 | encoder_hidden_states, 379 | encoder_attention_mask, 380 | past_key_value, 381 | output_attentions, 382 | rel_pos=rel_pos, 383 | rel_2d_pos=rel_2d_pos, 384 | ) 385 | attention_output = self.output(self_outputs[0], hidden_states) 386 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 387 | return outputs 388 | 389 | 390 | class LayoutLMv3Layer(nn.Module): 391 | def __init__(self, config): 392 | super().__init__() 393 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 394 | self.seq_len_dim = 1 395 | self.attention = LayoutLMv3Attention(config) 396 | assert not config.is_decoder and not config.add_cross_attention, \ 397 | "This version do not support decoder. Please refer to RoBERTa for implementation of is_decoder." 398 | self.intermediate = RobertaIntermediate(config) 399 | self.output = RobertaOutput(config) 400 | 401 | def forward( 402 | self, 403 | hidden_states, 404 | attention_mask=None, 405 | head_mask=None, 406 | encoder_hidden_states=None, 407 | encoder_attention_mask=None, 408 | past_key_value=None, 409 | output_attentions=False, 410 | rel_pos=None, 411 | rel_2d_pos=None, 412 | ): 413 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 414 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 415 | self_attention_outputs = self.attention( 416 | hidden_states, 417 | attention_mask, 418 | head_mask, 419 | output_attentions=output_attentions, 420 | past_key_value=self_attn_past_key_value, 421 | rel_pos=rel_pos, 422 | rel_2d_pos=rel_2d_pos, 423 | ) 424 | attention_output = self_attention_outputs[0] 425 | 426 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights 427 | 428 | layer_output = apply_chunking_to_forward( 429 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output 430 | ) 431 | outputs = (layer_output,) + outputs 432 | 433 | return outputs 434 | 435 | def feed_forward_chunk(self, attention_output): 436 | intermediate_output = self.intermediate(attention_output) 437 | layer_output = self.output(intermediate_output, attention_output) 438 | return layer_output 439 | 440 | 441 | class LayoutLMv3Encoder(nn.Module): 442 | def __init__(self, config, detection=False, out_features=None): 443 | super().__init__() 444 | self.config = config 445 | self.detection = detection 446 | self.layer = nn.ModuleList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)]) 447 | self.gradient_checkpointing = False 448 | 449 | self.has_relative_attention_bias = config.has_relative_attention_bias 450 | self.has_spatial_attention_bias = config.has_spatial_attention_bias 451 | 452 | if self.has_relative_attention_bias: 453 | self.rel_pos_bins = config.rel_pos_bins 454 | self.max_rel_pos = config.max_rel_pos 455 | self.rel_pos_onehot_size = config.rel_pos_bins 456 | self.rel_pos_bias = nn.Linear(self.rel_pos_onehot_size, config.num_attention_heads, bias=False) 457 | 458 | if self.has_spatial_attention_bias: 459 | self.max_rel_2d_pos = config.max_rel_2d_pos 460 | self.rel_2d_pos_bins = config.rel_2d_pos_bins 461 | self.rel_2d_pos_onehot_size = config.rel_2d_pos_bins 462 | self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False) 463 | self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False) 464 | 465 | if self.detection: 466 | self.gradient_checkpointing = True 467 | embed_dim = self.config.hidden_size 468 | self.out_features = out_features 469 | self.out_indices = [int(name[5:]) for name in out_features] 470 | self.fpn1 = nn.Sequential( 471 | nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), 472 | # nn.SyncBatchNorm(embed_dim), 473 | nn.BatchNorm2d(embed_dim), 474 | nn.GELU(), 475 | nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), 476 | ) 477 | 478 | self.fpn2 = nn.Sequential( 479 | nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), 480 | ) 481 | 482 | self.fpn3 = nn.Identity() 483 | 484 | self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) 485 | self.ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] 486 | 487 | def relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128): 488 | ret = 0 489 | if bidirectional: 490 | num_buckets //= 2 491 | ret += (relative_position > 0).long() * num_buckets 492 | n = torch.abs(relative_position) 493 | else: 494 | n = torch.max(-relative_position, torch.zeros_like(relative_position)) 495 | # now n is in the range [0, inf) 496 | 497 | # half of the buckets are for exact increments in positions 498 | max_exact = num_buckets // 2 499 | is_small = n < max_exact 500 | 501 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance 502 | val_if_large = max_exact + ( 503 | torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) 504 | ).to(torch.long) 505 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) 506 | 507 | ret += torch.where(is_small, n, val_if_large) 508 | return ret 509 | 510 | def _cal_1d_pos_emb(self, hidden_states, position_ids, valid_span): 511 | VISUAL_NUM = 196 + 1 512 | 513 | rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1) 514 | 515 | if valid_span is not None: 516 | # for the text part, if two words are not in the same line, 517 | # set their distance to the max value (position_ids.shape[-1]) 518 | rel_pos_mat[(rel_pos_mat > 0) & (valid_span == False)] = position_ids.shape[1] 519 | rel_pos_mat[(rel_pos_mat < 0) & (valid_span == False)] = -position_ids.shape[1] 520 | 521 | # image-text, minimum distance 522 | rel_pos_mat[:, -VISUAL_NUM:, :-VISUAL_NUM] = 0 523 | rel_pos_mat[:, :-VISUAL_NUM, -VISUAL_NUM:] = 0 524 | 525 | rel_pos = self.relative_position_bucket( 526 | rel_pos_mat, 527 | num_buckets=self.rel_pos_bins, 528 | max_distance=self.max_rel_pos, 529 | ) 530 | rel_pos = F.one_hot(rel_pos, num_classes=self.rel_pos_onehot_size).type_as(hidden_states) 531 | rel_pos = self.rel_pos_bias(rel_pos).permute(0, 3, 1, 2) 532 | rel_pos = rel_pos.contiguous() 533 | return rel_pos 534 | 535 | def _cal_2d_pos_emb(self, hidden_states, bbox): 536 | position_coord_x = bbox[:, :, 0] 537 | position_coord_y = bbox[:, :, 3] 538 | rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1) 539 | rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1) 540 | rel_pos_x = self.relative_position_bucket( 541 | rel_pos_x_2d_mat, 542 | num_buckets=self.rel_2d_pos_bins, 543 | max_distance=self.max_rel_2d_pos, 544 | ) 545 | rel_pos_y = self.relative_position_bucket( 546 | rel_pos_y_2d_mat, 547 | num_buckets=self.rel_2d_pos_bins, 548 | max_distance=self.max_rel_2d_pos, 549 | ) 550 | rel_pos_x = F.one_hot(rel_pos_x, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states) 551 | rel_pos_y = F.one_hot(rel_pos_y, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states) 552 | rel_pos_x = self.rel_pos_x_bias(rel_pos_x).permute(0, 3, 1, 2) 553 | rel_pos_y = self.rel_pos_y_bias(rel_pos_y).permute(0, 3, 1, 2) 554 | rel_pos_x = rel_pos_x.contiguous() 555 | rel_pos_y = rel_pos_y.contiguous() 556 | rel_2d_pos = rel_pos_x + rel_pos_y 557 | return rel_2d_pos 558 | 559 | def forward( 560 | self, 561 | hidden_states, 562 | bbox=None, 563 | attention_mask=None, 564 | head_mask=None, 565 | encoder_hidden_states=None, 566 | encoder_attention_mask=None, 567 | past_key_values=None, 568 | use_cache=None, 569 | output_attentions=False, 570 | output_hidden_states=False, 571 | return_dict=True, 572 | position_ids=None, 573 | Hp=None, 574 | Wp=None, 575 | valid_span=None, 576 | ): 577 | all_hidden_states = () if output_hidden_states else None 578 | all_self_attentions = () if output_attentions else None 579 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 580 | 581 | next_decoder_cache = () if use_cache else None 582 | 583 | rel_pos = self._cal_1d_pos_emb(hidden_states, position_ids, valid_span) if self.has_relative_attention_bias else None 584 | rel_2d_pos = self._cal_2d_pos_emb(hidden_states, bbox) if self.has_spatial_attention_bias else None 585 | 586 | if self.detection: 587 | feat_out = {} 588 | j = 0 589 | 590 | for i, layer_module in enumerate(self.layer): 591 | if output_hidden_states: 592 | all_hidden_states = all_hidden_states + (hidden_states,) 593 | 594 | layer_head_mask = head_mask[i] if head_mask is not None else None 595 | past_key_value = past_key_values[i] if past_key_values is not None else None 596 | 597 | if self.gradient_checkpointing and self.training: 598 | 599 | if use_cache: 600 | logger.warning( 601 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 602 | ) 603 | use_cache = False 604 | 605 | def create_custom_forward(module): 606 | def custom_forward(*inputs): 607 | return module(*inputs) 608 | # return module(*inputs, past_key_value, output_attentions, rel_pos, rel_2d_pos) 609 | # The above line will cause error: 610 | # RuntimeError: Trying to backward through the graph a second time 611 | # (or directly access saved tensors after they have already been freed). 612 | return custom_forward 613 | 614 | layer_outputs = torch.utils.checkpoint.checkpoint( 615 | create_custom_forward(layer_module), 616 | hidden_states, 617 | attention_mask, 618 | layer_head_mask, 619 | encoder_hidden_states, 620 | encoder_attention_mask, 621 | past_key_value, 622 | output_attentions, 623 | rel_pos, 624 | rel_2d_pos 625 | ) 626 | else: 627 | layer_outputs = layer_module( 628 | hidden_states, 629 | attention_mask, 630 | layer_head_mask, 631 | encoder_hidden_states, 632 | encoder_attention_mask, 633 | past_key_value, 634 | output_attentions, 635 | rel_pos=rel_pos, 636 | rel_2d_pos=rel_2d_pos, 637 | ) 638 | 639 | hidden_states = layer_outputs[0] 640 | if use_cache: 641 | next_decoder_cache += (layer_outputs[-1],) 642 | if output_attentions: 643 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 644 | if self.config.add_cross_attention: 645 | all_cross_attentions = all_cross_attentions + (layer_outputs[2],) 646 | 647 | if self.detection and i in self.out_indices: 648 | xp = hidden_states[:, -Hp*Wp:, :].permute(0, 2, 1).reshape(len(hidden_states), -1, Hp, Wp) 649 | feat_out[self.out_features[j]] = self.ops[j](xp.contiguous()) 650 | j += 1 651 | 652 | if self.detection: 653 | return feat_out 654 | 655 | if output_hidden_states: 656 | all_hidden_states = all_hidden_states + (hidden_states,) 657 | 658 | if not return_dict: 659 | return tuple( 660 | v 661 | for v in [ 662 | hidden_states, 663 | next_decoder_cache, 664 | all_hidden_states, 665 | all_self_attentions, 666 | all_cross_attentions, 667 | ] 668 | if v is not None 669 | ) 670 | return BaseModelOutputWithPastAndCrossAttentions( 671 | last_hidden_state=hidden_states, 672 | past_key_values=next_decoder_cache, 673 | hidden_states=all_hidden_states, 674 | attentions=all_self_attentions, 675 | cross_attentions=all_cross_attentions, 676 | ) 677 | 678 | 679 | class LayoutLMv3Model(LayoutLMv3PreTrainedModel): 680 | """ 681 | """ 682 | 683 | _keys_to_ignore_on_load_missing = [r"position_ids"] 684 | 685 | # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta 686 | def __init__(self, config, detection=False, out_features=None, image_only=False): 687 | super().__init__(config) 688 | self.config = config 689 | assert not config.is_decoder and not config.add_cross_attention, \ 690 | "This version do not support decoder. Please refer to RoBERTa for implementation of is_decoder." 691 | self.detection = detection 692 | if not self.detection: 693 | self.image_only = False 694 | else: 695 | assert config.visual_embed 696 | self.image_only = image_only 697 | 698 | if not self.image_only: 699 | self.embeddings = LayoutLMv3Embeddings(config) 700 | self.encoder = LayoutLMv3Encoder(config, detection=detection, out_features=out_features) 701 | 702 | if config.visual_embed: 703 | embed_dim = self.config.hidden_size 704 | # use the default pre-training parameters for fine-tuning (e.g., input_size) 705 | # when the input_size is larger in fine-tuning, we will interpolate the position embedding in forward 706 | self.patch_embed = PatchEmbed(embed_dim=embed_dim) 707 | 708 | patch_size = 16 709 | size = int(self.config.input_size / patch_size) 710 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 711 | self.pos_embed = nn.Parameter(torch.zeros(1, size * size + 1, embed_dim)) 712 | self.pos_drop = nn.Dropout(p=0.) 713 | 714 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 715 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 716 | 717 | if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias: 718 | self._init_visual_bbox(img_size=(size, size)) 719 | 720 | from functools import partial 721 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 722 | self.norm = norm_layer(embed_dim) 723 | 724 | self.init_weights() 725 | 726 | def get_input_embeddings(self): 727 | return self.embeddings.word_embeddings 728 | 729 | def set_input_embeddings(self, value): 730 | self.embeddings.word_embeddings = value 731 | 732 | def _prune_heads(self, heads_to_prune): 733 | """ 734 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 735 | class PreTrainedModel 736 | """ 737 | for layer, heads in heads_to_prune.items(): 738 | self.encoder.layer[layer].attention.prune_heads(heads) 739 | 740 | def _init_visual_bbox(self, img_size=(14, 14), max_len=1000): 741 | visual_bbox_x = torch.div(torch.arange(0, max_len * (img_size[1] + 1), max_len), 742 | img_size[1], rounding_mode='trunc') 743 | visual_bbox_y = torch.div(torch.arange(0, max_len * (img_size[0] + 1), max_len), 744 | img_size[0], rounding_mode='trunc') 745 | visual_bbox = torch.stack( 746 | [ 747 | visual_bbox_x[:-1].repeat(img_size[0], 1), 748 | visual_bbox_y[:-1].repeat(img_size[1], 1).transpose(0, 1), 749 | visual_bbox_x[1:].repeat(img_size[0], 1), 750 | visual_bbox_y[1:].repeat(img_size[1], 1).transpose(0, 1), 751 | ], 752 | dim=-1, 753 | ).view(-1, 4) 754 | 755 | cls_token_box = torch.tensor([[0 + 1, 0 + 1, max_len - 1, max_len - 1]]) 756 | self.visual_bbox = torch.cat([cls_token_box, visual_bbox], dim=0) 757 | 758 | def _calc_visual_bbox(self, device, dtype, bsz): # , img_size=(14, 14), max_len=1000): 759 | visual_bbox = self.visual_bbox.repeat(bsz, 1, 1) 760 | visual_bbox = visual_bbox.to(device).type(dtype) 761 | return visual_bbox 762 | 763 | def forward_image(self, x): 764 | if self.detection: 765 | x = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None) 766 | else: 767 | x = self.patch_embed(x) 768 | batch_size, seq_len, _ = x.size() 769 | 770 | cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 771 | if self.pos_embed is not None and self.detection: 772 | cls_tokens = cls_tokens + self.pos_embed[:, :1, :] 773 | 774 | x = torch.cat((cls_tokens, x), dim=1) 775 | if self.pos_embed is not None and not self.detection: 776 | x = x + self.pos_embed 777 | x = self.pos_drop(x) 778 | 779 | x = self.norm(x) 780 | return x 781 | 782 | # Copied from transformers.models.bert.modeling_bert.BertModel.forward 783 | def forward( 784 | self, 785 | input_ids=None, 786 | bbox=None, 787 | attention_mask=None, 788 | token_type_ids=None, 789 | valid_span=None, 790 | position_ids=None, 791 | head_mask=None, 792 | inputs_embeds=None, 793 | encoder_hidden_states=None, 794 | encoder_attention_mask=None, 795 | past_key_values=None, 796 | use_cache=None, 797 | output_attentions=None, 798 | output_hidden_states=None, 799 | return_dict=None, 800 | images=None, 801 | ): 802 | r""" 803 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 804 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 805 | the model is configured as a decoder. 806 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 807 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 808 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: 809 | 810 | - 1 for tokens that are **not masked**, 811 | - 0 for tokens that are **masked**. 812 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 813 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 814 | 815 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 816 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 817 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. 818 | use_cache (:obj:`bool`, `optional`): 819 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 820 | decoding (see :obj:`past_key_values`). 821 | """ 822 | 823 | 824 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 825 | 826 | #print("output_attention"+output_attentions.shape) 827 | 828 | output_hidden_states = ( 829 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 830 | ) 831 | 832 | #print("output_hidden_states"+output_hidden_states) 833 | 834 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 835 | 836 | use_cache = False 837 | 838 | # if input_ids is not None and inputs_embeds is not None: 839 | # raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 840 | if input_ids is not None: 841 | input_shape = input_ids.size() 842 | batch_size, seq_length = input_shape 843 | device = input_ids.device 844 | elif inputs_embeds is not None: 845 | input_shape = inputs_embeds.size()[:-1] 846 | batch_size, seq_length = input_shape 847 | device = inputs_embeds.device 848 | elif images is not None: 849 | batch_size = len(images) 850 | device = images.device 851 | else: 852 | raise ValueError("You have to specify either input_ids or inputs_embeds or images") 853 | 854 | if not self.image_only: 855 | # past_key_values_length 856 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 857 | 858 | if attention_mask is None: 859 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) 860 | if token_type_ids is None: 861 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 862 | 863 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 864 | # ourselves in which case we just need to make it broadcastable to all heads. 865 | # extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) 866 | 867 | encoder_extended_attention_mask = None 868 | 869 | # Prepare head mask if needed 870 | # 1.0 in head_mask indicate we keep the head 871 | # attention_probs has shape bsz x n_heads x N x N 872 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 873 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 874 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 875 | 876 | if not self.image_only: 877 | if bbox is None: 878 | bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device) 879 | 880 | embedding_output = self.embeddings( 881 | input_ids=input_ids, 882 | bbox=bbox, 883 | position_ids=position_ids, 884 | token_type_ids=token_type_ids, 885 | inputs_embeds=inputs_embeds, 886 | past_key_values_length=past_key_values_length, 887 | ) 888 | 889 | final_bbox = final_position_ids = None 890 | Hp = Wp = None 891 | if images is not None: 892 | patch_size = 16 893 | Hp, Wp = int(images.shape[2] / patch_size), int(images.shape[3] / patch_size) 894 | visual_emb = self.forward_image(images) 895 | if self.detection: 896 | visual_attention_mask = torch.ones((batch_size, visual_emb.shape[1]), dtype=torch.long, device=device) 897 | if self.image_only: 898 | attention_mask = visual_attention_mask 899 | else: 900 | attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1) 901 | elif self.image_only: 902 | attention_mask = torch.ones((batch_size, visual_emb.shape[1]), dtype=torch.long, device=device) 903 | 904 | if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias: 905 | if self.config.has_spatial_attention_bias: 906 | visual_bbox = self._calc_visual_bbox(device, dtype=torch.long, bsz=batch_size) 907 | if self.image_only: 908 | final_bbox = visual_bbox 909 | else: 910 | final_bbox = torch.cat([bbox, visual_bbox], dim=1) 911 | 912 | visual_position_ids = torch.arange(0, visual_emb.shape[1], dtype=torch.long, device=device).repeat( 913 | batch_size, 1) 914 | if self.image_only: 915 | final_position_ids = visual_position_ids 916 | else: 917 | position_ids = torch.arange(0, input_shape[1], device=device).unsqueeze(0) 918 | position_ids = position_ids.expand_as(input_ids) 919 | final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1) 920 | 921 | if self.image_only: 922 | embedding_output = visual_emb 923 | else: 924 | embedding_output = torch.cat([embedding_output, visual_emb], dim=1) 925 | embedding_output = self.LayerNorm(embedding_output) 926 | embedding_output = self.dropout(embedding_output) 927 | elif self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias: 928 | if self.config.has_spatial_attention_bias: 929 | final_bbox = bbox 930 | if self.config.has_relative_attention_bias: 931 | position_ids = self.embeddings.position_ids[:, :input_shape[1]] 932 | position_ids = position_ids.expand_as(input_ids) 933 | final_position_ids = position_ids 934 | 935 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, None, device) 936 | 937 | encoder_outputs = self.encoder( 938 | embedding_output, 939 | bbox=final_bbox, 940 | position_ids=final_position_ids, 941 | attention_mask=extended_attention_mask, 942 | head_mask=head_mask, 943 | encoder_hidden_states=encoder_hidden_states, 944 | encoder_attention_mask=encoder_extended_attention_mask, 945 | past_key_values=past_key_values, 946 | use_cache=use_cache, 947 | output_attentions=output_attentions, 948 | output_hidden_states=output_hidden_states, 949 | return_dict=return_dict, 950 | Hp=Hp, 951 | Wp=Wp, 952 | valid_span=valid_span, 953 | ) 954 | 955 | if self.detection: 956 | return encoder_outputs 957 | 958 | sequence_output = encoder_outputs[0] 959 | pooled_output = None 960 | 961 | if not return_dict: 962 | return (sequence_output, pooled_output) + encoder_outputs[1:] 963 | 964 | return BaseModelOutputWithPoolingAndCrossAttentions( 965 | last_hidden_state=sequence_output, 966 | pooler_output=pooled_output, 967 | past_key_values=encoder_outputs.past_key_values, 968 | hidden_states=encoder_outputs.hidden_states, 969 | attentions=encoder_outputs.attentions, 970 | cross_attentions=encoder_outputs.cross_attentions, 971 | ) 972 | 973 | -------------------------------------------------------------------------------- /model/configuration_graphlayoutlm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from transformers.models.layoutlmv3.configuration_layoutlmv3 import LayoutLMv3Config 3 | 4 | 5 | class GraphLayoutLMConfig(LayoutLMv3Config): 6 | model_type = "graphlayoutlm" 7 | 8 | def __init__( 9 | self, 10 | pad_token_id=1, 11 | bos_token_id=0, 12 | eos_token_id=2, 13 | max_2d_position_embeddings=1024, 14 | coordinate_size=None, 15 | shape_size=None, 16 | has_relative_attention_bias=False, 17 | rel_pos_bins=32, 18 | max_rel_pos=128, 19 | has_spatial_attention_bias=False, 20 | rel_2d_pos_bins=64, 21 | max_rel_2d_pos=256, 22 | visual_embed=True, 23 | mim=False, 24 | wpa_task=False, 25 | discrete_vae_weight_path='', 26 | discrete_vae_type='dall-e', 27 | input_size=224, 28 | second_input_size=112, 29 | device='cuda', 30 | **kwargs 31 | ): 32 | super().__init__( 33 | pad_token_id=pad_token_id, 34 | bos_token_id=bos_token_id, 35 | eos_token_id=eos_token_id, 36 | max_2d_position_embeddings = max_2d_position_embeddings, 37 | coordinate_size = coordinate_size, 38 | shape_size = shape_size, 39 | has_relative_attention_bias = has_relative_attention_bias, 40 | rel_pos_bins = rel_pos_bins, 41 | max_rel_pos = max_rel_pos, 42 | has_spatial_attention_bias = has_spatial_attention_bias, 43 | rel_2d_pos_bins = rel_2d_pos_bins, 44 | max_rel_2d_pos = max_rel_2d_pos, 45 | visual_embed = visual_embed, 46 | mim = mim, 47 | wpa_task = wpa_task, 48 | discrete_vae_weight_path = discrete_vae_weight_path, 49 | discrete_vae_type = discrete_vae_type, 50 | input_size = input_size, 51 | second_input_size = second_input_size, 52 | device = device, 53 | **kwargs) 54 | -------------------------------------------------------------------------------- /model/graphlayoutlm.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import CrossEntropyLoss 5 | from transformers.modeling_utils import PreTrainedModel 6 | from transformers.modeling_outputs import ( 7 | BaseModelOutputWithPoolingAndCrossAttentions, 8 | TokenClassifierOutput, 9 | ) 10 | 11 | from model.basemodel import LayoutLMv3Model 12 | from model.configuration_graphlayoutlm import GraphLayoutLMConfig 13 | 14 | 15 | class GraphLayoutLMPreTrainedModel(PreTrainedModel): 16 | config_class = GraphLayoutLMConfig 17 | base_model_prefix = "graphlayoutlm" 18 | 19 | # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights 20 | def _init_weights(self, module): 21 | """Initialize the weights""" 22 | if isinstance(module, nn.Linear): 23 | # Slightly different from the TF version which uses truncated_normal for initialization 24 | # cf https://github.com/pytorch/pytorch/pull/5617 25 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 26 | if module.bias is not None: 27 | module.bias.data.zero_() 28 | elif isinstance(module, nn.Embedding): 29 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 30 | if module.padding_idx is not None: 31 | module.weight.data[module.padding_idx].zero_() 32 | elif isinstance(module, nn.LayerNorm): 33 | module.bias.data.zero_() 34 | module.weight.data.fill_(1.0) 35 | 36 | 37 | class GraphAttentionLayer(nn.Module): 38 | def __init__(self,config): 39 | super(GraphAttentionLayer, self).__init__() 40 | self.num_attention_heads = int(config.num_attention_heads/2) 41 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads) 42 | self.all_head_size = self.num_attention_heads * self.attention_head_size 43 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 44 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 45 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 46 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 47 | self.final = nn.Linear(config.hidden_size, self.all_head_size) 48 | 49 | def transpose_for_scores(self, x): 50 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 51 | x = x.view(*new_x_shape) 52 | return x.permute(0, 2, 1, 3) 53 | 54 | def forward( 55 | self, 56 | seq_inputs, 57 | graph_mask, 58 | ): 59 | mixed_query_layer = self.query(seq_inputs) 60 | 61 | key_layer = self.transpose_for_scores(self.key(seq_inputs)) 62 | value_layer = self.transpose_for_scores(self.value(seq_inputs)) 63 | query_layer = self.transpose_for_scores(mixed_query_layer) 64 | 65 | 66 | attention_scores = torch.matmul(query_layer , key_layer.transpose(-1, -2))/ math.sqrt(self.attention_head_size) 67 | 68 | attention_scores = attention_scores+graph_mask.unsqueeze(1).repeat(1,self.num_attention_heads,1,1) 69 | 70 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 71 | 72 | attention_probs = self.dropout(attention_probs) 73 | 74 | context_layer = torch.matmul(attention_probs, value_layer) 75 | 76 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 77 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 78 | context_layer = context_layer.view(*new_context_layer_shape) 79 | 80 | outputs = self.final(context_layer) 81 | 82 | return outputs 83 | 84 | class SubLayerConnection(nn.Module): 85 | def __init__(self,config): 86 | super(SubLayerConnection,self).__init__() 87 | self.norm = nn.LayerNorm(config.hidden_size, eps=1e-05) 88 | self.dropout=nn.Dropout(p=config.hidden_dropout_prob) 89 | self.size=config.hidden_size 90 | 91 | def forward(self,x,graph_mask,sublayer): 92 | return x+self.dropout(sublayer(self.norm(x),graph_mask)) 93 | 94 | 95 | class GraphLayoutLM(GraphLayoutLMPreTrainedModel): 96 | def __init__(self,config): 97 | super().__init__(config) 98 | self.model_base=LayoutLMv3Model(config) 99 | self.graph_attention_layer=GraphAttentionLayer(config) 100 | self.sublayer=SubLayerConnection(config) 101 | self.init_weights() 102 | 103 | 104 | def forward( 105 | self, 106 | input_ids=None, 107 | bbox=None, 108 | attention_mask=None, 109 | token_type_ids=None, 110 | position_ids=None, 111 | head_mask=None, 112 | inputs_embeds=None, 113 | output_attentions=None, 114 | output_hidden_states=None, 115 | return_dict=None, 116 | images=None, 117 | valid_span=None, 118 | graph_mask=None, 119 | ): 120 | outputs = self.model_base( 121 | input_ids, 122 | bbox=bbox, 123 | attention_mask=attention_mask, 124 | token_type_ids=token_type_ids, 125 | position_ids=position_ids, 126 | head_mask=head_mask, 127 | inputs_embeds=inputs_embeds, 128 | output_attentions=output_attentions, 129 | output_hidden_states=output_hidden_states, 130 | return_dict=return_dict, 131 | images=images, 132 | valid_span=valid_span, 133 | ) 134 | sequence_output=self.sublayer(outputs[0],graph_mask,self.graph_attention_layer) 135 | 136 | if not return_dict: 137 | return (sequence_output, outputs[1]) 138 | 139 | return BaseModelOutputWithPoolingAndCrossAttentions( 140 | last_hidden_state=sequence_output, 141 | pooler_output=outputs.pooler_output, 142 | past_key_values=outputs.past_key_values, 143 | hidden_states=outputs.hidden_states, 144 | attentions=outputs.attentions, 145 | cross_attentions=outputs.cross_attentions, 146 | ) 147 | 148 | 149 | class GraphLayoutLMClassificationHead(nn.Module): 150 | """ 151 | Head for sentence-level classification tasks. 152 | Reference: RobertaClassificationHead 153 | """ 154 | 155 | def __init__(self, config, pool_feature=False): 156 | super().__init__() 157 | self.pool_feature = pool_feature 158 | if pool_feature: 159 | self.dense = nn.Linear(config.hidden_size*3, config.hidden_size) 160 | else: 161 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 162 | classifier_dropout = ( 163 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 164 | ) 165 | self.dropout = nn.Dropout(classifier_dropout) 166 | self.out_proj = nn.Linear(config.hidden_size, config.num_labels) 167 | 168 | def forward(self, x): 169 | # x = features[:, 0, :] # take token (equiv. to [CLS]) 170 | x = self.dropout(x) 171 | x = self.dense(x) 172 | x = torch.tanh(x) 173 | x = self.dropout(x) 174 | x = self.out_proj(x) 175 | return x 176 | 177 | 178 | class GraphLayoutLMForTokenClassification(GraphLayoutLMPreTrainedModel): 179 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 180 | _keys_to_ignore_on_load_missing = [r"position_ids"] 181 | 182 | def __init__(self, config): 183 | super().__init__(config) 184 | self.num_labels = config.num_labels 185 | self.graphlayoutlm = GraphLayoutLM(config) 186 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 187 | if config.num_labels < 10: 188 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 189 | else: 190 | self.classifier = GraphLayoutLMClassificationHead(config, pool_feature=False) 191 | 192 | self.init_weights() 193 | 194 | 195 | def forward( 196 | self, 197 | input_ids=None, 198 | bbox=None, 199 | attention_mask=None, 200 | token_type_ids=None, 201 | position_ids=None, 202 | valid_span=None, 203 | head_mask=None, 204 | inputs_embeds=None, 205 | labels=None, 206 | output_attentions=None, 207 | output_hidden_states=None, 208 | return_dict=None, 209 | images=None, 210 | graph_mask=None, 211 | ): 212 | r""" 213 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 214 | Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - 215 | 1]``. 216 | """ 217 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 218 | 219 | outputs = self.graphlayoutlm( 220 | input_ids, 221 | bbox=bbox, 222 | attention_mask=attention_mask, 223 | token_type_ids=token_type_ids, 224 | position_ids=position_ids, 225 | head_mask=head_mask, 226 | inputs_embeds=inputs_embeds, 227 | output_attentions=output_attentions, 228 | output_hidden_states=output_hidden_states, 229 | return_dict=return_dict, 230 | images=images, 231 | valid_span=valid_span, 232 | graph_mask=graph_mask, 233 | ) 234 | 235 | sequence_output = outputs[0] 236 | 237 | sequence_output = self.dropout(sequence_output) 238 | logits = self.classifier(sequence_output) 239 | 240 | loss = None 241 | if labels is not None: 242 | loss_fct = CrossEntropyLoss() 243 | # Only keep active parts of the loss 244 | if attention_mask is not None: 245 | active_loss = attention_mask.view(-1) == 1 246 | active_logits = logits.view(-1, self.num_labels) 247 | active_labels = torch.where( 248 | active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) 249 | ) 250 | loss = loss_fct(active_logits, active_labels) 251 | else: 252 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 253 | 254 | if not return_dict: 255 | output = (logits,) + outputs[2:] 256 | return ((loss,) + output) if loss is not None else output 257 | 258 | return TokenClassifierOutput( 259 | loss=loss, 260 | logits=logits, 261 | hidden_states=outputs.hidden_states, 262 | attentions=outputs.attentions, 263 | ) -------------------------------------------------------------------------------- /model/tokenization_graphlayoutlm_fast.py: -------------------------------------------------------------------------------- 1 | from transformers.models.layoutlmv3.tokenization_layoutlmv3_fast import LayoutLMv3TokenizerFast 2 | from transformers.utils import logging 3 | 4 | 5 | logger = logging.get_logger(__name__) 6 | 7 | class GraphLayoutLMTokenizerFast(LayoutLMv3TokenizerFast): 8 | pass 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.18.0 2 | setuptools==58.0.4 3 | datasets==2.3.2 4 | transformers==4.28.1 5 | timm==0.4.12 6 | seqeval==1.2.2 7 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Line-Kite/GraphLayoutLM/8a39a54557b318f90070710c3a868c650cdf6af2/utils/__init__.py -------------------------------------------------------------------------------- /utils/graph_builder_uitls.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | 4 | OPPOSITE={"left":"right", 5 | "right":"left", 6 | "up":"down", 7 | "down":"up", 8 | "parent":"child", 9 | "child":"parent"} 10 | 11 | class TreeNode: 12 | def __init__(self,id,box): 13 | self.id=id 14 | self.box=box 15 | self.parent=None 16 | self.children=[] 17 | 18 | def json_loader(path): 19 | with open(path, "r", encoding="utf8") as f: 20 | data = json.load(f) 21 | return data 22 | 23 | def json_saver(data,path): 24 | json_data = json.dumps(data, ensure_ascii=False) 25 | with open(path, "w", encoding='utf-8') as f: 26 | f.write(json_data) 27 | 28 | def posotion_judge(box_a,box_b): 29 | center=[box_a[0]+box_a[2],box_a[1]+box_a[3]] 30 | box_b2=[2*i for i in box_b] 31 | if box_b2[2]center[1]: 37 | return "down-left" 38 | elif box_b2[0]<=center[0]<=box_b2[2]: 39 | if box_b2[3]center[1]: 44 | return "down" 45 | elif box_b2[0]>center[0]: 46 | if box_b2[3] < center[1]: 47 | return "up-right" 48 | elif box_b2[1] <= center[1] <= box_b2[3]: 49 | return "right" 50 | elif box_b2[1] > center[1]: 51 | return "down-right" 52 | 53 | def node_box_update(node_p,node_c): 54 | if node_p.box[0]>node_c.box[0]: 55 | node_p.box[0]=node_c.box[0] 56 | if node_p.box[1]>node_c.box[1]: 57 | node_p.box[1]=node_c.box[1] 58 | if node_p.box[2]>> transforms.Compose([ 145 | >>> transforms.CenterCrop(10), 146 | >>> transforms.PILToTensor(), 147 | >>> transforms.ConvertImageDtype(torch.float), 148 | >>> ]) 149 | 150 | .. note:: 151 | In order to script the transformations, please use ``torch.nn.Sequential`` as below. 152 | 153 | >>> transforms = torch.nn.Sequential( 154 | >>> transforms.CenterCrop(10), 155 | >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 156 | >>> ) 157 | >>> scripted_transforms = torch.jit.script(transforms) 158 | 159 | Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require 160 | `lambda` functions or ``PIL.Image``. 161 | 162 | """ 163 | 164 | def __init__(self, transforms): 165 | self.transforms = transforms 166 | 167 | def __call__(self, img, augmentation=False, box=None): 168 | for t in self.transforms: 169 | img = t(img, augmentation, box) 170 | return img 171 | 172 | 173 | class RandomResizedCropAndInterpolationWithTwoPic: 174 | """Crop the given PIL Image to random size and aspect ratio with random interpolation. 175 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 176 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 177 | is finally resized to given size. 178 | This is popularly used to train the Inception networks. 179 | Args: 180 | size: expected output size of each edge 181 | scale: range of size of the origin size cropped 182 | ratio: range of aspect ratio of the origin aspect ratio cropped 183 | interpolation: Default: PIL.Image.BILINEAR 184 | """ 185 | 186 | def __init__(self, size, second_size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), 187 | interpolation='bilinear', second_interpolation='lanczos'): 188 | if isinstance(size, tuple): 189 | self.size = size 190 | else: 191 | self.size = (size, size) 192 | if second_size is not None: 193 | if isinstance(second_size, tuple): 194 | self.second_size = second_size 195 | else: 196 | self.second_size = (second_size, second_size) 197 | else: 198 | self.second_size = None 199 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 200 | warnings.warn("range should be of kind (min, max)") 201 | 202 | self.interpolation = _pil_interp(interpolation) 203 | self.second_interpolation = _pil_interp(second_interpolation) 204 | self.scale = scale 205 | self.ratio = ratio 206 | 207 | @staticmethod 208 | def get_params(img, scale, ratio): 209 | """Get parameters for ``crop`` for a random sized crop. 210 | Args: 211 | img (PIL Image): Image to be cropped. 212 | scale (tuple): range of size of the origin size cropped 213 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 214 | Returns: 215 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 216 | sized crop. 217 | """ 218 | area = img.size[0] * img.size[1] 219 | 220 | for attempt in range(10): 221 | target_area = random.uniform(*scale) * area 222 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 223 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 224 | 225 | w = int(round(math.sqrt(target_area * aspect_ratio))) 226 | h = int(round(math.sqrt(target_area / aspect_ratio))) 227 | 228 | if w <= img.size[0] and h <= img.size[1]: 229 | i = random.randint(0, img.size[1] - h) 230 | j = random.randint(0, img.size[0] - w) 231 | return i, j, h, w 232 | 233 | # Fallback to central crop 234 | in_ratio = img.size[0] / img.size[1] 235 | if in_ratio < min(ratio): 236 | w = img.size[0] 237 | h = int(round(w / min(ratio))) 238 | elif in_ratio > max(ratio): 239 | h = img.size[1] 240 | w = int(round(h * max(ratio))) 241 | else: # whole image 242 | w = img.size[0] 243 | h = img.size[1] 244 | i = (img.size[1] - h) // 2 245 | j = (img.size[0] - w) // 2 246 | return i, j, h, w 247 | 248 | def __call__(self, img, augmentation=False, box=None): 249 | """ 250 | Args: 251 | img (PIL Image): Image to be cropped and resized. 252 | Returns: 253 | PIL Image: Randomly cropped and resized image. 254 | """ 255 | if augmentation: 256 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 257 | img = F.crop(img, i, j, h, w) 258 | # img, box = crop(img, i, j, h, w, box) 259 | img = F.resize(img, self.size, self.interpolation) 260 | second_img = F.resize(img, self.second_size, self.second_interpolation) \ 261 | if self.second_size is not None else None 262 | return img, second_img 263 | 264 | def __repr__(self): 265 | if isinstance(self.interpolation, (tuple, list)): 266 | interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation]) 267 | else: 268 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 269 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 270 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) 271 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) 272 | format_string += ', interpolation={0}'.format(interpolate_str) 273 | if self.second_size is not None: 274 | format_string += ', second_size={0}'.format(self.second_size) 275 | format_string += ', second_interpolation={0}'.format(_pil_interpolation_to_str[self.second_interpolation]) 276 | format_string += ')' 277 | return format_string 278 | 279 | 280 | def pil_loader(path: str) -> Image.Image: 281 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 282 | with open(path, 'rb') as f: 283 | img = Image.open(f) 284 | return img.convert('RGB') 285 | --------------------------------------------------------------------------------