├── 00017.PNG
├── README.md
├── approach.pptx
├── convert_to_csv.py
├── data
├── All_X152.yaml
└── configs
│ └── Base-RCNN-FPN.yaml
├── imgs
├── 00017.PNG
├── Finding_Tables.png
├── Subex.jpg
├── output1.jpg
├── output2.jpeg
└── output_image.png
├── models
├── __pycache__
│ ├── model_layout.cpython-38.pyc
│ └── table_location_predictor.cpython-38.pyc
├── model_layout.py
├── table_location_predictor.py
└── utils.py
├── output_images
├── cropped_table0.png
├── output_image.png
├── processed_image.png
└── tables_detected.png
├── pre
├── __pycache__
│ └── preprocess.cpython-38.pyc
├── prep.sh
├── preprocess.py
└── readme.md
├── requirements.txt
├── run_inference.py
├── training-procedure
└── subex_final.ipynb
└── utils
└── utils.py
/00017.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nehamjain10/Finding_Tables/34a653da754bb8fb4019eefcdf23c2dfacbb1e5e/00017.PNG
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Subex Hackathon
2 |
3 | > An end to end Deep Learning approach for table detection and structure recognition from invoice documents
4 |
5 | ## Results: 1st Place out of 150+ participants
6 |
7 |
8 | ## 1. Introduction
9 | Finding Tables is an automatic table recognition method for interpretation of tabular data in document images. We present an improved deep learning-based end to end approach for solving both problems of table detection and structure recognition using a finetuned model on our invoice data which is already pretrained on TableBank. Finding Tables is a Faster RCNN High-Resolution Network that detects the regions of tables. For our structure recognition we propose an entirely novel approach leveraging the SOTA methods in NLP. We use layoutLM a BERT based model to process the text in the image and map them as question answers pairs, so that we can then transform it into json files.
10 |
11 |
12 | ## 2. Setting it all up
13 |
14 |
15 | Setting up LayoutLM:
16 |
17 |
18 | 19 | git clone -b remove_torch_save https://github.com/NielsRogge/unilm.git 20 | cd unilm/layoutlm 21 | pip install unilm/layoutlm 22 | git clone https://github.com/huggingface/transformers.git 23 | cd transformers 24 | pip install ./transformers 25 |26 | 27 | Code is developed under following library dependencies
35 | pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.7/index.html 36 |37 | 38 |
43 | pip install -r requirements.txt 44 |45 | 46 |
Please download weights for Detectron V2 and LayoutLM and keep it data folder
49 | 50 | 51 | Detectronv2 weights: Detectron_finetuned_model_weights 52 | 53 | LayoutLM LayoutLM weights 54 | 55 | ## 3. Running inference 56 | 57 |  58 | 59 | 60 | To test custom images on our model, go inside the folder and run the command "python run_inference.py 00017.PNG (path of image file)" 61 | 62 | 63 | ## 4. Examples 64 | 65 | Original Image: 66 | 67 |  68 | 69 | 70 | Detecting Images: 71 | 72 |  73 | 74 | 75 | Example from Structure Recognition: 76 | 77 |  78 | 79 |  80 | 81 | 82 |84 | 85 | If you are having troubles getting it to work, please feel free to contact me or raise an issue 86 | 87 |
88 |90 | Neham (nehamjain2002@gmail.com) & Tanay (dixittanay@gmail.com) 91 | -------------------------------------------------------------------------------- /approach.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nehamjain10/Finding_Tables/34a653da754bb8fb4019eefcdf23c2dfacbb1e5e/approach.pptx -------------------------------------------------------------------------------- /convert_to_csv.py: -------------------------------------------------------------------------------- 1 | #Helper function to convert extracted table image into a csv file 2 | #Doesnt work that great (just a backup method) 3 | 4 | 5 | import random 6 | import os 7 | from os import listdir 8 | from xml.etree import ElementTree 9 | import cv2 10 | import glob 11 | from PIL import Image 12 | from random import randrange 13 | import numpy as np 14 | import pytesseract 15 | 16 | 17 | def convert(image): 18 | i=cv2.read(image_path) 19 | gray_image = cv2.cvtColor(i, cv2.COLOR_BGR2GRAY) 20 | threshold_img = cv2.threshold(gray_image, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] 21 | kernel = np.ones((1, 1), np.uint8) 22 | threshold_img = cv2.dilate(threshold_img, kernel, iterations=1) 23 | threshold_img = cv2.erode(threshold_img, kernel, iterations=1) 24 | 25 | #configuring parameters for tesseract 26 | from pytesseract import Output 27 | custom_config = r'--oem 3 --psm 6' 28 | 29 | # now feeding image to tesseract 30 | 31 | details = pytesseract.image_to_data(threshold_img, output_type=Output.DICT,config=custom_config) 32 | 33 | print(details.keys()) 34 | from pytesseract import Output 35 | custom_config = r'--oem 3 --psm 6' 36 | total_boxes = len(details['text']) 37 | 38 | for sequence_number in range(total_boxes): 39 | 40 | if int(details['conf'][sequence_number]) >5: 41 | (x, y, w, h) = (details['left'][sequence_number], details['top'][sequence_number], details['width'][sequence_number], details['height'][sequence_number]) 42 | threshold_img = cv2.rectangle(threshold_img, (x, y), (x + w, y + h), (0, 255, 0), 2) 43 | 44 | # display image 45 | 46 | 47 | # now feeding image to tesseract 48 | 49 | details = pytesseract.image_to_data(threshold_img, output_type=Output.DICT,config=custom_config) 50 | parse_text = [] 51 | 52 | word_list = [] 53 | 54 | last_word = '' 55 | 56 | for word in details['text']: 57 | 58 | if word!='': 59 | 60 | word_list.append(word) 61 | 62 | last_word = word 63 | 64 | if (last_word!='' and word == '') or (word==details['text'][-1]): 65 | 66 | parse_text.append(word_list) 67 | 68 | word_list = [] 69 | import csv 70 | # saving the extracted text output to a txt file 71 | 72 | with open('result.txt','w', newline="") as file: 73 | 74 | csv.writer(file, delimiter=" ").writerows(parse_text) 75 | 76 | import pandas as pd 77 | # reading the txt file into a dataframe to convert to csv file 78 | df = pd.read_csv("result.txt",delimiter='\t') 79 | df.to_csv('result.csv') -------------------------------------------------------------------------------- /data/All_X152.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "configs/Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "X-152-32x8d-IN5k.pkl" 4 | MASK_ON: False 5 | PIXEL_STD: [57.375, 57.120, 58.395] 6 | RESNETS: 7 | STRIDE_IN_1X1: False 8 | NUM_GROUPS: 32 9 | WIDTH_PER_GROUP: 8 10 | DEPTH: 152 11 | ROI_HEADS: 12 | NUM_CLASSES: 1 13 | SOLVER: 14 | STEPS: (88000, 99000) 15 | MAX_ITER: 110000 16 | IMS_PER_BATCH: 24 17 | BASE_LR: 0.03 18 | DATASETS: 19 | TRAIN: ("tablebank_word_train", "tablebank_latex_train") 20 | TEST: ("tablebank_word_val", "tablebank_latex_val") 21 | DATALOADER: 22 | NUM_WORKERS: 2 23 | OUTPUT_DIR: "output/X152/All_X152" -------------------------------------------------------------------------------- /data/configs/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | SOLVER: 33 | IMS_PER_BATCH: 16 34 | BASE_LR: 0.02 35 | STEPS: (60000, 80000) 36 | MAX_ITER: 90000 37 | INPUT: 38 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 39 | VERSION: 2 40 | -------------------------------------------------------------------------------- /imgs/00017.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nehamjain10/Finding_Tables/34a653da754bb8fb4019eefcdf23c2dfacbb1e5e/imgs/00017.PNG -------------------------------------------------------------------------------- /imgs/Finding_Tables.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nehamjain10/Finding_Tables/34a653da754bb8fb4019eefcdf23c2dfacbb1e5e/imgs/Finding_Tables.png -------------------------------------------------------------------------------- /imgs/Subex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nehamjain10/Finding_Tables/34a653da754bb8fb4019eefcdf23c2dfacbb1e5e/imgs/Subex.jpg -------------------------------------------------------------------------------- /imgs/output1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nehamjain10/Finding_Tables/34a653da754bb8fb4019eefcdf23c2dfacbb1e5e/imgs/output1.jpg -------------------------------------------------------------------------------- /imgs/output2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nehamjain10/Finding_Tables/34a653da754bb8fb4019eefcdf23c2dfacbb1e5e/imgs/output2.jpeg -------------------------------------------------------------------------------- /imgs/output_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nehamjain10/Finding_Tables/34a653da754bb8fb4019eefcdf23c2dfacbb1e5e/imgs/output_image.png -------------------------------------------------------------------------------- /models/__pycache__/model_layout.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nehamjain10/Finding_Tables/34a653da754bb8fb4019eefcdf23c2dfacbb1e5e/models/__pycache__/model_layout.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/table_location_predictor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nehamjain10/Finding_Tables/34a653da754bb8fb4019eefcdf23c2dfacbb1e5e/models/__pycache__/table_location_predictor.cpython-38.pyc -------------------------------------------------------------------------------- /models/model_layout.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw, ImageFont 2 | import numpy as np 3 | import pytesseract 4 | from transformers import LayoutLMForTokenClassification,LayoutLMTokenizer,LayoutLMConfig 5 | import torch 6 | #from utils import * 7 | 8 | def openImage(path): 9 | image = Image.open(path) 10 | return image 11 | 12 | 13 | def getOCRdata(image): 14 | width, height = image.size 15 | w_scale = 1000/width 16 | h_scale = 1000/height 17 | 18 | ocr_df = pytesseract.image_to_data(image, output_type='data.frame') \ 19 | 20 | ocr_df = ocr_df.dropna().assign(left_scaled = ocr_df.left*w_scale, 21 | width_scaled = ocr_df.width*w_scale, 22 | top_scaled = ocr_df.top*h_scale, 23 | height_scaled = ocr_df.height*h_scale, 24 | right_scaled = lambda x: x.left_scaled + x.width_scaled, 25 | bottom_scaled = lambda x: x.top_scaled + x.height_scaled) 26 | 27 | float_cols = ocr_df.select_dtypes('float').columns 28 | ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int) 29 | ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True) 30 | ocr_df = ocr_df.dropna().reset_index(drop=True) 31 | 32 | return ocr_df, width, height 33 | 34 | def normalize_box(box, width, height): 35 | return [ 36 | int(1000 * (box[0] / width)), 37 | int(1000 * (box[1] / height)), 38 | int(1000 * (box[2] / width)), 39 | int(1000 * (box[3] / height)), 40 | ] 41 | 42 | def getBoxes(df, width, height): 43 | words = list(df.text) 44 | coordinates = df[['left', 'top', 'width', 'height','text']] 45 | actual_boxes = []; named_boxes =[] 46 | for idx, row in coordinates.iterrows(): 47 | x, y, w, h, text = tuple(row) # the row comes in (left, top, width, height) format 48 | actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+widght, top+height) to get the actual box 49 | text_box =actual_box +[text] 50 | actual_boxes.append(actual_box) 51 | named_boxes.append(text_box ) 52 | 53 | boxes = [] 54 | for box in actual_boxes: 55 | boxes.append(normalize_box(box, width, height)) 56 | return boxes, words,actual_boxes 57 | 58 | 59 | 60 | 61 | def convert_example_to_features(image, words, boxes, actual_boxes, tokenizer, args, cls_token_box=[0, 0, 0, 0], 62 | sep_token_box=[1000, 1000, 1000, 1000], 63 | pad_token_box=[0, 0, 0, 0]): 64 | width, height = image.size 65 | 66 | tokens = [] 67 | token_boxes = [] 68 | actual_bboxes = [] # we use an extra b because actual_boxes is already used 69 | token_actual_boxes = [] 70 | for word, box, actual_bbox in zip(words, boxes, actual_boxes): 71 | word_tokens = tokenizer.tokenize(word) 72 | tokens.extend(word_tokens) 73 | token_boxes.extend([box] * len(word_tokens)) 74 | actual_bboxes.extend([actual_bbox] * len(word_tokens)) 75 | token_actual_boxes.extend([actual_bbox] * len(word_tokens)) 76 | 77 | # Truncation: account for [CLS] and [SEP] with "- 2". 78 | special_tokens_count = 2 79 | if len(tokens) > args.max_seq_length - special_tokens_count: 80 | tokens = tokens[: (args.max_seq_length - special_tokens_count)] 81 | token_boxes = token_boxes[: (args.max_seq_length - special_tokens_count)] 82 | actual_bboxes = actual_bboxes[: (args.max_seq_length - special_tokens_count)] 83 | token_actual_boxes = token_actual_boxes[: (args.max_seq_length - special_tokens_count)] 84 | 85 | # add [SEP] token, with corresponding token boxes and actual boxes 86 | tokens += [tokenizer.sep_token] 87 | token_boxes += [sep_token_box] 88 | actual_bboxes += [[0, 0, width, height]] 89 | token_actual_boxes += [[0, 0, width, height]] 90 | 91 | segment_ids = [0] * len(tokens) 92 | 93 | # next: [CLS] token 94 | tokens = [tokenizer.cls_token] + tokens 95 | token_boxes = [cls_token_box] + token_boxes 96 | actual_bboxes = [[0, 0, width, height]] + actual_bboxes 97 | token_actual_boxes = [[0, 0, width, height]] + token_actual_boxes 98 | segment_ids = [1] + segment_ids 99 | 100 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 101 | 102 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 103 | # tokens are attended to. 104 | input_mask = [1] * len(input_ids) 105 | 106 | # Zero-pad up to the sequence length. 107 | padding_length = args.max_seq_length - len(input_ids) 108 | input_ids += [tokenizer.pad_token_id] * padding_length 109 | input_mask += [0] * padding_length 110 | segment_ids += [tokenizer.pad_token_id] * padding_length 111 | token_boxes += [pad_token_box] * padding_length 112 | token_actual_boxes += [pad_token_box] * padding_length 113 | 114 | assert len(input_ids) == args.max_seq_length 115 | assert len(input_mask) == args.max_seq_length 116 | assert len(segment_ids) == args.max_seq_length 117 | #assert len(label_ids) == args.max_seq_length 118 | assert len(token_boxes) == args.max_seq_length 119 | assert len(token_actual_boxes) == args.max_seq_length 120 | 121 | return input_ids, input_mask, segment_ids, token_boxes, token_actual_boxes 122 | 123 | class LayoutLM(torch.nn.Module): 124 | 125 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 126 | 127 | def __init__(self,image_path, model_path,config_path, num_labels =13, args= None): 128 | super(LayoutLM, self).__init__() 129 | self.image = openImage(image_path) 130 | self.args = args 131 | self.tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased") 132 | 133 | config = LayoutLMConfig.from_pretrained(config_path) 134 | self.model = LayoutLMForTokenClassification.from_pretrained(model_path, config=config) 135 | self.model.to(device) 136 | 137 | self.input_ids= None; self.attention_mask= None; self.token_type_ids= None; self.bboxes= None; self.token_actual_boxes =None 138 | 139 | def setup_data(self, args): 140 | 141 | ocr_df, width, height = getOCRdata(self.image) 142 | boxes, words,actual_boxes = getBoxes(ocr_df, width, height) 143 | input_ids, input_mask, segment_ids, token_boxes, self.token_actual_boxes = \ 144 | convert_example_to_features(image=self.image, words=words, boxes=boxes, actual_boxes=actual_boxes, tokenizer=self.tokenizer, args=self.args) 145 | 146 | self.input_ids = torch.tensor(input_ids, device=device).unsqueeze(0) 147 | self.attention_mask = torch.tensor(input_mask, device=device).unsqueeze(0) 148 | self.token_type_ids = torch.tensor(segment_ids, device=device).unsqueeze(0) 149 | self.bbox = torch.tensor(token_boxes, device=device).unsqueeze(0) 150 | self.outputs = self.model(input_ids=self.input_ids, bbox=self.bbox, attention_mask=self.attention_mask, token_type_ids=self.token_type_ids) 151 | assert self.outputs != None ,"Setup failed" 152 | print('Setup done') 153 | 154 | def inference(self): 155 | token_predictions = self.outputs.logits.argmax(-1).squeeze().tolist() 156 | word_level_predictions = [] # let's turn them into word level predictions 157 | final_boxes = [] 158 | for id, token_pred, box in zip(self.input_ids.squeeze().tolist(), token_predictions, self.token_actual_boxes): 159 | if (self.tokenizer.decode([id]).startswith("##")) or (id in [self.tokenizer.cls_token_id, 160 | self.tokenizer.sep_token_id, 161 | self.tokenizer.pad_token_id]): 162 | # skip prediction + bounding box 163 | 164 | continue 165 | else: 166 | word_level_predictions.append(token_pred) 167 | final_boxes.append(box) 168 | 169 | label2color = {'i-question':'blue', 'i-answer':'green', 'i-header':'orange','b-question':'blue',\ 170 | 'b-answer':'green', 'b-header':'orange', 'e-question':'blue', 'e-answer':'green', 'e-header':'orange',\ 171 | 's-question':'blue', 's-answer':'green', 's-header':'orange','other':'violet'} 172 | 173 | label_map ={0: 'B-ANSWER', 174 | 1: 'B-HEADER', 175 | 2: 'B-QUESTION', 176 | 3: 'E-ANSWER', 177 | 4: 'E-HEADER', 178 | 5: 'E-QUESTION', 179 | 6: 'I-ANSWER', 180 | 7: 'I-HEADER', 181 | 8: 'I-QUESTION', 182 | 9: 'O', 183 | 10: 'S-ANSWER', 184 | 11: 'S-HEADER', 185 | 12: 'S-QUESTION'} 186 | 187 | draw = ImageDraw.Draw(self.image) 188 | for prediction, box in zip(word_level_predictions, final_boxes): 189 | predicted_label = iob_to_label(label_map[prediction]).lower() 190 | if predicted_label !='other': 191 | draw.rectangle(box, outline=label2color[predicted_label]) 192 | draw.text((box[0] + 20, box[1] - 20), text=predicted_label, fill=label2color[predicted_label], font=font) 193 | 194 | return Image 195 | 196 | 197 | 198 | 199 | -------------------------------------------------------------------------------- /models/table_location_predictor.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import torch, torchvision 3 | from os import listdir 4 | from tqdm.notebook import tqdm 5 | from detectron2.engine import DefaultTrainer 6 | # Setup detectron2 logger 7 | import detectron2 8 | from detectron2.utils.logger import setup_logger 9 | setup_logger() 10 | from PIL import Image 11 | 12 | # import some common libraries 13 | import numpy as np 14 | import os, json, cv2, random 15 | from detectron2.structures import BoxMode 16 | 17 | # import some common detectron2 utilities 18 | from detectron2 import model_zoo 19 | from detectron2.engine import DefaultPredictor 20 | from detectron2.config import get_cfg 21 | from detectron2.utils.visualizer import Visualizer 22 | from detectron2.data import MetadataCatalog, DatasetCatalog 23 | 24 | def save_detected_tables(image,bnd_boxes_tables): 25 | tables_detected=[] 26 | os.makedirs("output_images",exist_ok=True) 27 | count=0 28 | cv2.imwrite("output_images/tables_detected.png",image) 29 | for i in bnd_boxes_tables: 30 | i=i.numpy() 31 | xmin=int(i[0]) 32 | ymin=int(i[1]) 33 | xmax=int(i[2]) 34 | ymax=int(i[3]) 35 | new_img=image[ymin:ymax,xmin:xmax] 36 | path=f"output_images/cropped_table{count}.png" 37 | cv2.imwrite(path,new_img) 38 | count+=1 39 | tables_detected.append(path) 40 | return tables_detected 41 | 42 | def get_predictor(model_weights,threshold=0.75): 43 | cfg = get_cfg() 44 | # add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library 45 | cfg.merge_from_file("data/All_X152.yaml") 46 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold # set threshold for this model 47 | # Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well 48 | cfg.MODEL.WEIGHTS = model_weights 49 | predictor = DefaultPredictor(cfg) 50 | return predictor,cfg 51 | 52 | 53 | def Table_Detection(image_path,model_weights): 54 | bnd_boxes_tables=[] 55 | threshold=0.75 56 | im = cv2.imread(image_path) 57 | print("DETECTING TABLES................") 58 | while len(bnd_boxes_tables)==0 and threshold>=0.5: 59 | predictor,cfg=get_predictor(model_weights,threshold) 60 | outputs = predictor(im) 61 | bnd_boxes_tables=list(outputs["instances"].pred_boxes.to("cpu")) 62 | threshold=threshold-0.25 63 | v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0])) 64 | out = v.draw_instance_predictions(outputs["instances"].to("cpu")) 65 | im1=out.get_image()[:, :, ::-1] 66 | table_imgs=save_detected_tables(im1,bnd_boxes_tables) 67 | print("File path of saved tables are: ",table_imgs) 68 | return table_imgs -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import pytesseract 4 | 5 | 6 | def openImage(path): 7 | image = Image.open(path) 8 | return image 9 | 10 | 11 | def getOCRdata(image): 12 | width, height = image.size 13 | w_scale = 1000/width 14 | h_scale = 1000/height 15 | 16 | ocr_df = pytesseract.image_to_data(image, output_type='data.frame') \ 17 | 18 | ocr_df = ocr_df.dropna().assign(left_scaled = ocr_df.left*w_scale, 19 | width_scaled = ocr_df.width*w_scale, 20 | top_scaled = ocr_df.top*h_scale, 21 | height_scaled = ocr_df.height*h_scale, 22 | right_scaled = lambda x: x.left_scaled + x.width_scaled, 23 | bottom_scaled = lambda x: x.top_scaled + x.height_scaled) 24 | 25 | float_cols = ocr_df.select_dtypes('float').columns 26 | ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int) 27 | ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True) 28 | ocr_df = ocr_df.dropna().reset_index(drop=True) 29 | 30 | return ocr_df, width, height 31 | 32 | def normalize_box(box, width, height): 33 | return [ 34 | int(1000 * (box[0] / width)), 35 | int(1000 * (box[1] / height)), 36 | int(1000 * (box[2] / width)), 37 | int(1000 * (box[3] / height)), 38 | ] 39 | 40 | def getBoxes(df, width, height): 41 | words = list(df.text) 42 | coordinates = df[['left', 'top', 'width', 'height','text']] 43 | actual_boxes = []; named_boxes =[] 44 | for idx, row in coordinates.iterrows(): 45 | x, y, w, h, text = tuple(row) # the row comes in (left, top, width, height) format 46 | actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+widght, top+height) to get the actual box 47 | text_box =actual_box +[text] 48 | actual_boxes.append(actual_box) 49 | named_boxes.append(text_box ) 50 | 51 | boxes = [] 52 | for box in actual_boxes: 53 | boxes.append(normalize_box(box, width, height)) 54 | return boxes, words,actual_boxes 55 | 56 | 57 | 58 | 59 | def convert_example_to_features(image, words, boxes, actual_boxes, tokenizer, args, cls_token_box=[0, 0, 0, 0], 60 | sep_token_box=[1000, 1000, 1000, 1000], 61 | pad_token_box=[0, 0, 0, 0]): 62 | width, height = image.size 63 | 64 | tokens = [] 65 | token_boxes = [] 66 | actual_bboxes = [] # we use an extra b because actual_boxes is already used 67 | token_actual_boxes = [] 68 | for word, box, actual_bbox in zip(words, boxes, actual_boxes): 69 | word_tokens = tokenizer.tokenize(word) 70 | tokens.extend(word_tokens) 71 | token_boxes.extend([box] * len(word_tokens)) 72 | actual_bboxes.extend([actual_bbox] * len(word_tokens)) 73 | token_actual_boxes.extend([actual_bbox] * len(word_tokens)) 74 | 75 | # Truncation: account for [CLS] and [SEP] with "- 2". 76 | special_tokens_count = 2 77 | if len(tokens) > args.max_seq_length - special_tokens_count: 78 | tokens = tokens[: (args.max_seq_length - special_tokens_count)] 79 | token_boxes = token_boxes[: (args.max_seq_length - special_tokens_count)] 80 | actual_bboxes = actual_bboxes[: (args.max_seq_length - special_tokens_count)] 81 | token_actual_boxes = token_actual_boxes[: (args.max_seq_length - special_tokens_count)] 82 | 83 | # add [SEP] token, with corresponding token boxes and actual boxes 84 | tokens += [tokenizer.sep_token] 85 | token_boxes += [sep_token_box] 86 | actual_bboxes += [[0, 0, width, height]] 87 | token_actual_boxes += [[0, 0, width, height]] 88 | 89 | segment_ids = [0] * len(tokens) 90 | 91 | # next: [CLS] token 92 | tokens = [tokenizer.cls_token] + tokens 93 | token_boxes = [cls_token_box] + token_boxes 94 | actual_bboxes = [[0, 0, width, height]] + actual_bboxes 95 | token_actual_boxes = [[0, 0, width, height]] + token_actual_boxes 96 | segment_ids = [1] + segment_ids 97 | 98 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 99 | 100 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 101 | # tokens are attended to. 102 | input_mask = [1] * len(input_ids) 103 | 104 | # Zero-pad up to the sequence length. 105 | padding_length = args.max_seq_length - len(input_ids) 106 | input_ids += [tokenizer.pad_token_id] * padding_length 107 | input_mask += [0] * padding_length 108 | segment_ids += [tokenizer.pad_token_id] * padding_length 109 | token_boxes += [pad_token_box] * padding_length 110 | token_actual_boxes += [pad_token_box] * padding_length 111 | 112 | assert len(input_ids) == args.max_seq_length 113 | assert len(input_mask) == args.max_seq_length 114 | assert len(segment_ids) == args.max_seq_length 115 | #assert len(label_ids) == args.max_seq_length 116 | assert len(token_boxes) == args.max_seq_length 117 | assert len(token_actual_boxes) == args.max_seq_length 118 | 119 | return input_ids, input_mask, segment_ids, token_boxes, token_actual_boxes 120 | -------------------------------------------------------------------------------- /output_images/cropped_table0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nehamjain10/Finding_Tables/34a653da754bb8fb4019eefcdf23c2dfacbb1e5e/output_images/cropped_table0.png -------------------------------------------------------------------------------- /output_images/output_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nehamjain10/Finding_Tables/34a653da754bb8fb4019eefcdf23c2dfacbb1e5e/output_images/output_image.png -------------------------------------------------------------------------------- /output_images/processed_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nehamjain10/Finding_Tables/34a653da754bb8fb4019eefcdf23c2dfacbb1e5e/output_images/processed_image.png -------------------------------------------------------------------------------- /output_images/tables_detected.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nehamjain10/Finding_Tables/34a653da754bb8fb4019eefcdf23c2dfacbb1e5e/output_images/tables_detected.png -------------------------------------------------------------------------------- /pre/__pycache__/preprocess.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nehamjain10/Finding_Tables/34a653da754bb8fb4019eefcdf23c2dfacbb1e5e/pre/__pycache__/preprocess.cpython-38.pyc -------------------------------------------------------------------------------- /pre/prep.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | for file in $(ls data/images) 3 | do 4 | echo "file is $file" 5 | python preprocess.py $file data/images_new/ 6 | 7 | done -------------------------------------------------------------------------------- /pre/preprocess.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import pathlib 3 | import cv2 4 | import pytesseract 5 | import urllib 6 | import numpy as np 7 | import re 8 | import sys 9 | 10 | 11 | 12 | def rotateImage(image): 13 | 14 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 15 | gray = cv2.bitwise_not(gray) 16 | 17 | rot_data = pytesseract.image_to_osd(image) 18 | print("[OSD] "+rot_data) 19 | rot = re.search('(?<=Rotate: )\d+', rot_data) 20 | 21 | angle = float(rot) 22 | if angle > 0: 23 | angle = 360 - angle 24 | print("[ANGLE] "+str(angle)) 25 | 26 | # rotate the image to deskew it 27 | (h, w) = image.shape[:2] 28 | center = (w // 2, h // 2) 29 | M = cv2.getRotationMatrix2D(center, angle, 1.0) 30 | rotated = cv2.warpAffine(image, M, (w, h),flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE) 31 | 32 | return rotated 33 | 34 | 35 | 36 | def RemoveStryLines(image): 37 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 38 | mask = np.zeros(image.shape, dtype=np.uint8) 39 | 40 | cnts = cv2.findContours(gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 41 | cnts = cnts[0] if len(cnts) == 2 else cnts[1] 42 | 43 | cv2.fillPoly(mask, cnts, [255,255,255]) 44 | mask = 255 - mask 45 | result = cv2.bitwise_or(image, mask) 46 | 47 | return result 48 | 49 | 50 | def ostsuthresholding(img): 51 | th3 = cv2.adaptiveThreshold(img,255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C,\ 52 | cv2.THRESH_BINARY,11,2) 53 | 54 | return th3 55 | 56 | def sharpenImage(img): 57 | kernel = np.array([[0,-1,0], [-1,5,-1], [0,-1,0]]) 58 | im = cv2.filter2D(img, -1, kernel) 59 | return im 60 | 61 | def preprocess_image(path): 62 | Image = cv2.imread(path) 63 | img = RemoveStryLines(Image) 64 | img_final = sharpenImage(img) 65 | cv2.imwrite('output_images/processed_image.png', img_final) 66 | -------------------------------------------------------------------------------- /pre/readme.md: -------------------------------------------------------------------------------- 1 | rest all are functions which take in image path and return clean image 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | adal==1.2.5 3 | argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1605217004767/work 4 | async-generator==1.10 5 | attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1605083924122/work 6 | azure-cognitiveservices-search-imagesearch==2.0.0 7 | azure-common==1.1.26 8 | backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work 9 | backports.functools-lru-cache==1.6.1 10 | bleach @ file:///home/conda/feedstock_root/build_artifacts/bleach_1600454382015/work 11 | blis==0.4.1 12 | brotlipy==0.7.0 13 | cachetools==4.2.1 14 | catalogue @ file:///home/conda/feedstock_root/build_artifacts/catalogue_1605613584677/work 15 | certifi==2020.11.8 16 | cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1602537222527/work 17 | chardet @ file:///home/conda/feedstock_root/build_artifacts/chardet_1602255302199/work 18 | click==7.1.2 19 | cloudpickle==1.6.0 20 | cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography_1604179079864/work 21 | cycler==0.10.0 22 | cymem @ file:///home/conda/feedstock_root/build_artifacts/cymem_1604919356432/work 23 | Cython==0.29.21 24 | dataclasses==0.6 25 | decorator==4.4.2 26 | defusedxml==0.6.0 27 | detectron2==0.3+cu110 28 | entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1605121927639/work/dist/entrypoints-0.3-py2.py3-none-any.whl 29 | fastai @ file:///home/jhoward/anaconda3/conda-bld/fastai_1604932388885/work 30 | fastai2==0.0.30 31 | fastbook==0.0.14 32 | fastcore @ file:///home/jhoward/anaconda3/conda-bld/fastcore_1605551797847/work 33 | fastprocess==2.0.0 34 | fastprogress @ file:///home/conda/feedstock_root/build_artifacts/fastprogress_1597932925331/work 35 | filelock==3.0.12 36 | future==0.18.2 37 | fvcore==0.1.3.post20210218 38 | gdown==3.12.2 39 | google-auth==1.27.0 40 | google-auth-oauthlib==0.4.2 41 | graphviz==0.15 42 | grpcio==1.35.0 43 | h5py==3.1.0 44 | idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1593328102638/work 45 | importlib-metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1602263269022/work 46 | iopath==0.1.3 47 | ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1605455374814/work/dist/ipykernel-5.3.4-py3-none-any.whl 48 | ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1604159561527/work 49 | ipython-genutils==0.2.0 50 | ipywidgets @ file:///home/conda/feedstock_root/build_artifacts/ipywidgets_1599554010055/work 51 | isodate==0.6.0 52 | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1605054524035/work 53 | Jinja2==2.11.2 54 | joblib @ file:///home/conda/feedstock_root/build_artifacts/joblib_1601671685479/work 55 | jsonschema @ file:///home/conda/feedstock_root/build_artifacts/jsonschema_1602551949684/work 56 | jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1598486169312/work 57 | jupyter-console @ file:///home/conda/feedstock_root/build_artifacts/jupyter_console_1598728807792/work 58 | jupyter-core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1605735009305/work 59 | jupyterlab-pygments @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_pygments_1601375948261/work 60 | Keras==2.4.3 61 | kiwisolver @ file:///home/conda/feedstock_root/build_artifacts/kiwisolver_1604322295622/work 62 | layoutlm @ file:///notebooks/unilm/layoutlm 63 | lxml==4.5.1 64 | Markdown==3.3.3 65 | MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1602267312178/work 66 | matplotlib @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-suite_1605180228501/work 67 | mistune @ file:///home/conda/feedstock_root/build_artifacts/mistune_1605115651871/work 68 | msrest==0.6.19 69 | msrestazure==0.6.4 70 | murmurhash==1.0.4 71 | nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1602859080374/work 72 | nbconvert==5.6.1 73 | nbdev==1.1.5 74 | nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1602732862338/work 75 | nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1605195931949/work 76 | notebook @ file:///home/conda/feedstock_root/build_artifacts/notebook_1605103633466/work 77 | numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1604945996350/work 78 | nvidia-ml-py3==7.352.0 79 | oauthlib==3.1.0 80 | olefile @ file:///home/conda/feedstock_root/build_artifacts/olefile_1602866521163/work 81 | opencv-python==4.5.1.48 82 | packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1589925210001/work 83 | pandas==1.1.4 84 | pandocfilters==1.4.2 85 | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1595548966091/work 86 | pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1602535608087/work 87 | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work 88 | Pillow==7.1.2 89 | plac==0.9.6 90 | portalocker==2.2.1 91 | preshed @ file:///home/conda/feedstock_root/build_artifacts/preshed_1605166129992/work 92 | prometheus-client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1605543085815/work 93 | prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1605053337398/work 94 | protobuf==3.14.0 95 | ptyprocess==0.6.0 96 | pyasn1==0.4.8 97 | pyasn1-modules==0.2.8 98 | pycocotools==2.0.2 99 | pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1593275161868/work 100 | pydot==1.4.2 101 | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1603558917696/work 102 | PyJWT==1.7.1 103 | pyOpenSSL==19.1.0 104 | pyparsing==2.4.7 105 | PyQt5==5.12.3 106 | PyQt5-sip==4.19.18 107 | PyQtChart==5.12 108 | PyQtWebEngine==5.12.1 109 | pyrsistent @ file:///home/conda/feedstock_root/build_artifacts/pyrsistent_1605115595652/work 110 | PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1602326928339/work 111 | pytesseract==0.3.7 112 | python-dateutil==2.8.1 113 | pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1604321279890/work 114 | PyYAML==5.3.1 115 | pyzmq==20.0.0 116 | qtconsole @ file:///home/conda/feedstock_root/build_artifacts/qtconsole_1599147533948/work 117 | QtPy==1.9.0 118 | regex==2020.11.13 119 | requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1605186911681/work 120 | requests-oauthlib==1.3.0 121 | rsa==4.7.1 122 | sacremoses==0.0.43 123 | scikit-learn @ file:///home/conda/feedstock_root/build_artifacts/scikit-learn_1604232448678/work 124 | scipy @ file:///home/conda/feedstock_root/build_artifacts/scipy_1604304779838/work 125 | Send2Trash==1.5.0 126 | sentencepiece==0.1.94 127 | seqeval==0.0.12 128 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1590081179328/work 129 | spacy @ file:///home/conda/feedstock_root/build_artifacts/spacy_1605692171608/work 130 | srsly @ file:///home/conda/feedstock_root/build_artifacts/srsly_1605085673973/work 131 | tabulate==0.8.8 132 | tensorboard==2.4.1 133 | tensorboard-plugin-wit==1.8.0 134 | tensorboardX==2.0 135 | termcolor==1.1.0 136 | terminado @ file:///home/conda/feedstock_root/build_artifacts/terminado_1602679586280/work 137 | testpath==0.4.4 138 | thinc @ file:///home/conda/feedstock_root/build_artifacts/thinc_1605620876750/work 139 | threadpoolctl @ file:///tmp/tmp79xdzxkt/threadpoolctl-2.1.0-py3-none-any.whl 140 | tokenizers==0.10.1 141 | torch==1.7.1 142 | torchvision==0.8.2 143 | tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1604105045397/work 144 | tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1605543106900/work 145 | traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1602771532708/work 146 | transformers @ file:///notebooks/transformers 147 | typing-extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1602702424206/work 148 | urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1603125704209/work 149 | utils==1.0.1 150 | wasabi @ file:///home/conda/feedstock_root/build_artifacts/wasabi_1600272362626/work 151 | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1600965781394/work 152 | webencodings==0.5.1 153 | Werkzeug==1.0.1 154 | widgetsnbextension @ file:///home/conda/feedstock_root/build_artifacts/widgetsnbextension_1605475534911/work 155 | xmltodict==0.12.0 156 | yacs==0.1.8 157 | zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1603668650351/work 158 | -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | from models.model_layout import LayoutLM 2 | from models.table_location_predictor import Table_Detection 3 | import sys 4 | import argparse 5 | import cv2 6 | import PIL 7 | import warnings 8 | warnings.filterwarnings("ignore") 9 | from pre.preprocess import preprocess_image 10 | # class to turn the keys of a dict into attributes (thanks Stackoverflow) 11 | class AttrDict(dict): 12 | def __init__(self, *args, **kwargs): 13 | super(AttrDict, self).__init__(*args, **kwargs) 14 | self.__dict__ = self 15 | 16 | 17 | 18 | if __name__=="__main__": 19 | 20 | 21 | args = {'local_rank': -1, 22 | 'overwrite_cache': True, 23 | 'data_dir': '/content/data', 24 | 'model_name_or_path':'data/model_layoutLM.pt', 25 | 'max_seq_length': 512, 26 | 'model_type': 'layoutlm',} 27 | args = AttrDict(args) 28 | 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('image_path', help='Image file path (.PNG)') 32 | parser.add_argument('--layoutLM_model_path',default='data/model_layoutLM.pt', type=str,help='.pt file for model weights') 33 | parser.add_argument('--table_detection_model_path',default='data/model_detectronV2.pth', type=str,help='.pt file for model weights') 34 | parser.add_argument('--config',default="microsoft/layoutlm-base-uncased",type = str,help='model configure path json file') 35 | arguments = parser.parse_args() 36 | 37 | preprocess_image(arguments.image_path) 38 | 39 | detected_tables = Table_Detection("output_images/processed_image.png",arguments.table_detection_model_path) 40 | 41 | 42 | 43 | for path in detected_tables: 44 | layout_model = LayoutLM(path, arguments.layoutLM_model_path, arguments.config) 45 | layout_model.setup_data(args) 46 | Image = layout_model.inference() 47 | 48 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import pytesseract 4 | 5 | 6 | def openImage(path): 7 | image = Image.open(path) 8 | return image 9 | 10 | 11 | def getOCRdata(image): 12 | width, height = image.size 13 | w_scale = 1000/width 14 | h_scale = 1000/height 15 | 16 | ocr_df = pytesseract.image_to_data(image, output_type='data.frame') \ 17 | 18 | ocr_df = ocr_df.dropna().assign(left_scaled = ocr_df.left*w_scale, 19 | width_scaled = ocr_df.width*w_scale, 20 | top_scaled = ocr_df.top*h_scale, 21 | height_scaled = ocr_df.height*h_scale, 22 | right_scaled = lambda x: x.left_scaled + x.width_scaled, 23 | bottom_scaled = lambda x: x.top_scaled + x.height_scaled) 24 | 25 | float_cols = ocr_df.select_dtypes('float').columns 26 | ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int) 27 | ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True) 28 | ocr_df = ocr_df.dropna().reset_index(drop=True) 29 | 30 | return ocr_df, width, height 31 | 32 | def normalize_box(box, width, height): 33 | return [ 34 | int(1000 * (box[0] / width)), 35 | int(1000 * (box[1] / height)), 36 | int(1000 * (box[2] / width)), 37 | int(1000 * (box[3] / height)), 38 | ] 39 | 40 | def getBoxes(df, width, height): 41 | words = list(df.text) 42 | coordinates = df[['left', 'top', 'width', 'height','text']] 43 | actual_boxes = []; named_boxes =[] 44 | for idx, row in coordinates.iterrows(): 45 | x, y, w, h, text = tuple(row) # the row comes in (left, top, width, height) format 46 | actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+widght, top+height) to get the actual box 47 | text_box =actual_box +[text] 48 | actual_boxes.append(actual_box) 49 | named_boxes.append(text_box ) 50 | 51 | boxes = [] 52 | for box in actual_boxes: 53 | boxes.append(normalize_box(box, width, height)) 54 | return boxes, words,actual_boxes 55 | 56 | 57 | 58 | 59 | def convert_example_to_features(image, words, boxes, actual_boxes, tokenizer, args, cls_token_box=[0, 0, 0, 0], 60 | sep_token_box=[1000, 1000, 1000, 1000], 61 | pad_token_box=[0, 0, 0, 0]): 62 | width, height = image.size 63 | 64 | tokens = [] 65 | token_boxes = [] 66 | actual_bboxes = [] # we use an extra b because actual_boxes is already used 67 | token_actual_boxes = [] 68 | for word, box, actual_bbox in zip(words, boxes, actual_boxes): 69 | word_tokens = tokenizer.tokenize(word) 70 | tokens.extend(word_tokens) 71 | token_boxes.extend([box] * len(word_tokens)) 72 | actual_bboxes.extend([actual_bbox] * len(word_tokens)) 73 | token_actual_boxes.extend([actual_bbox] * len(word_tokens)) 74 | 75 | # Truncation: account for [CLS] and [SEP] with "- 2". 76 | special_tokens_count = 2 77 | if len(tokens) > args.max_seq_length - special_tokens_count: 78 | tokens = tokens[: (args.max_seq_length - special_tokens_count)] 79 | token_boxes = token_boxes[: (args.max_seq_length - special_tokens_count)] 80 | actual_bboxes = actual_bboxes[: (args.max_seq_length - special_tokens_count)] 81 | token_actual_boxes = token_actual_boxes[: (args.max_seq_length - special_tokens_count)] 82 | 83 | # add [SEP] token, with corresponding token boxes and actual boxes 84 | tokens += [tokenizer.sep_token] 85 | token_boxes += [sep_token_box] 86 | actual_bboxes += [[0, 0, width, height]] 87 | token_actual_boxes += [[0, 0, width, height]] 88 | 89 | segment_ids = [0] * len(tokens) 90 | 91 | # next: [CLS] token 92 | tokens = [tokenizer.cls_token] + tokens 93 | token_boxes = [cls_token_box] + token_boxes 94 | actual_bboxes = [[0, 0, width, height]] + actual_bboxes 95 | token_actual_boxes = [[0, 0, width, height]] + token_actual_boxes 96 | segment_ids = [1] + segment_ids 97 | 98 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 99 | 100 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 101 | # tokens are attended to. 102 | input_mask = [1] * len(input_ids) 103 | 104 | # Zero-pad up to the sequence length. 105 | padding_length = args.max_seq_length - len(input_ids) 106 | input_ids += [tokenizer.pad_token_id] * padding_length 107 | input_mask += [0] * padding_length 108 | segment_ids += [tokenizer.pad_token_id] * padding_length 109 | token_boxes += [pad_token_box] * padding_length 110 | token_actual_boxes += [pad_token_box] * padding_length 111 | 112 | assert len(input_ids) == args.max_seq_length 113 | assert len(input_mask) == args.max_seq_length 114 | assert len(segment_ids) == args.max_seq_length 115 | #assert len(label_ids) == args.max_seq_length 116 | assert len(token_boxes) == args.max_seq_length 117 | assert len(token_actual_boxes) == args.max_seq_length 118 | 119 | return input_ids, input_mask, segment_ids, token_boxes, token_actual_boxes 120 | --------------------------------------------------------------------------------