├── requirements.txt ├── models ├── config.json └── sample_data.json ├── LICENSE ├── README.md ├── demo.ipynb ├── create_dataset.py ├── tokenization.py ├── modeling_layoutlm.py ├── utils_docvqa.py └── run_docvqa.py /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | editdistance 3 | pillow 4 | opencv-python 5 | torch 6 | torchvision 7 | tensorboardX 8 | transformers 9 | tensorflow-gpu 10 | -------------------------------------------------------------------------------- /models/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LayoutLMForTokenClassification" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_2d_position_embeddings": 1024, 13 | "max_position_embeddings": 512, 14 | "model_type": "bert", 15 | "num_attention_heads": 12, 16 | "num_hidden_layers": 12, 17 | "output_past": true, 18 | "pad_token_id": 0, 19 | "type_vocab_size": 2, 20 | "vocab_size": 30522 21 | } 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Anisha Gunjal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Document Visual Question Answering (DocVQA) 2 | This repo hosts the basic functional code for our approach entitled [HyperDQA](https://rrc.cvc.uab.es/?ch=17&com=evaluation&view=method_info&task=1&m=75548) in the [Document Visual Question Answering](https://rrc.cvc.uab.es/?ch=17) competition hosted as a part of [Workshop on Text and Documents in Deep Learning Era](https://cvpr2020text.wordpress.com) at [CVPR2020](http://cvpr2020.thecvf.com). Our approach stands at position 4 on the [Leaderboard](https://rrc.cvc.uab.es/?ch=17&com=evaluation&task=1). 3 | 4 | Read more about our approach in this [blogpost](https://medium.com/@anishagunjal7/document-visual-question-answering-e6090f3bddee)! 5 | 6 | ## Installation 7 | ### Virtual Environment Python 3 (Recommended) 8 | 1) Clone the repository 9 | ``` 10 | git clone https://github.com/anisha2102/docvqa.git 11 | ``` 12 | 13 | 2) Install libraries 14 | ``` 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ## Downloads 19 | 1) Download the dataset 20 | The dataset for Task 1 can be downloaded from the Competition [Website](https://rrc.cvc.uab.es/?ch=17) from the Downloads Section. 21 | The dataset consists of document images and their corresponding OCR transcriptions. 22 | 23 | 2) Download the pretrained model 24 | Download the pretrained model for LayoutLM-Base, Uncased from [here](https://github.com/microsoft/unilm/tree/master/layoutlm) 25 | ## Prepare dataset 26 | ``` 27 | python create_dataset.py \ 28 | \ 29 | \ 30 | \ 31 | \ 32 | 33 | ``` 34 | ## Train the model 35 | ``` 36 | CUDA_VISIBLE_DEVICES=0 python run_docvqa.py \ 37 | --data_dir \ 38 | --model_type layoutlm \ 39 | --model_name_or_path \ #example ./models/layoutlm-base-uncased 40 | --do_lower_case \ 41 | --max_seq_length 512 \ 42 | --do_train \ 43 | --num_train_epochs 15 \ 44 | --logging_steps 500 \ 45 | --evaluate_during_training \ 46 | --save_steps 500 \ 47 | --do_eval \ 48 | --output_dir / \ 49 | --per_gpu_train_batch_size 8 \ 50 | --overwrite_output_dir \ 51 | --cache_dir /models \ 52 | --skip_match_answers \ 53 | --val_json \ 54 | --train_json \ 55 | ``` 56 | ## Model Checkpoints 57 | Download the pytorch_model.bin file from the link below and copy it to the models folder. 58 | [Google Drive Link](https://drive.google.com/file/d/1W4E06nb-tDcjKVN9iCjjk0b_3EyHkqVr/view?usp=sharing) 59 | 60 | ## Demo 61 | Try out the demo on a sample datapoint with demo.ipynb 62 | 63 | ## Acknowledgements 64 | The code and pretrained models are based on [LayoutLM](https://github.com/microsoft/unilm/tree/master/layoutlm) and [HuggingFace Transformers](https://github.com/huggingface/transformers). Many thanks for their amazing open source contributions. 65 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from modeling_layoutlm import LayoutLMForTokenClassification\n", 11 | "from transformers import (\n", 12 | " BertConfig,\n", 13 | " BertTokenizer,\n", 14 | ")\n", 15 | "from utils_docvqa import (\n", 16 | " read_docvqa_examples,\n", 17 | " convert_examples_to_features)\n", 18 | "from torch.utils.data import DataLoader, SequentialSampler, TensorDataset\n", 19 | "from transformers.data.processors.squad import SquadResult\n", 20 | "from tqdm import tqdm\n", 21 | "import numpy as np" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "MODEL_FOLDER = \"./models/\"\n", 31 | "SAMPLE_DATA = \"./models/sample_data.json\"\n", 32 | "LABELS = [\"start\",\"end\"]\n", 33 | "pad_token_label_id=-100\n", 34 | "labels = [\"start\",\"end\"]\n", 35 | "max_seq_length = 512\n", 36 | "max_query_length = 64\n", 37 | "doc_stride = 128" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 3, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "device = torch.device(\"cuda:0\")\n", 47 | "torch.cuda.set_device(device)\n", 48 | "model_class = LayoutLMForTokenClassification\n", 49 | "config_class = BertConfig\n", 50 | "tokenizer_class = BertTokenizer\n", 51 | "config = config_class.from_pretrained(MODEL_FOLDER,num_labels=2,cache_dir=None)\n", 52 | "model = model_class.from_pretrained(MODEL_FOLDER)\n", 53 | "tokenizer = tokenizer_class.from_pretrained(MODEL_FOLDER,do_lower_case=True)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 4, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "WARNING:tensorflow:From /mnt/anisha/code/docvqa/utils_docvqa.py:95: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.\n", 66 | "\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "examples = read_docvqa_examples(SAMPLE_DATA, is_training=False)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "features = convert_examples_to_features(\n", 81 | " examples=examples,\n", 82 | " label_list=labels,\n", 83 | " tokenizer=tokenizer,\n", 84 | " max_seq_length=max_seq_length,\n", 85 | " doc_stride=doc_stride,\n", 86 | " max_query_length=max_query_length,\n", 87 | " is_training=False,\n", 88 | " pad_token_label_id=pad_token_label_id)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 6, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)\n", 98 | "all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)\n", 99 | "all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)\n", 100 | "all_bboxes = torch.tensor([f.boxes for f in features], dtype=torch.long)\n", 101 | "all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)\n", 102 | "\n", 103 | "eval_dataset = TensorDataset(\n", 104 | " all_input_ids, all_input_mask, all_segment_ids,all_bboxes,all_example_index)\n", 105 | "eval_batch_size = 1\n", 106 | "eval_sampler = (\n", 107 | " SequentialSampler(eval_dataset))\n", 108 | "\n", 109 | "eval_dataloader = DataLoader(\n", 110 | " eval_dataset, sampler=eval_sampler, batch_size=eval_batch_size\n", 111 | " )" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 7, 117 | "metadata": { 118 | "scrolled": false 119 | }, 120 | "outputs": [ 121 | { 122 | "name": "stderr", 123 | "output_type": "stream", 124 | "text": [ 125 | "Evaluating: 100%|██████████| 1/1 [00:00<00:00, 32.54it/s]" 126 | ] 127 | }, 128 | { 129 | "name": "stdout", 130 | "output_type": "stream", 131 | "text": [ 132 | "what is the contact person name mentioned in letter ?\n", 133 | "maura payne\n" 134 | ] 135 | }, 136 | { 137 | "name": "stderr", 138 | "output_type": "stream", 139 | "text": [ 140 | "\n" 141 | ] 142 | } 143 | ], 144 | "source": [ 145 | "model.to(device)\n", 146 | "all_results = []\n", 147 | "\n", 148 | "def to_list(tensor):\n", 149 | " return tensor.detach().cpu().tolist()\n", 150 | "\n", 151 | "for batch in tqdm(eval_dataloader, desc=\"Evaluating\"):\n", 152 | " model.eval()\n", 153 | " batch = tuple(t.to(device) for t in batch)\n", 154 | " with torch.no_grad():\n", 155 | " inputs = {\n", 156 | " \"input_ids\": batch[0],\n", 157 | " \"attention_mask\": batch[1],\n", 158 | " }\n", 159 | " inputs[\"bbox\"] = batch[3]\n", 160 | " inputs[\"token_type_ids\"] = (batch[2])\n", 161 | " outputs = model(**inputs)\n", 162 | " example_indices = batch[4]\n", 163 | "\n", 164 | " for i, example_index in enumerate(example_indices):\n", 165 | " eval_feature = features[example_index.item()]\n", 166 | " unique_id = int(eval_feature.unique_id)\n", 167 | "\n", 168 | " output = [to_list(output[i]) for output in outputs]\n", 169 | "\n", 170 | " start_logits, end_logits = output\n", 171 | " result = SquadResult(unique_id, start_logits, end_logits)\n", 172 | " all_results.append(result)\n", 173 | "predictions_json = {}\n", 174 | "assert len(all_results)==len(features)\n", 175 | "for i in range(len(all_results)):\n", 176 | " start_index = np.argmax(all_results[i].start_logits)\n", 177 | " end_index = np.argmax(all_results[i].end_logits)\n", 178 | " pred_answer = features[i].tokens[start_index:end_index+1]\n", 179 | " pred_answer = ' '.join([x for x in pred_answer])\n", 180 | " pred_text = pred_answer.replace(' ##', '')\n", 181 | " question = features[i].tokens[1:features[i].tokens.index('[SEP]')]\n", 182 | " question_text = ' '.join([x for x in question])\n", 183 | " question_text = question_text.replace(' ##', '')\n", 184 | " print(question_text)\n", 185 | " print(pred_text)\n", 186 | " \n", 187 | " " 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [] 196 | } 197 | ], 198 | "metadata": { 199 | "kernelspec": { 200 | "display_name": "Python 3", 201 | "language": "python", 202 | "name": "python3" 203 | }, 204 | "language_info": { 205 | "codemirror_mode": { 206 | "name": "ipython", 207 | "version": 3 208 | }, 209 | "file_extension": ".py", 210 | "mimetype": "text/x-python", 211 | "name": "python", 212 | "nbconvert_exporter": "python", 213 | "pygments_lexer": "ipython3", 214 | "version": "3.6.9" 215 | } 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 4 219 | } 220 | -------------------------------------------------------------------------------- /create_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | import cv2 4 | import PIL.Image 5 | from tqdm import tqdm 6 | import editdistance 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('ocr_folder', help='Path to folder containing OCR annotations') 11 | parser.add_argument('documents_folder', help='Path to folder containing document images') 12 | parser.add_argument('train_v1_json', help='Path to train_v1.0.json') 13 | parser.add_argument('out_train_json') 14 | parser.add_argument('out_val_json') 15 | 16 | args = parser.parse_args() 17 | 18 | def bbox_string(box, width, length): 19 | return [ 20 | int(1000 * (box[0] / width)), 21 | int(1000 * (box[1] / length)), 22 | int(1000 * (box[2] / width)), 23 | int(1000 * (box[3] / length)) 24 | ] 25 | 26 | def clean_text(text): 27 | replace_chars = ',.;:()-/$%&*' 28 | for j in replace_chars: 29 | if text is not None: 30 | text = text.replace(j,'') 31 | return text 32 | 33 | def harsh_find(answer_tokens, words): 34 | answer_raw = ''.join(answer_tokens) 35 | answer = ' '.join(answer_tokens) 36 | if len(answer_tokens)==1: 37 | for (ind,w) in enumerate(words): 38 | dist=0 if len(answer)<5 else 1 39 | if editdistance.eval(answer,w)<=dist: 40 | start_index=end_index=ind 41 | return start_index,end_index,w 42 | for (ind,w) in enumerate(words): 43 | if answer_raw.startswith(w): #Looks like words are split 44 | for inc in range(1,30): 45 | if ind+inc>=len(words): 46 | break 47 | w=w+words[ind+inc] 48 | if len(answer_raw)>=5: 49 | dist=1 50 | else: 51 | dist=0 52 | start_index=ind 53 | end_index=ind+inc 54 | ext_list = words[start_index:end_index+1] 55 | extracted_answer = ' '.join(ext_list) 56 | 57 | if editdistance.eval(answer.replace(' ',''),extracted_answer.replace(' ',''))<=dist: 58 | return start_index,end_index,extracted_answer 59 | return reverse_harsh_find(answer_tokens, words) 60 | 61 | 62 | def reverse_harsh_find(answer_tokens, words): 63 | answer_raw = ''.join(answer_tokens) 64 | answer = ''.join(answer_tokens) 65 | for (ind,w) in enumerate(words): 66 | if answer_raw.endswith(w): #Looks like words are split 67 | for inc in range(1,30): 68 | if ind-inc<0: 69 | break 70 | w=words[ind-inc]+w 71 | if len(answer_raw)>=15: 72 | dist=3 73 | elif len(answer_raw)>=5: 74 | dist=1 75 | else: 76 | dist=0 77 | start_index=ind-inc 78 | end_index=ind 79 | ext_list = words[start_index:end_index+1] 80 | extracted_answer = ' '.join(ext_list) 81 | 82 | if editdistance.eval(answer.replace(' ',''),extracted_answer.replace(' ',''))<=dist: 83 | return start_index,end_index,extracted_answer 84 | return None,None,None 85 | 86 | def get_answer_indices(ques_id,words, answer): 87 | count = 0 88 | answer_tokens = answer.split() 89 | end_index = None 90 | start_index = None 91 | words = [clean_text(x) for x in words] 92 | answer_tokens = [clean_text(x) for x in answer_tokens] 93 | answer = ' '.join(answer_tokens) 94 | 95 | if answer_tokens[0] in words: 96 | start_index = words.index(answer_tokens[0]) 97 | if answer_tokens[-1] in words: 98 | end_index = words.index(answer_tokens[-1]) 99 | if start_index is not None and end_index is not None: 100 | if start_index > end_index: 101 | if answer_tokens[-1] in words[start_index:]: 102 | end_index = words[start_index:].index(answer_tokens[-1]) 103 | end_index+=start_index 104 | else: 105 | #Last try 106 | start_index,end_index,extracted_answer = harsh_find(answer_tokens,words) 107 | return start_index,end_index,extracted_answer 108 | 109 | 110 | assert start_index<=end_index 111 | extracted_answer = ' '.join(words[start_index:end_index+1]) 112 | if answer.replace(' ','')!=extracted_answer.replace(' ',''): 113 | start_index,end_index,extracted_answer = harsh_find(answer_tokens,words) 114 | return start_index,end_index,extracted_answer 115 | else: 116 | return start_index, end_index, extracted_answer 117 | 118 | return None,None,None 119 | else: 120 | answer_raw = ''.join(answer_tokens) 121 | start_index,end_index,extracted_answer = harsh_find(answer_tokens,words) 122 | return start_index,end_index,extracted_answer 123 | 124 | def find_candidate_lines(ocr_json,ans_json): 125 | pass 126 | 127 | 128 | data = [] 129 | ocr_files = glob.glob(args.ocr_folder+"/*") 130 | ocr_files = [x.split('.')[0] for x in ocr_files] 131 | dict_img_qa = json.load(open(args.train_v1_json)) 132 | found = 0 133 | nf = [] 134 | not_found = 0 135 | img_id_covered = [] 136 | 137 | for datapt in tqdm(dict_img_qa["data"]): 138 | img_id = datapt["image"].split('/')[-1].split('.')[0] 139 | if img_id in img_id_covered: 140 | continue 141 | else: 142 | img_id_covered.append(img_id) 143 | img_qs = [] 144 | questionId = [] 145 | img_as = [] 146 | 147 | for d in dict_img_qa["data"]: 148 | id_im = d["image"].split('/')[-1].split('.')[0] 149 | if id_im==img_id: 150 | img_qs.append(d["question"]) 151 | questionId.append(d["questionId"]) 152 | img_as.append(d["answers"][0]) 153 | 154 | example = {} 155 | example["image_id"] = img_id 156 | example["qas"] = [] 157 | words = [] 158 | boxes = [] 159 | boxes_norm = [] 160 | line_indices = [] 161 | lines_array = [] 162 | 163 | ocr_file = glob.glob(args.ocr_folder+"/"+img_id+'.json') 164 | img_file = glob.glob(args.documents_folder+"/"+img_id+'.png') 165 | img = cv2.imread(img_file[0]) 166 | length, width = img.shape[:2] 167 | ocr_json = json.load(open(ocr_file[0])) 168 | 169 | assert len(ocr_file)==1 170 | assert len(img_file)==1 171 | 172 | #Added boxes and context to the example 173 | for obj in ocr_json['recognitionResults']: 174 | lines = obj['lines'] 175 | idx = 0 176 | for line in lines: 177 | lines_array.append(line['text']) 178 | for word in line['words']: 179 | words.append(word['text'].lower()) 180 | line_indices.append(idx) 181 | x1,y1,x2,y2,x3,y3,x4,y4 = word['boundingBox'] 182 | new_x1 = min([x1,x2,x3,x4]) 183 | new_x2 = max([x1,x2,x3,x4]) 184 | new_y1 = min([y1,y2,y3,y4]) 185 | new_y2 = max([y1,y2,y3,y4]) 186 | boxes.append([new_x1,new_y1,new_x2,new_y2]) 187 | box_norm = bbox_string([new_x1,new_y1,new_x2,new_y2], width, length) 188 | assert new_x2>=new_x1 189 | assert new_y2>=new_y1 190 | assert box_norm[2]>=box_norm[0] 191 | assert box_norm[3]>=box_norm[1] 192 | 193 | boxes_norm.append(box_norm) 194 | idx+=1 195 | example["context"] = words 196 | example["boxes"] = boxes_norm 197 | 198 | assert len(example["context"]) == len(example["boxes"]) 199 | assert len(example["context"]) == len(line_indices) 200 | 201 | 202 | ques_counter = 1 203 | for qid in range(len(img_qs)): 204 | ques = img_qs[qid] 205 | ans = img_as[qid] 206 | ques_json = {} 207 | ques_json['qid'] = img_id+'-'+str(ques_counter) 208 | ques_counter+=1 209 | ques_json["question"] = ques.lower() 210 | ques_json["answer"] = [] 211 | ans_json = {} 212 | ans_json["text"] = ans.lower() 213 | ques_json["answer"].append(ans_json) 214 | for ans_index in range(len(ques_json["answer"])): 215 | start_index, end_index, extracted_answer = get_answer_indices(ques_json['qid'],example["context"],ques_json["answer"][ans_index]["text"]) 216 | replace_chars =',.;:()-/$%&*' 217 | ans=ans.lower() 218 | extracted_answer = clean_text(extracted_answer) 219 | ans = clean_text(ans) 220 | dist = editdistance.eval(extracted_answer.replace(' ',''),ans.replace(' ','')) if extracted_answer!=None else 1000 221 | if dist>5: 222 | start_index=None 223 | if start_index is not None: 224 | break 225 | if start_index is None or len(extracted_answer)>150 or extracted_answer=="": 226 | nf.append(img_id) 227 | not_found+=1 228 | start_index=None 229 | end_index=None 230 | continue 231 | else: 232 | found+=1 233 | ans_json["answer_start"] = start_index 234 | ans_json["answer_end"] = end_index 235 | example["qas"].append(ques_json) 236 | data.append(example) 237 | 238 | val_count=1 239 | new_val = [] 240 | new_train = [] 241 | 242 | 243 | for i in tqdm(data): 244 | img_id = i['image_id'] 245 | if val_count<=1000: 246 | new_val.append(i) 247 | val_count+=1 248 | else: 249 | new_train.append(i) 250 | 251 | 252 | print("LEN VAL",len(new_val)) 253 | print("LEN TRAIN",len(new_train)) 254 | 255 | with open(args.out_train_json, "w") as fp: 256 | json.dump(new_train,fp) 257 | with open(args.out_val_json, "w") as fp: 258 | json.dump(new_val,fp) 259 | 260 | print("Answers found",found) 261 | print("Answers not found",not_found) 262 | -------------------------------------------------------------------------------- /models/sample_data.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "boxes": [ 4 | [ 5 | 344, 6 | 14, 7 | 586, 8 | 64 9 | ], 10 | [ 11 | 637, 12 | 15, 13 | 648, 14 | 27 15 | ], 16 | [ 17 | 650, 18 | 15, 19 | 663, 20 | 26 21 | ], 22 | [ 23 | 412, 24 | 100, 25 | 459, 26 | 115 27 | ], 28 | [ 29 | 467, 30 | 100, 31 | 490, 32 | 115 33 | ], 34 | [ 35 | 500, 36 | 100, 37 | 587, 38 | 117 39 | ], 40 | [ 41 | 152, 42 | 145, 43 | 195, 44 | 159 45 | ], 46 | [ 47 | 197, 48 | 145, 49 | 202, 50 | 159 51 | ], 52 | [ 53 | 235, 54 | 127, 55 | 328, 56 | 163 57 | ], 58 | [ 59 | 251, 60 | 163, 61 | 308, 62 | 203 63 | ], 64 | [ 65 | 322, 66 | 163, 67 | 379, 68 | 203 69 | ], 70 | [ 71 | 152, 72 | 204, 73 | 238, 74 | 219 75 | ], 76 | [ 77 | 248, 78 | 203, 79 | 326, 80 | 220 81 | ], 82 | [ 83 | 336, 84 | 202, 85 | 388, 86 | 221 87 | ], 88 | [ 89 | 405, 90 | 196, 91 | 451, 92 | 229 93 | ], 94 | [ 95 | 458, 96 | 194, 97 | 573, 98 | 225 99 | ], 100 | [ 101 | 153, 102 | 233, 103 | 185, 104 | 249 105 | ], 106 | [ 107 | 194, 108 | 233, 109 | 272, 110 | 249 111 | ], 112 | [ 113 | 283, 114 | 233, 115 | 313, 116 | 249 117 | ], 118 | [ 119 | 152, 120 | 261, 121 | 243, 122 | 279 123 | ], 124 | [ 125 | 262, 126 | 256, 127 | 290, 128 | 285 129 | ], 130 | [ 131 | 295, 132 | 253, 133 | 389, 134 | 284 135 | ], 136 | [ 137 | 152, 138 | 323, 139 | 205, 140 | 337 141 | ], 142 | [ 143 | 217, 144 | 322, 145 | 237, 146 | 337 147 | ], 148 | [ 149 | 617, 150 | 352, 151 | 704, 152 | 369 153 | ], 154 | [ 155 | 800, 156 | 351, 157 | 845, 158 | 370 159 | ], 160 | [ 161 | 151, 162 | 382, 163 | 205, 164 | 398 165 | ], 166 | [ 167 | 215, 168 | 383, 169 | 282, 170 | 398 171 | ], 172 | [ 173 | 621, 174 | 371, 175 | 666, 176 | 398 177 | ], 178 | [ 179 | 149, 180 | 412, 181 | 205, 182 | 426 183 | ], 184 | [ 185 | 216, 186 | 413, 187 | 269, 188 | 427 189 | ], 190 | [ 191 | 149, 192 | 440, 193 | 207, 194 | 456 195 | ], 196 | [ 197 | 215, 198 | 440, 199 | 279, 200 | 456 201 | ], 202 | [ 203 | 143, 204 | 464, 205 | 188, 206 | 488 207 | ], 208 | [ 209 | 209, 210 | 461, 211 | 332, 212 | 487 213 | ], 214 | [ 215 | 145, 216 | 490, 217 | 218, 218 | 518 219 | ], 220 | [ 221 | 230, 222 | 486, 223 | 322, 224 | 514 225 | ], 226 | [ 227 | 146, 228 | 518, 229 | 184, 230 | 546 231 | ], 232 | [ 233 | 198, 234 | 512, 235 | 304, 236 | 542 237 | ], 238 | [ 239 | 151, 240 | 546, 241 | 197, 242 | 575 243 | ], 244 | [ 245 | 218, 246 | 537, 247 | 297, 248 | 570 249 | ], 250 | [ 251 | 611, 252 | 547, 253 | 664, 254 | 584 255 | ], 256 | [ 257 | 234, 258 | 619, 259 | 302, 260 | 636 261 | ], 262 | [ 263 | 311, 264 | 619, 265 | 335, 266 | 636 267 | ], 268 | [ 269 | 343, 270 | 619, 271 | 400, 272 | 636 273 | ], 274 | [ 275 | 408, 276 | 619, 277 | 489, 278 | 636 279 | ], 280 | [ 281 | 495, 282 | 619, 283 | 532, 284 | 636 285 | ], 286 | [ 287 | 541, 288 | 619, 289 | 564, 290 | 637 291 | ], 292 | [ 293 | 571, 294 | 619, 295 | 661, 296 | 637 297 | ], 298 | [ 299 | 669, 300 | 618, 301 | 755, 302 | 637 303 | ], 304 | [ 305 | 874, 306 | 758, 307 | 900, 308 | 804 309 | ], 310 | [ 311 | 867, 312 | 810, 313 | 891, 314 | 848 315 | ], 316 | [ 317 | 882, 318 | 946, 319 | 887, 320 | 955 321 | ], 322 | [ 323 | 887, 324 | 947, 325 | 896, 326 | 957 327 | ], 328 | [ 329 | 223, 330 | 980, 331 | 290, 332 | 996 333 | ], 334 | [ 335 | 293, 336 | 979, 337 | 775, 338 | 996 339 | ] 340 | ], 341 | "context": [ 342 | "confidential", 343 | "..", 344 | "..", 345 | "rjrt", 346 | "pr", 347 | "approval", 348 | "date", 349 | ":", 350 | "1/8/13", 351 | "r", 352 | "alas", 353 | "proposed", 354 | "release", 355 | "date:", 356 | "for", 357 | "response", 358 | "for", 359 | "release", 360 | "to:", 361 | "contact:", 362 | "p.", 363 | "carter", 364 | "route", 365 | "to", 366 | "initials", 367 | "pate", 368 | "peggy", 369 | "carter", 370 | "ac", 371 | "maura", 372 | "payne", 373 | "david", 374 | "fishel", 375 | "tom", 376 | "griscom", 377 | "diane", 378 | "barrows", 379 | "ed", 380 | "blackmer", 381 | "tow", 382 | "rucker", 383 | "tr", 384 | "return", 385 | "to", 386 | "peggy", 387 | "carter,", 388 | "pr,", 389 | "16", 390 | "reynolds", 391 | "building", 392 | "51142", 393 | "3977", 394 | ".", 395 | ".", 396 | "source:", 397 | "https://www.industrydocuments.ucsf.edu/docs/xnb10037" 398 | ], 399 | "image_id": "xnbl0037_1", 400 | "qas": [ 401 | { 402 | "answer": [ 403 | { 404 | "answer_end": 21, 405 | "answer_start": 20, 406 | "text": "p. carter" 407 | } 408 | ], 409 | "qid": "xnbl0037_1-2", 410 | "question": "what is the contact person name mentioned in letter?" 411 | } 412 | ] 413 | } 414 | ] 415 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat in ("Cc", "Cf"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /modeling_layoutlm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import CrossEntropyLoss, MSELoss 4 | from transformers import BertPreTrainedModel 5 | from transformers.modeling_bert import ( 6 | BertEncoder, 7 | BertPooler, 8 | BertLayerNorm, 9 | ) 10 | 11 | 12 | class LayoutLMEmbeddings(nn.Module): 13 | """Construct the embeddings from word, position and token_type embeddings. 14 | """ 15 | 16 | def __init__(self, config): 17 | super(LayoutLMEmbeddings, self).__init__() 18 | #print("Word Embedding",config.vocab_size, config.hidden_size) 19 | self.word_embeddings = nn.Embedding( 20 | config.vocab_size, config.hidden_size, padding_idx=0 21 | ) 22 | self.position_embeddings = nn.Embedding( 23 | config.max_position_embeddings, config.hidden_size 24 | ) 25 | self.x_position_embeddings = nn.Embedding( 26 | config.max_2d_position_embeddings, config.hidden_size 27 | ) 28 | self.y_position_embeddings = nn.Embedding( 29 | config.max_2d_position_embeddings, config.hidden_size 30 | ) 31 | self.h_position_embeddings = nn.Embedding( 32 | config.max_2d_position_embeddings, config.hidden_size 33 | ) 34 | self.w_position_embeddings = nn.Embedding( 35 | config.max_2d_position_embeddings, config.hidden_size 36 | ) 37 | self.token_type_embeddings = nn.Embedding( 38 | config.type_vocab_size, config.hidden_size 39 | ) 40 | 41 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 42 | # any TensorFlow checkpoint file 43 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 44 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 45 | 46 | def forward( 47 | self, input_ids, bbox, token_type_ids=None, position_ids=None, 48 | ): 49 | seq_length = input_ids.size(1) 50 | if position_ids is None: 51 | position_ids = torch.arange( 52 | seq_length, dtype=torch.long, device=input_ids.device 53 | ) 54 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 55 | if token_type_ids is None: 56 | token_type_ids = torch.zeros_like(input_ids) 57 | 58 | words_embeddings = self.word_embeddings(input_ids) 59 | position_embeddings = self.position_embeddings(position_ids) 60 | left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0]) 61 | upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) 62 | right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) 63 | lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) 64 | h_position_embeddings = self.h_position_embeddings( 65 | bbox[:, :, 3] - bbox[:, :, 1] 66 | ) 67 | w_position_embeddings = self.w_position_embeddings( 68 | bbox[:, :, 2] - bbox[:, :, 0] 69 | ) 70 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 71 | #import pdb; pdb.set_trace() 72 | embeddings = ( 73 | words_embeddings 74 | + position_embeddings 75 | + left_position_embeddings 76 | + upper_position_embeddings 77 | + right_position_embeddings 78 | + lower_position_embeddings 79 | + h_position_embeddings 80 | + w_position_embeddings 81 | + token_type_embeddings 82 | ) 83 | #print("Before lYERNORMEmbeddings[i].shape") 84 | #for emb in embeddings: 85 | # print(emb.shape) 86 | embeddings = self.LayerNorm(embeddings) 87 | #print("Embeddings[i].shape") 88 | #for emb in embeddings: 89 | # print(emb.shape) 90 | embeddings = self.dropout(embeddings) 91 | return embeddings 92 | 93 | 94 | class LayoutLMModel(BertPreTrainedModel): 95 | r""" 96 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 97 | **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` 98 | Sequence of hidden-states at the output of the last layer of the model. 99 | **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` 100 | Last layer hidden-state of the first token of the sequence (classification token) 101 | further processed by a Linear layer and a Tanh activation function. The Linear 102 | layer weights are trained from the next sentence prediction (classification) 103 | objective during Bert pretraining. This output is usually *not* a good summary 104 | of the semantic content of the input, you're often better with averaging or pooling 105 | the sequence of hidden-states for the whole input sequence. 106 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 107 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 108 | of shape ``(batch_size, sequence_length, hidden_size)``: 109 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 110 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 111 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 112 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 113 | 114 | """ 115 | 116 | def __init__(self, config): 117 | super(LayoutLMModel, self).__init__(config) 118 | 119 | self.embeddings = LayoutLMEmbeddings(config) 120 | self.encoder = BertEncoder(config) 121 | self.pooler = BertPooler(config) 122 | 123 | self.init_weights() 124 | 125 | def _resize_token_embeddings(self, new_num_tokens): 126 | old_embeddings = self.embeddings.word_embeddings 127 | new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) 128 | self.embeddings.word_embeddings = new_embeddings 129 | return self.embeddings.word_embeddings 130 | 131 | def _prune_heads(self, heads_to_prune): 132 | """ Prunes heads of the model. 133 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 134 | See base class PreTrainedModel 135 | """ 136 | for layer, heads in heads_to_prune.items(): 137 | self.encoder.layer[layer].attention.prune_heads(heads) 138 | 139 | def forward( 140 | self, 141 | input_ids, 142 | bbox, 143 | attention_mask=None, 144 | token_type_ids=None, 145 | position_ids=None, 146 | head_mask=None, 147 | inputs_embeds=None, 148 | encoder_hidden_states=None, 149 | encoder_attention_mask=None, 150 | ): 151 | if attention_mask is None: 152 | attention_mask = torch.ones_like(input_ids) 153 | if token_type_ids is None: 154 | token_type_ids = torch.zeros_like(input_ids) 155 | 156 | # We create a 3D attention mask from a 2D tensor mask. 157 | # Sizes are [batch_size, 1, 1, to_seq_length] 158 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 159 | # this attention mask is more simple than the triangular masking of causal attention 160 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 161 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 162 | 163 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 164 | # masked positions, this operation will create a tensor which is 0.0 for 165 | # positions we want to attend and -10000.0 for masked positions. 166 | # Since we are adding it to the raw scores before the softmax, this is 167 | # effectively the same as removing these entirely. 168 | extended_attention_mask = extended_attention_mask.to( 169 | dtype=next(self.parameters()).dtype 170 | ) # fp16 compatibility 171 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 172 | 173 | # Prepare head mask if needed 174 | # 1.0 in head_mask indicate we keep the head 175 | # attention_probs has shape bsz x n_heads x N x N 176 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 177 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 178 | if head_mask is not None: 179 | if head_mask.dim() == 1: 180 | head_mask = ( 181 | head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 182 | ) 183 | head_mask = head_mask.expand( 184 | self.config.num_hidden_layers, -1, -1, -1, -1 185 | ) 186 | elif head_mask.dim() == 2: 187 | head_mask = ( 188 | head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) 189 | ) # We can specify head_mask for each layer 190 | head_mask = head_mask.to( 191 | dtype=next(self.parameters()).dtype 192 | ) # switch to fload if need + fp16 compatibility 193 | else: 194 | head_mask = [None] * self.config.num_hidden_layers 195 | 196 | embedding_output = self.embeddings( 197 | input_ids, bbox, position_ids=position_ids, token_type_ids=token_type_ids, 198 | ) 199 | #print("embedding_output",embedding_output.shape) 200 | encoder_outputs = self.encoder( 201 | embedding_output, extended_attention_mask, head_mask=head_mask 202 | ) 203 | #print("encoder_outputs") 204 | #print([x.shape for x in encoder_outputs]) 205 | sequence_output = encoder_outputs[0] 206 | pooled_output = self.pooler(sequence_output) 207 | #print("sequence_output=encoder_outputs[0]",sequence_output.shape) 208 | #print("pooled_output",pooled_output.shape) 209 | #print("encoder_outputs",encoder_outputs) 210 | outputs = (sequence_output, pooled_output) + encoder_outputs[ 211 | 1: 212 | ] # add hidden_states and attentions if they are here 213 | #print("Final outputs",outputs[0].shape,outputs[1].shape) 214 | return outputs # sequence_output, pooled_output, (hidden_states), (attentions) 215 | 216 | 217 | class LayoutLMForTokenClassification(BertPreTrainedModel): 218 | r""" 219 | **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 220 | Labels for computing the token classification loss. 221 | Indices should be in ``[0, ..., config.num_labels - 1]``. 222 | 223 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 224 | **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 225 | Classification loss. 226 | **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)`` 227 | Classification scores (before SoftMax). 228 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 229 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 230 | of shape ``(batch_size, sequence_length, hidden_size)``: 231 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 232 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 233 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 234 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 235 | 236 | """ 237 | 238 | def __init__(self, config): 239 | super(LayoutLMForTokenClassification, self).__init__(config) 240 | self.num_labels = config.num_labels 241 | self.bert = LayoutLMModel(config) 242 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 243 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 244 | 245 | self.init_weights() 246 | 247 | def forward( 248 | self, 249 | input_ids, 250 | bbox, 251 | attention_mask=None, 252 | token_type_ids=None, 253 | position_ids=None, 254 | head_mask=None, 255 | inputs_embeds=None, 256 | start_positions=None, 257 | end_positions=None, 258 | ): 259 | 260 | #print("Input IDs",input_ids.shape) 261 | #print("BBox",bbox.shape) 262 | outputs = self.bert( 263 | input_ids=input_ids, 264 | bbox=bbox, 265 | attention_mask=attention_mask, 266 | token_type_ids=token_type_ids, 267 | position_ids=position_ids, 268 | head_mask=head_mask, 269 | ) 270 | #print("LayoutLMModel Model Output:",outputs[0].shape,) 271 | sequence_output = outputs[0] 272 | sequence_output = self.dropout(sequence_output) 273 | logits = self.classifier(sequence_output) 274 | 275 | start_logits, end_logits = logits.split(1, dim=-1) 276 | start_logits = start_logits.squeeze(-1) 277 | end_logits = end_logits.squeeze(-1) 278 | outputs = (start_logits, end_logits,) + outputs[2:] 279 | 280 | if start_positions is not None and end_positions is not None: 281 | # If we are on multi-GPU, split add a dimension 282 | if len(start_positions.size()) > 1: 283 | start_positions = start_positions.squeeze(-1) 284 | if len(end_positions.size()) > 1: 285 | end_positions = end_positions.squeeze(-1) 286 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 287 | ignored_index = start_logits.size(1) 288 | start_positions.clamp_(0, ignored_index) 289 | end_positions.clamp_(0, ignored_index) 290 | 291 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 292 | start_loss = loss_fct(start_logits, start_positions) 293 | end_loss = loss_fct(end_logits, end_positions) 294 | total_loss = (start_loss + end_loss) / 2 295 | outputs = (total_loss,) + outputs 296 | ''' 297 | #print("logits= self.classifier(sequence_output)",logits.shape) 298 | outputs = (logits,) + outputs[ 299 | 2: 300 | ] # add hidden states and attention if they are here 301 | if labels is not None: 302 | #print('label',labels) 303 | #print('logits',logits) 304 | #print('label',labels.shape) 305 | #print('logits',logits.shape) 306 | loss_fct = CrossEntropyLoss() 307 | # Only keep active parts of the loss 308 | if attention_mask is not None: 309 | active_loss = attention_mask.view(-1) == 1 310 | active_logits = logits.view(-1, self.num_labels)[active_loss] 311 | active_labels = labels.view(-1)[active_loss] 312 | loss = loss_fct(active_logits, active_labels) 313 | else: 314 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 315 | outputs = (loss,) + outputs 316 | ''' 317 | return outputs # (loss), scores, (hidden_states), (attentions) 318 | 319 | 320 | class LayoutLMForSequenceClassification(BertPreTrainedModel): 321 | r""" 322 | **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: 323 | Labels for computing the sequence classification/regression loss. 324 | Indices should be in ``[0, ..., config.num_labels - 1]``. 325 | If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), 326 | If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). 327 | 328 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 329 | **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 330 | Classification (or regression if config.num_labels==1) loss. 331 | **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` 332 | Classification (or regression if config.num_labels==1) scores (before SoftMax). 333 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 334 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 335 | of shape ``(batch_size, sequence_length, hidden_size)``: 336 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 337 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 338 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 339 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 340 | 341 | """ 342 | 343 | def __init__(self, config): 344 | super(LayoutLMForSequenceClassification, self).__init__(config) 345 | self.num_labels = config.num_labels 346 | 347 | self.bert = LayoutLMModel(config) 348 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 349 | self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) 350 | 351 | self.init_weights() 352 | 353 | def forward( 354 | self, 355 | input_ids, 356 | bbox, 357 | attention_mask=None, 358 | token_type_ids=None, 359 | position_ids=None, 360 | head_mask=None, 361 | inputs_embeds=None, 362 | labels=None, 363 | ): 364 | 365 | outputs = self.bert( 366 | input_ids=input_ids, 367 | bbox=bbox, 368 | attention_mask=attention_mask, 369 | token_type_ids=token_type_ids, 370 | position_ids=position_ids, 371 | head_mask=head_mask, 372 | ) 373 | 374 | pooled_output = outputs[1] 375 | print("pooled_output",pooled_output.shape) 376 | pooled_output = self.dropout(pooled_output) 377 | logits = self.classifier(pooled_output) 378 | print("logits",logits.shape) 379 | outputs = (logits,) + outputs[ 380 | 2: 381 | ] # add hidden states and attention if they are here 382 | 383 | if labels is not None: 384 | if self.num_labels == 1: 385 | # We are doing regression 386 | loss_fct = MSELoss() 387 | loss = loss_fct(logits.view(-1), labels.view(-1)) 388 | else: 389 | loss_fct = CrossEntropyLoss() 390 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 391 | outputs = (loss,) + outputs 392 | 393 | return outputs # (loss), logits, (hidden_states), (attentions) 394 | -------------------------------------------------------------------------------- /utils_docvqa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Named entity recognition fine-tuning: utilities to work with CoNLL-2003 task. """ 17 | from __future__ import absolute_import, division, print_function 18 | import warnings 19 | warnings.filterwarnings('ignore') 20 | 21 | import logging 22 | import os 23 | from io import open 24 | import tensorflow as tf 25 | import json 26 | import sys 27 | import collections 28 | import six 29 | import tokenization 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | 34 | 35 | class DocvqaExample(object): 36 | """A single training/test example for token classification.""" 37 | 38 | 39 | def __init__(self, 40 | qas_id, 41 | question_text, 42 | doc_tokens, 43 | orig_answer_text=None, 44 | start_position=None, 45 | end_position=None, 46 | is_impossible=False, 47 | boxes = []): 48 | self.qas_id = qas_id 49 | self.question_text = question_text 50 | self.doc_tokens = doc_tokens 51 | self.orig_answer_text = orig_answer_text 52 | self.start_position = start_position 53 | self.end_position = end_position 54 | self.is_impossible = is_impossible 55 | self.boxes = boxes 56 | 57 | 58 | class InputFeatures(object): 59 | """A single set of features of data.""" 60 | 61 | 62 | def __init__(self, 63 | unique_id, 64 | qas_id, 65 | example_index, 66 | doc_span_index, 67 | tokens, 68 | token_to_orig_map, 69 | token_is_max_context, 70 | input_ids, 71 | input_mask, 72 | segment_ids, 73 | start_positions=None, 74 | end_positions=None, 75 | is_impossible=None, 76 | boxes = None, 77 | p_mask =None): 78 | self.unique_id = unique_id 79 | self.qas_id = qas_id 80 | self.example_index = example_index 81 | self.doc_span_index = doc_span_index 82 | self.tokens = tokens 83 | self.token_to_orig_map = token_to_orig_map 84 | self.token_is_max_context = token_is_max_context 85 | self.input_ids = input_ids 86 | self.input_mask = input_mask 87 | self.segment_ids = segment_ids 88 | self.start_positions = start_positions 89 | self.end_positions = end_positions 90 | self.is_impossible = is_impossible 91 | self.boxes = boxes 92 | self.p_mask = p_mask 93 | def read_docvqa_examples(input_file, is_training, skip_match_answers=True): 94 | """Read a SQuAD json file into a list of SquadExample.""" 95 | with tf.gfile.Open(input_file, "r") as reader: 96 | input_data = json.load(reader) 97 | 98 | def is_whitespace(c): 99 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 100 | return True 101 | return False 102 | count_match = 0 103 | count_nomatch = 0 104 | 105 | examples = [] 106 | for paragraph in input_data: 107 | image_id = paragraph["image_id"] 108 | paragraph_text = paragraph["context"] 109 | boxes = paragraph["boxes"] 110 | doc_tokens = paragraph["context"] 111 | for qa in paragraph["qas"]: 112 | qas_id = qa["qid"] 113 | question_text = qa["question"] 114 | start_position = None 115 | end_position = None 116 | orig_answer_text = None 117 | is_impossible = False 118 | answer = qa["answer"][0] 119 | orig_answer_text = answer["text"] 120 | if is_training: 121 | if not is_impossible: 122 | answer = qa["answer"][0] 123 | orig_answer_text = answer["text"] 124 | # Only add answers where the text can be exactly recovered from the 125 | # document. If this CAN'T happen it's likely due to weird Unicode 126 | # stuff so we will just skip the example. 127 | # 128 | # Note that this means for training mode, every example is NOT 129 | # guaranteed to be preserved. 130 | start_position = qa["answer"][0]["answer_start"] 131 | end_position = qa["answer"][0]["answer_end"] 132 | actual_text = " ".join( 133 | doc_tokens[start_position:(end_position + 1)]) 134 | cleaned_answer_text = " ".join( 135 | tokenization.whitespace_tokenize(orig_answer_text)) 136 | if not skip_match_answers: 137 | if actual_text.find(cleaned_answer_text) == -1: 138 | tf.logging.warning("Could not find answer: '%s' vs. '%s'", 139 | actual_text, cleaned_answer_text) 140 | count_nomatch+=1 141 | continue 142 | count_match+=1 143 | else: 144 | start_position = -1 145 | end_position = -1 146 | orig_answer_text = "" 147 | 148 | example =DocvqaExample( 149 | qas_id=qas_id, 150 | question_text=question_text, 151 | doc_tokens=doc_tokens, 152 | orig_answer_text=orig_answer_text, 153 | start_position=start_position, 154 | end_position=end_position, 155 | is_impossible=is_impossible, 156 | boxes=boxes) 157 | examples.append(example) 158 | return examples 159 | 160 | 161 | def convert_examples_to_features(examples,label_list, tokenizer, max_seq_length, 162 | doc_stride, max_query_length, is_training, 163 | pad_token_label_id=-100): 164 | """Loads a data file into a list of `InputBatch`s.""" 165 | 166 | unique_id = 1000000000 167 | features = [] 168 | label_map = {label: i for i, label in enumerate(label_list)} 169 | query_label_ids = [] 170 | for (example_index, example) in enumerate(examples): 171 | query_tokens = tokenizer.tokenize(example.question_text) 172 | if len(query_tokens) > max_query_length: 173 | query_tokens = query_tokens[0:max_query_length] 174 | query_label_ids=[0]+[pad_token_label_id] * (len(query_tokens) - 1) 175 | 176 | tok_to_orig_index = [] 177 | orig_to_tok_index = [] 178 | all_doc_tokens = [] 179 | all_doc_boxes_tokens = [] 180 | cls_token_box=[0, 0, 0, 0] 181 | sep_token_box=[1000, 1000, 1000, 1000] 182 | pad_token_box=[0, 0, 0, 0] 183 | ques_token_box=[0, 0, 0, 0] 184 | all_label_ids = [] 185 | for (i, token) in enumerate(example.doc_tokens): 186 | orig_to_tok_index.append(len(all_doc_tokens)) 187 | sub_tokens = tokenizer.tokenize(token) 188 | box = example.boxes[i] 189 | if i == example.start_position: 190 | lab = 1 191 | elif i == example.end_position: 192 | lab = 2 193 | else: 194 | lab = 0 195 | p = [lab] + [pad_token_label_id] * (len(sub_tokens) - 1) 196 | all_label_ids+=p 197 | for sub_token in sub_tokens: 198 | tok_to_orig_index.append(i) 199 | all_doc_tokens.append(sub_token) 200 | all_doc_boxes_tokens.append(box) 201 | #p = [lab] + [pad_token_label_id] * (len(sub_tokens) - 1) 202 | #all_label_ids+=p 203 | 204 | 205 | 206 | tok_start_position = None 207 | tok_end_position = None 208 | if is_training and example.is_impossible: 209 | tok_start_position = -1 210 | tok_end_position = -1 211 | if is_training and not example.is_impossible: 212 | tok_start_position = orig_to_tok_index[example.start_position] 213 | if example.end_position < len(example.doc_tokens) - 1: 214 | tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 215 | else: 216 | tok_end_position = len(all_doc_tokens) - 1 217 | (tok_start_position, tok_end_position) = _improve_answer_span( 218 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer, 219 | example.orig_answer_text) 220 | 221 | # The -3 accounts for [CLS], [SEP] and [SEP] 222 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 223 | 224 | # We can have documents that are longer than the maximum sequence length. 225 | # To deal with this we do a sliding window approach, where we take chunks 226 | # of the up to our max length with a stride of `doc_stride`. 227 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name 228 | "DocSpan", ["start", "length"]) 229 | doc_spans = [] 230 | start_offset = 0 231 | while start_offset < len(all_doc_tokens): 232 | length = len(all_doc_tokens) - start_offset 233 | if length > max_tokens_for_doc: 234 | length = max_tokens_for_doc 235 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 236 | if start_offset + length == len(all_doc_tokens): 237 | break 238 | start_offset += min(length, doc_stride) 239 | 240 | #TODO Remove later 241 | #if len(doc_spans)>1: 242 | # continue 243 | 244 | for (doc_span_index, doc_span) in enumerate(doc_spans): 245 | tokens = [] 246 | boxes_tokens = [] 247 | label_ids = [] 248 | token_to_orig_map = {} 249 | token_is_max_context = {} 250 | segment_ids = [] 251 | p_mask = [] 252 | tokens.append("[CLS]") 253 | p_mask.append(0) 254 | boxes_tokens.append(cls_token_box) 255 | segment_ids.append(0) 256 | label_ids.append(0) 257 | for token in query_tokens: 258 | tokens.append(token) 259 | boxes_tokens.append(ques_token_box) 260 | segment_ids.append(0) 261 | p_mask.append(1) 262 | label_ids=label_ids+query_label_ids 263 | tokens.append("[SEP]") 264 | p_mask.append(1) 265 | boxes_tokens.append(sep_token_box) 266 | segment_ids.append(0) 267 | label_ids.append(0) 268 | for i in range(doc_span.length): 269 | split_token_index = doc_span.start + i 270 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] 271 | 272 | is_max_context = _check_is_max_context(doc_spans, doc_span_index, 273 | split_token_index) 274 | token_is_max_context[len(tokens)] = is_max_context 275 | tokens.append(all_doc_tokens[split_token_index]) 276 | label_ids.append(all_label_ids[split_token_index]) 277 | boxes_tokens.append(all_doc_boxes_tokens[split_token_index]) 278 | segment_ids.append(1) 279 | p_mask.append(0) 280 | tokens.append("[SEP]") 281 | p_mask.append(1) 282 | boxes_tokens.append(sep_token_box) 283 | segment_ids.append(1) 284 | label_ids.append(pad_token_label_id) 285 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 286 | 287 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 288 | # tokens are attended to. 289 | input_mask = [1] * len(input_ids) 290 | # Zero-pad up to the sequence length. 291 | while len(input_ids) < max_seq_length: 292 | input_ids.append(0) 293 | input_mask.append(0) 294 | segment_ids.append(0) 295 | boxes_tokens.append(pad_token_box) 296 | label_ids.append(pad_token_label_id) 297 | p_mask.append(1) 298 | 299 | assert len(input_ids) == max_seq_length 300 | assert len(input_mask) == max_seq_length 301 | assert len(segment_ids) == max_seq_length 302 | assert len(boxes_tokens) == max_seq_length 303 | assert len(label_ids) == max_seq_length 304 | assert len(p_mask) == max_seq_length 305 | 306 | start_position = None 307 | end_position = None 308 | if is_training and not example.is_impossible: 309 | # For training, if our document chunk does not contain an annotation 310 | # we throw it out, since there is nothing to predict. 311 | doc_start = doc_span.start 312 | doc_end = doc_span.start + doc_span.length - 1 313 | out_of_span = False 314 | if not (tok_start_position >= doc_start and 315 | tok_end_position <= doc_end): 316 | out_of_span = True 317 | if out_of_span: 318 | start_position = 0 319 | end_position = 0 320 | else: 321 | doc_offset = len(query_tokens) + 2 322 | start_position = tok_start_position - doc_start + doc_offset 323 | end_position = tok_end_position - doc_start + doc_offset 324 | 325 | if is_training and example.is_impossible: 326 | start_position = 0 327 | end_position = 0 328 | #label_ids = [-1]*max_seq_length 329 | #if is_training and (start_position!=0 and end_position!=0): 330 | # label_ids[start_position]=0 331 | # label_ids[end_position]=1 332 | 333 | if example_index < 20: 334 | tf.logging.info("*** Example ***") 335 | tf.logging.info("unique_id: %s" % (unique_id)) 336 | tf.logging.info("example_index: %s" % (example_index)) 337 | tf.logging.info("doc_span_index: %s" % (doc_span_index)) 338 | tf.logging.info("tokens: %s" % " ".join( 339 | [tokenization.printable_text(x) for x in tokens])) 340 | tf.logging.info("token_to_orig_map: %s" % " ".join( 341 | ["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)])) 342 | tf.logging.info("token_is_max_context: %s" % " ".join([ 343 | "%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context) 344 | ])) 345 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 346 | tf.logging.info( 347 | "input_mask: %s" % " ".join([str(x) for x in input_mask])) 348 | tf.logging.info( 349 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 350 | if is_training and example.is_impossible: 351 | tf.logging.info("impossible example") 352 | if is_training and not example.is_impossible: 353 | answer_text = " ".join(tokens[start_position:(end_position + 1)]) 354 | tf.logging.info("start_position: %d" % (start_position)) 355 | tf.logging.info("end_position: %d" % (end_position)) 356 | tf.logging.info( 357 | "answer: %s" % (tokenization.printable_text(answer_text))) 358 | feature = InputFeatures( 359 | unique_id=unique_id, 360 | qas_id=example.qas_id, 361 | example_index=example_index, 362 | doc_span_index=doc_span_index, 363 | tokens=tokens, 364 | token_to_orig_map=token_to_orig_map, 365 | token_is_max_context=token_is_max_context, 366 | input_ids=input_ids, 367 | input_mask=input_mask, 368 | segment_ids=segment_ids, 369 | start_positions=start_position, 370 | end_positions=end_position, 371 | is_impossible=example.is_impossible, 372 | boxes=boxes_tokens, 373 | p_mask = p_mask, 374 | ) 375 | features.append(feature) 376 | ''' 377 | print(feature) 378 | print('unique_id',feature.unique_id) 379 | print('tokens',feature.tokens) 380 | print('example_index',feature.example_index) 381 | print('input_ids',feature.input_ids) 382 | print('segment_ids',feature.segment_ids) 383 | print('doc_span_index',feature.doc_span_index) 384 | print('token_to_orig_map',feature.token_to_orig_map) 385 | print('token_is_max_contex',feature.token_is_max_context) 386 | print('input_mask',feature.input_mask) 387 | print('start_position',feature.start_position) 388 | print('end_position',feature.end_position) 389 | print('is_impossible',feature.is_impossible)''' 390 | # Run callback 391 | #output_fn(feature) 392 | 393 | unique_id += 1 394 | return features 395 | 396 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, 397 | orig_answer_text): 398 | """Returns tokenized answer spans that better match the annotated answer.""" 399 | 400 | # The SQuAD annotations are character based. We first project them to 401 | # whitespace-tokenized words. But then after WordPiece tokenization, we can 402 | # often find a "better match". For example: 403 | # 404 | # Question: What year was John Smith born? 405 | # Context: The leader was John Smith (1895-1943). 406 | # Answer: 1895 407 | # 408 | # The original whitespace-tokenized answer will be "(1895-1943).". However 409 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match 410 | # the exact answer, 1895. 411 | # 412 | # However, this is not always possible. Consider the following: 413 | # 414 | # Question: What country is the top exporter of electornics? 415 | # Context: The Japanese electronics industry is the lagest in the world. 416 | # Answer: Japan 417 | # 418 | # In this case, the annotator chose "Japan" as a character sub-span of 419 | # the word "Japanese". Since our WordPiece tokenizer does not split 420 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare 421 | # in SQuAD, but does happen. 422 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 423 | for new_start in range(input_start, input_end + 1): 424 | for new_end in range(input_end, new_start - 1, -1): 425 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) 426 | if text_span == tok_answer_text: 427 | return (new_start, new_end) 428 | 429 | return (input_start, input_end) 430 | 431 | 432 | def _check_is_max_context(doc_spans, cur_span_index, position): 433 | """Check if this is the 'max context' doc span for the token.""" 434 | 435 | # Because of the sliding window approach taken to scoring documents, a single 436 | # token can appear in multiple documents. E.g. 437 | # Doc: the man went to the store and bought a gallon of milk 438 | # Span A: the man went to the 439 | # Span B: to the store and bought 440 | # Span C: and bought a gallon of 441 | # ... 442 | # 443 | # Now the word 'bought' will have two scores from spans B and C. We only 444 | # want to consider the score with "maximum context", which we define as 445 | # the *minimum* of its left and right context (the *sum* of left and 446 | # right context will always be the same, of course). 447 | # 448 | # In the example the maximum context for 'bought' would be span C since 449 | # it has 1 left context and 3 right context, while span B has 4 left context 450 | # and 0 right context. 451 | best_score = None 452 | best_span_index = None 453 | for (span_index, doc_span) in enumerate(doc_spans): 454 | end = doc_span.start + doc_span.length - 1 455 | if position < doc_span.start: 456 | continue 457 | if position > end: 458 | continue 459 | num_left_context = position - doc_span.start 460 | num_right_context = end - position 461 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 462 | if best_score is None or score > best_score: 463 | best_score = score 464 | best_span_index = span_index 465 | 466 | return cur_span_index == best_span_index 467 | -------------------------------------------------------------------------------- /run_docvqa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Fine-tuning the library models for named entity recognition on CoNLL-2003 (Bert or Roberta). """ 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import argparse 21 | import glob 22 | import logging 23 | import os 24 | import shutil 25 | import random 26 | import json 27 | import timeit 28 | 29 | import numpy as np 30 | import torch 31 | 32 | from tensorboardX import SummaryWriter 33 | from torch.nn import CrossEntropyLoss 34 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset 35 | from torch.utils.data.distributed import DistributedSampler 36 | from tqdm import tqdm, trange 37 | from utils_docvqa import ( 38 | read_docvqa_examples, 39 | convert_examples_to_features) 40 | import sys 41 | from modeling_layoutlm import LayoutLMForTokenClassification 42 | from transformers import AdamW, get_linear_schedule_with_warmup 43 | from transformers import ( 44 | WEIGHTS_NAME, 45 | BertConfig, 46 | BertForTokenClassification, 47 | BertTokenizer, 48 | ) 49 | from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer 50 | from transformers import ( 51 | DistilBertConfig, 52 | DistilBertForTokenClassification, 53 | DistilBertTokenizer, 54 | ) 55 | from transformers.data.metrics.squad_metrics import ( 56 | compute_predictions_log_probs, 57 | compute_predictions_logits, 58 | squad_evaluate, 59 | ) 60 | from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor 61 | 62 | logger = logging.getLogger(__name__) 63 | 64 | 65 | 66 | MODEL_CLASSES = { 67 | "layoutlm": (BertConfig, LayoutLMForTokenClassification, BertTokenizer), 68 | } 69 | 70 | 71 | def set_seed(args): 72 | random.seed(args.seed) 73 | np.random.seed(args.seed) 74 | torch.manual_seed(args.seed) 75 | if args.n_gpu > 0: 76 | torch.cuda.manual_seed_all(args.seed) 77 | 78 | 79 | def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): 80 | """ Train the model """ 81 | if args.local_rank in [-1, 0]: 82 | tb_writer = SummaryWriter() 83 | 84 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 85 | train_sampler = ( 86 | RandomSampler(train_dataset) 87 | if args.local_rank == -1 88 | else DistributedSampler(train_dataset) 89 | ) 90 | train_dataloader = DataLoader( 91 | train_dataset, sampler=train_sampler, batch_size=args.train_batch_size 92 | ) 93 | 94 | if args.max_steps > 0: 95 | t_total = args.max_steps 96 | args.num_train_epochs = ( 97 | args.max_steps 98 | // (len(train_dataloader) // args.gradient_accumulation_steps) 99 | + 1 100 | ) 101 | else: 102 | t_total = ( 103 | len(train_dataloader) 104 | // args.gradient_accumulation_steps 105 | * args.num_train_epochs 106 | ) 107 | 108 | # Prepare optimizer and schedule (linear warmup and decay) 109 | no_decay = ["bias", "LayerNorm.weight"] 110 | optimizer_grouped_parameters = [ 111 | { 112 | "params": [ 113 | p 114 | for n, p in model.named_parameters() 115 | if not any(nd in n for nd in no_decay) 116 | ], 117 | "weight_decay": args.weight_decay, 118 | }, 119 | { 120 | "params": [ 121 | p 122 | for n, p in model.named_parameters() 123 | if any(nd in n for nd in no_decay) 124 | ], 125 | "weight_decay": 0.0, 126 | }, 127 | ] 128 | optimizer = AdamW( 129 | optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon 130 | ) 131 | scheduler = get_linear_schedule_with_warmup( 132 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 133 | ) 134 | if args.fp16: 135 | try: 136 | from apex import amp 137 | except ImportError: 138 | raise ImportError( 139 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." 140 | ) 141 | model, optimizer = amp.initialize( 142 | model, optimizer, opt_level=args.fp16_opt_level 143 | ) 144 | 145 | # multi-gpu training (should be after apex fp16 initialization) 146 | if args.n_gpu > 1: 147 | model = torch.nn.DataParallel(model) 148 | 149 | # Distributed training (should be after apex fp16 initialization) 150 | if args.local_rank != -1: 151 | model = torch.nn.parallel.DistributedDataParallel( 152 | model, 153 | device_ids=[args.local_rank], 154 | output_device=args.local_rank, 155 | find_unused_parameters=True, 156 | ) 157 | 158 | # Train! 159 | logger.info("***** Running training *****") 160 | logger.info(" Num examples = %d", len(train_dataset)) 161 | logger.info(" Num Epochs = %d", args.num_train_epochs) 162 | logger.info( 163 | " Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size 164 | ) 165 | logger.info( 166 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 167 | args.train_batch_size 168 | * args.gradient_accumulation_steps 169 | * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), 170 | ) 171 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 172 | logger.info(" Total optimization steps = %d", t_total) 173 | 174 | global_step = 0 175 | tr_loss, logging_loss = 0.0, 0.0 176 | model.zero_grad() 177 | train_iterator = trange( 178 | int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0] 179 | ) 180 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 181 | for _ in train_iterator: 182 | epoch_iterator = tqdm( 183 | train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0] 184 | ) 185 | for step, batch in enumerate(epoch_iterator): 186 | model.train() 187 | if args.model_type != "layoutlm": 188 | batch = batch[:4] 189 | batch = tuple(t.to(args.device) for t in batch) 190 | inputs = { 191 | "input_ids": batch[0], 192 | "attention_mask": batch[1], 193 | "start_positions": batch[3], 194 | "end_positions":batch[4], 195 | } 196 | 197 | if args.model_type == "layoutlm": 198 | inputs["bbox"] = batch[5] 199 | if args.model_type != "distilbert": 200 | inputs["token_type_ids"] = ( 201 | batch[2] 202 | if args.model_type in ["bert", "xlnet", "layoutlm"] 203 | else None 204 | ) # XLM and RoBERTa don"t use segment_ids 205 | outputs = model(**inputs) 206 | loss = outputs[0] 207 | # model outputs are always tuple in pytorch-transformers (see doc) 208 | 209 | if args.n_gpu > 1: 210 | loss = loss.mean() # mean() to average on multi-gpu parallel training 211 | if args.gradient_accumulation_steps > 1: 212 | loss = loss / args.gradient_accumulation_steps 213 | 214 | if args.fp16: 215 | with amp.scale_loss(loss, optimizer) as scaled_loss: 216 | scaled_loss.backward() 217 | else: 218 | loss.backward() 219 | 220 | tr_loss += loss.item() 221 | if (step + 1) % args.gradient_accumulation_steps == 0: 222 | if args.fp16: 223 | torch.nn.utils.clip_grad_norm_( 224 | amp.master_params(optimizer), args.max_grad_norm 225 | ) 226 | else: 227 | torch.nn.utils.clip_grad_norm_( 228 | model.parameters(), args.max_grad_norm 229 | ) 230 | 231 | scheduler.step() # Update learning rate schedule 232 | optimizer.step() 233 | model.zero_grad() 234 | global_step += 1 235 | 236 | if ( 237 | args.local_rank in [-1, 0] 238 | and args.logging_steps > 0 239 | and global_step % args.logging_steps == 0 240 | ): 241 | # Log metrics 242 | if ( 243 | args.local_rank == -1 and args.evaluate_during_training 244 | ): # Only evaluate when single GPU otherwise metrics may not average well 245 | results = evaluate( 246 | args, 247 | model, 248 | tokenizer, 249 | labels, 250 | pad_token_label_id, 251 | mode="dev", 252 | ) 253 | for key, value in results.items(): 254 | tb_writer.add_scalar( 255 | "eval_{}".format(key), value, global_step 256 | ) 257 | tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) 258 | tb_writer.add_scalar( 259 | "loss", 260 | (tr_loss - logging_loss) / args.logging_steps, 261 | global_step, 262 | ) 263 | logging_loss = tr_loss 264 | 265 | if ( 266 | args.local_rank in [-1, 0] 267 | and args.save_steps > 0 268 | and global_step % args.save_steps == 0 269 | ): 270 | # Save model checkpoint 271 | output_dir = os.path.join( 272 | args.output_dir, "checkpoint-{}".format(global_step) 273 | ) 274 | if not os.path.exists(output_dir): 275 | os.makedirs(output_dir) 276 | model_to_save = ( 277 | model.module if hasattr(model, "module") else model 278 | ) # Take care of distributed/parallel training 279 | model_to_save.save_pretrained(output_dir) 280 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 281 | logger.info("Saving model checkpoint to %s", output_dir) 282 | 283 | if args.max_steps > 0 and global_step > args.max_steps: 284 | epoch_iterator.close() 285 | break 286 | if args.max_steps > 0 and global_step > args.max_steps: 287 | train_iterator.close() 288 | break 289 | 290 | if args.local_rank in [-1, 0]: 291 | tb_writer.close() 292 | 293 | return global_step, tr_loss / global_step 294 | 295 | def to_list(tensor): 296 | return tensor.detach().cpu().tolist() 297 | def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""): 298 | eval_dataset,features, examples = load_and_cache_examples( 299 | args, tokenizer, labels, pad_token_label_id, mode=mode 300 | ) 301 | 302 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 303 | # Note that DistributedSampler samples randomly 304 | eval_sampler = ( 305 | SequentialSampler(eval_dataset) 306 | if args.local_rank == -1 307 | else DistributedSampler(eval_dataset) 308 | ) 309 | eval_dataloader = DataLoader( 310 | eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size 311 | ) 312 | 313 | # Eval! 314 | logger.info("***** Running evaluation %s *****", prefix) 315 | logger.info(" Num examples = %d", len(eval_dataset)) 316 | logger.info(" Batch size = %d", args.eval_batch_size) 317 | eval_loss = 0.0 318 | nb_eval_steps = 0 319 | preds = None 320 | out_label_ids = None 321 | all_results = [] 322 | start_time = timeit.default_timer() 323 | model.eval() 324 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 325 | batch = tuple(t.to(args.device) for t in batch) 326 | with torch.no_grad(): 327 | inputs = { 328 | "input_ids": batch[0], 329 | "attention_mask": batch[1], 330 | } 331 | inputs["bbox"] = batch[5] 332 | inputs["token_type_ids"] = (batch[6]) 333 | outputs = model(**inputs) 334 | example_indices = batch[7] 335 | for i, example_index in enumerate(example_indices): 336 | eval_feature = features[example_index.item()] 337 | unique_id = int(eval_feature.unique_id) 338 | 339 | output = [to_list(output[i]) for output in outputs] 340 | 341 | start_logits, end_logits = output 342 | result = SquadResult(unique_id, start_logits, end_logits) 343 | all_results.append(result) 344 | evalTime = timeit.default_timer() - start_time 345 | logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(eval_dataset)) 346 | output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix)) 347 | output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix)) 348 | output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix)) 349 | predictions = compute_predictions_logits( 350 | examples, 351 | features, 352 | all_results, 353 | 20, 354 | 30, 355 | args.do_lower_case, 356 | output_prediction_file, 357 | output_nbest_file, 358 | output_null_log_odds_file, 359 | True, 360 | True, 361 | 0.0, 362 | tokenizer, 363 | ) 364 | 365 | # Compute the F1 and exact scores. 366 | results = squad_evaluate(examples, predictions) 367 | return results 368 | 369 | def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode): 370 | if args.local_rank not in [-1, 0] and not evaluate: 371 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 372 | 373 | # Load data features from cache or dataset file 374 | cached_features_file = os.path.join( 375 | args.data_dir, 376 | "cached_{}_{}_{}".format( 377 | mode, 378 | list(filter(None, args.model_name_or_path.split("/"))).pop(), 379 | str(args.max_seq_length), 380 | ), 381 | ) 382 | if mode == 'train': 383 | train_examples = read_docvqa_examples(args.train_json,is_training=True,skip_match_answers=args.skip_match_answers) 384 | logger.info("Loading train json from %s", args.train_json) 385 | else: 386 | train_examples = read_docvqa_examples(args.val_json,is_training=True,skip_match_answers=args.skip_match_answers) 387 | logger.info("Loading val json from %s", args.val_json) 388 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 389 | logger.info("Loading features from cached file %s", cached_features_file) 390 | features = torch.load(cached_features_file) 391 | else: 392 | logger.info("Creating features from dataset file at %s", args.data_dir) 393 | if mode == 'train': 394 | train_examples = read_docvqa_examples(args.train_json,is_training=True,skip_match_answers=args.skip_match_answers) 395 | 396 | else: 397 | train_examples = read_docvqa_examples(args.val_json,is_training=True,skip_match_answers=args.skip_match_answers) 398 | max_query_length = 64 399 | doc_stride = args.doc_stride 400 | features = convert_examples_to_features( 401 | examples=train_examples, 402 | label_list=labels, 403 | tokenizer=tokenizer, 404 | max_seq_length=args.max_seq_length, 405 | doc_stride=doc_stride, 406 | max_query_length=max_query_length, 407 | is_training=True, 408 | pad_token_label_id=pad_token_label_id) 409 | print("Features generated",mode) 410 | if args.local_rank in [-1, 0]: 411 | logger.info("Saving features into cached file %s", cached_features_file) 412 | torch.save(features, cached_features_file) 413 | 414 | if args.local_rank == 0 and not evaluate: 415 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 416 | 417 | # Convert to Tensors and build dataset 418 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 419 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 420 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 421 | all_bboxes = torch.tensor([f.boxes for f in features], dtype=torch.long) 422 | all_start_positions = torch.tensor([f.start_positions for f in features], dtype=torch.long) 423 | all_end_positions = torch.tensor([f.end_positions for f in features], dtype=torch.long) 424 | all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.long) 425 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 426 | dataset = TensorDataset( 427 | all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions, all_bboxes, all_p_mask,all_example_index 428 | ) 429 | return dataset, features, train_examples 430 | 431 | 432 | def main(): 433 | parser = argparse.ArgumentParser() 434 | 435 | ## Required parameters 436 | parser.add_argument( 437 | "--data_dir", 438 | default=None, 439 | type=str, 440 | required=True, 441 | help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.", 442 | ) 443 | parser.add_argument( 444 | "--model_type", 445 | default=None, 446 | type=str, 447 | required=True, 448 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), 449 | ) 450 | parser.add_argument( 451 | "--model_name_or_path", 452 | default=None, 453 | type=str, 454 | required=True, 455 | help="Path to pre-trained model ", 456 | ) 457 | parser.add_argument( 458 | "--output_dir", 459 | default=None, 460 | type=str, 461 | required=True, 462 | help="The output directory where the model predictions and checkpoints will be written.", 463 | ) 464 | 465 | ## Other parameters 466 | parser.add_argument( 467 | "--config_name", 468 | default="", 469 | type=str, 470 | help="Pretrained config name or path if not the same as model_name", 471 | ) 472 | parser.add_argument( 473 | "--tokenizer_name", 474 | default="", 475 | type=str, 476 | help="Pretrained tokenizer name or path if not the same as model_name", 477 | ) 478 | parser.add_argument( 479 | "--cache_dir", 480 | default="", 481 | type=str, 482 | help="Where do you want to store the pre-trained models downloaded from s3", 483 | ) 484 | parser.add_argument( 485 | "--max_seq_length", 486 | default=128, 487 | type=int, 488 | help="The maximum total input sequence length after tokenization. Sequences longer " 489 | "than this will be truncated, sequences shorter will be padded.", 490 | ) 491 | parser.add_argument( 492 | "--doc_stride", 493 | default=128, 494 | type=int, 495 | help="Stride for documents with tokens more than max_seq_len", 496 | ) 497 | parser.add_argument( 498 | "--do_train", action="store_true", help="Whether to run training." 499 | ) 500 | parser.add_argument( 501 | "--do_eval", action="store_true", help="Whether to run eval on the dev set." 502 | ) 503 | parser.add_argument( 504 | "--do_predict", 505 | action="store_true", 506 | help="Whether to run predictions on the test set.", 507 | ) 508 | parser.add_argument( 509 | "--evaluate_during_training", 510 | action="store_true", 511 | help="Whether to run evaluation during training at each logging step.", 512 | ) 513 | parser.add_argument( 514 | "--do_lower_case", 515 | action="store_true", 516 | help="Set this flag if you are using an uncased model.", 517 | ) 518 | 519 | parser.add_argument( 520 | "--per_gpu_train_batch_size", 521 | default=8, 522 | type=int, 523 | help="Batch size per GPU/CPU for training.", 524 | ) 525 | parser.add_argument( 526 | "--per_gpu_eval_batch_size", 527 | default=8, 528 | type=int, 529 | help="Batch size per GPU/CPU for evaluation.", 530 | ) 531 | parser.add_argument( 532 | "--gradient_accumulation_steps", 533 | type=int, 534 | default=1, 535 | help="Number of updates steps to accumulate before performing a backward/update pass.", 536 | ) 537 | parser.add_argument( 538 | "--learning_rate", 539 | default=5e-5, 540 | type=float, 541 | help="The initial learning rate for Adam.", 542 | ) 543 | parser.add_argument( 544 | "--weight_decay", default=0.0, type=float, help="Weight decay if we apply some." 545 | ) 546 | parser.add_argument( 547 | "--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer." 548 | ) 549 | parser.add_argument( 550 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." 551 | ) 552 | parser.add_argument( 553 | "--num_train_epochs", 554 | default=3.0, 555 | type=float, 556 | help="Total number of training epochs to perform.", 557 | ) 558 | parser.add_argument( 559 | "--max_steps", 560 | default=-1, 561 | type=int, 562 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.", 563 | ) 564 | parser.add_argument( 565 | "--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps." 566 | ) 567 | 568 | parser.add_argument( 569 | "--logging_steps", type=int, default=50, help="Log every X updates steps." 570 | ) 571 | parser.add_argument( 572 | "--save_steps", 573 | type=int, 574 | default=50, 575 | help="Save checkpoint every X updates steps.", 576 | ) 577 | parser.add_argument( 578 | "--eval_all_checkpoints", 579 | action="store_true", 580 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number", 581 | ) 582 | parser.add_argument( 583 | "--no_cuda", action="store_true", help="Avoid using CUDA when available" 584 | ) 585 | parser.add_argument( 586 | "--overwrite_output_dir", 587 | action="store_true", 588 | help="Overwrite the content of the output directory", 589 | ) 590 | parser.add_argument( 591 | "--overwrite_cache", 592 | action="store_true", 593 | help="Overwrite the cached training and evaluation sets", 594 | ) 595 | parser.add_argument( 596 | "--seed", type=int, default=42, help="random seed for initialization" 597 | ) 598 | 599 | parser.add_argument( 600 | "--fp16", 601 | action="store_true", 602 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 603 | ) 604 | parser.add_argument( 605 | "--fp16_opt_level", 606 | type=str, 607 | default="O1", 608 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 609 | "See details at https://nvidia.github.io/apex/amp.html", 610 | ) 611 | parser.add_argument( 612 | "--local_rank", 613 | type=int, 614 | default=-1, 615 | help="For distributed training: local_rank", 616 | ) 617 | parser.add_argument( 618 | "--server_ip", type=str, default="", help="For distant debugging." 619 | ) 620 | parser.add_argument( 621 | "--server_port", type=str, default="", help="For distant debugging." 622 | ) 623 | parser.add_argument( 624 | "--train_json", type=str 625 | ) 626 | parser.add_argument( 627 | "--val_json", type=str 628 | ) 629 | parser.add_argument( 630 | "--skip_match_answers", 631 | action="store_true", 632 | help="Whether to match OCR start and end index to groundtruth answers", 633 | ) 634 | args = parser.parse_args() 635 | 636 | if ( 637 | os.path.exists(args.output_dir) 638 | and os.listdir(args.output_dir) 639 | and args.do_train 640 | and not args.overwrite_output_dir 641 | ): 642 | raise ValueError( 643 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( 644 | args.output_dir 645 | ) 646 | ) 647 | 648 | # Setup distant debugging if needed 649 | if args.server_ip and args.server_port: 650 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 651 | import ptvsd 652 | 653 | print("Waiting for debugger attach") 654 | ptvsd.enable_attach( 655 | address=(args.server_ip, args.server_port), redirect_output=True 656 | ) 657 | ptvsd.wait_for_attach() 658 | 659 | # Setup CUDA, GPU & distributed training 660 | if args.local_rank == -1 or args.no_cuda: 661 | device = torch.device( 662 | "cuda:0" if torch.cuda.is_available() and not args.no_cuda else "cpu" 663 | ) 664 | torch.cuda.set_device(device) 665 | args.n_gpu = torch.cuda.device_count() 666 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 667 | torch.cuda.set_device(args.local_rank) 668 | device = torch.device("cuda", args.local_rank) 669 | torch.distributed.init_process_group(backend="nccl") 670 | args.n_gpu = 1 671 | args.device = device 672 | 673 | if args.overwrite_output_dir and os.path.exists(args.output_dir): 674 | shutil.rmtree(args.output_dir) 675 | if not os.path.exists(args.output_dir): 676 | os.makedirs(args.output_dir) 677 | # Setup logging 678 | logging.basicConfig( 679 | filename=os.path.join(args.output_dir, "train.log"), 680 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 681 | datefmt="%m/%d/%Y %H:%M:%S", 682 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 683 | ) 684 | logger.warning( 685 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 686 | args.local_rank, 687 | device, 688 | args.n_gpu, 689 | bool(args.local_rank != -1), 690 | args.fp16, 691 | ) 692 | 693 | # Set seed 694 | set_seed(args) 695 | 696 | # Prepare CONLL-2003 task 697 | labels = ["start","end"] 698 | num_labels = len(labels) 699 | # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later 700 | pad_token_label_id = CrossEntropyLoss().ignore_index 701 | 702 | # Load pretrained model and tokenizer 703 | if args.local_rank not in [-1, 0]: 704 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 705 | 706 | args.model_type = args.model_type.lower() 707 | print("ARGS",args) 708 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 709 | print("Config_name",args.config_name) 710 | 711 | config = config_class.from_pretrained( 712 | args.config_name if args.config_name else args.model_name_or_path, 713 | num_labels=num_labels, 714 | cache_dir=args.cache_dir if args.cache_dir else None, 715 | ) 716 | tokenizer = tokenizer_class.from_pretrained( 717 | args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 718 | do_lower_case=args.do_lower_case, 719 | cache_dir=args.cache_dir if args.cache_dir else None, 720 | ) 721 | model = model_class.from_pretrained( 722 | args.model_name_or_path, 723 | from_tf=bool(".ckpt" in args.model_name_or_path), 724 | config=config, 725 | cache_dir=args.cache_dir if args.cache_dir else None, 726 | ) 727 | 728 | if args.local_rank == 0: 729 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 730 | 731 | model.to(args.device) 732 | 733 | logger.info("Training/evaluation parameters %s", args) 734 | 735 | # Training 736 | if args.do_train: 737 | print("tokenizer",tokenizer) 738 | print("pad_token_label_id",pad_token_label_id) 739 | print("labels",labels) 740 | train_dataset,_,_ = load_and_cache_examples( 741 | args, tokenizer, labels, pad_token_label_id, mode="train" 742 | ) 743 | global_step, tr_loss = train( 744 | args, train_dataset, model, tokenizer, labels, pad_token_label_id 745 | ) 746 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 747 | # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() 748 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 749 | # Create output directory if needed 750 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 751 | os.makedirs(args.output_dir) 752 | 753 | logger.info("Saving model checkpoint to %s", args.output_dir) 754 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 755 | # They can then be reloaded using `from_pretrained()` 756 | model_to_save = ( 757 | model.module if hasattr(model, "module") else model 758 | ) # Take care of distributed/parallel training 759 | model_to_save.save_pretrained(args.output_dir) 760 | tokenizer.save_pretrained(args.output_dir) 761 | 762 | # Good practice: save your training arguments together with the trained model 763 | torch.save(args, os.path.join(args.output_dir, "training_args.bin")) 764 | 765 | # Evaluation 766 | results = {} 767 | if args.do_eval and args.local_rank in [-1, 0]: 768 | tokenizer = tokenizer_class.from_pretrained( 769 | 'data_docvqa_train_test', do_lower_case=args.do_lower_case 770 | ) 771 | checkpoints = [args.output_dir] 772 | if args.eval_all_checkpoints: 773 | checkpoints = list( 774 | os.path.dirname(c) 775 | for c in sorted( 776 | glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True) 777 | ) 778 | ) 779 | logging.getLogger("pytorch_transformers.modeling_utils").setLevel( 780 | logging.WARN 781 | ) # Reduce logging 782 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 783 | for checkpoint in checkpoints: 784 | global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" 785 | model = model_class.from_pretrained(checkpoint) 786 | model.to(args.device) 787 | result = evaluate( 788 | args, 789 | model, 790 | tokenizer, 791 | labels, 792 | pad_token_label_id, 793 | mode="dev", 794 | prefix=global_step, 795 | ) 796 | if global_step: 797 | result = {"{}_{}".format(global_step, k): v for k, v in result.items()} 798 | results.update(result) 799 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 800 | with open(output_eval_file, "w") as writer: 801 | for key in sorted(results.keys()): 802 | writer.write("{} = {}\n".format(key, str(results[key]))) 803 | 804 | if args.do_predict and args.local_rank in [-1, 0]: 805 | tokenizer = tokenizer_class.from_pretrained( 806 | args.output_dir, do_lower_case=args.do_lower_case 807 | ) 808 | model = model_class.from_pretrained(args.output_dir) 809 | model.to(args.device) 810 | result, predictions = evaluate( 811 | args, model, tokenizer, labels, pad_token_label_id, mode="test" 812 | ) 813 | with open('./tmp.json','w') as fp: 814 | json.dump(predictions,fp) 815 | # Save results 816 | output_test_results_file = os.path.join(args.output_dir, "test_results.txt") 817 | with open(output_test_results_file, "w") as writer: 818 | for key in sorted(result.keys()): 819 | writer.write("{} = {}\n".format(key, str(result[key]))) 820 | # Save predictions 821 | output_test_predictions_file = os.path.join( 822 | args.output_dir, "test_predictions.txt" 823 | ) 824 | with open(output_test_predictions_file, "w") as writer: 825 | with open(os.path.join(args.data_dir, "test.txt"), "r") as f: 826 | example_id = 0 827 | for line in f: 828 | if line.startswith("-DOCSTART-") or line == "" or line == "\n": 829 | writer.write(line) 830 | if not predictions[example_id]: 831 | example_id += 1 832 | elif predictions[example_id]: 833 | output_line = ( 834 | line.split()[0] 835 | + " " 836 | + predictions[example_id].pop(0) 837 | + "\n" 838 | ) 839 | writer.write(output_line) 840 | else: 841 | logger.warning( 842 | "Maximum sequence length exceeded: No prediction for '%s'.", 843 | line.split()[0], 844 | ) 845 | 846 | return results 847 | 848 | 849 | if __name__ == "__main__": 850 | main() 851 | --------------------------------------------------------------------------------