76 | -
77 | -
78 | -
79 | -
80 | -
81 | -
82 | - C-tag
83 | - U-tag
84 | - L-tag
85 | - X-tag
86 | - NL-tag
87 | - R-tag
88 | -
89 |
--------------------------------------------------------------------------------
/dataset/data_preprocessing_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "SPLITS": ["train", "validation"],
3 | "OTSL_TAG": {
4 | "C": "C-tag",
5 | "L": "L-tag",
6 | "U": "U-tag",
7 | "X": "X-tag",
8 | "NL": "NL-tag"
9 | },
10 | "PUBTABNET_PATH": "TFLOP-dataset/meta_data/PubTabNet_2.0.0.jsonl",
11 | "AMBIGUOUS_DATA_PATH": "TFLOP-dataset/meta_data/erroneous_pubtabnet_data.json",
12 | "DR_COORD_PATH": {
13 | "train": {
14 | "0": "TFLOP-dataset/pse_results/train/detection_results_0.pkl",
15 | "1": "TFLOP-dataset/pse_results/train/detection_results_1.pkl",
16 | "2": "TFLOP-dataset/pse_results/train/detection_results_2.pkl",
17 | "3": "TFLOP-dataset/pse_results/train/detection_results_3.pkl",
18 | "4": "TFLOP-dataset/pse_results/train/detection_results_4.pkl",
19 | "5": "TFLOP-dataset/pse_results/train/detection_results_5.pkl",
20 | "6": "TFLOP-dataset/pse_results/train/detection_results_6.pkl",
21 | "7": "TFLOP-dataset/pse_results/train/detection_results_7.pkl"
22 | },
23 | "validation": {
24 | "0": "TFLOP-dataset/pse_results/val/detection_results_0.pkl"
25 | }
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/dataset/preprocess_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import pickle
4 | import random
5 |
6 | import torch
7 | from tqdm import tqdm
8 |
9 | from .preprocess_data_utils import convert_html_to_otsl
10 |
11 |
12 | def format_pubtabnet_gold_coords(gold_bbox_collection):
13 | """Preprocess gold coordinate info for PubTabNet dataset.
14 |
15 | NOTE
16 | - In PubTabnet, empty cells are marked by absence of 'bbox' in the cell's dictionary
17 |
18 | Args:
19 | gold_bbox_collection List[Dict]: List of cell dictionaries
20 | Each dictionary has 'tokens' and 'bbox' (if the cell is filled) keys
21 | E.g. [{'tokens': ['', 'R', 'i', 's', 'k', ' ', 'F', 'a', 'c', 't', 'o', 'r', 's', ''], 'bbox': [28, 5, 77, 14]}, ... ]
22 | """
23 | cells = []
24 | for cell in gold_bbox_collection:
25 | if "bbox" in cell:
26 | # This is a cell with filledContent
27 | string_coords = ["%.2f" % c for c in cell["bbox"]] + ["2"]
28 | # Add serialised string
29 | text = "".join(cell["tokens"])
30 | string_coords = string_coords + [text]
31 | cells.append(" ".join(string_coords))
32 | else:
33 | # This is an empty cell
34 | string_coords = ["-1.0", "-1.0", "-1.0", "-1.0", "1"]
35 | text = ""
36 | string_coords = string_coords + [text]
37 | cells.append(" ".join(string_coords))
38 |
39 | return cells
40 |
41 |
42 | def group_det_bbox(
43 | pred_bbox_tensor, gold_bbox_tensor, IOU_threshold=0.1, IOP_threshold=0.1
44 | ):
45 | """Map pred bbox to gold bbox based on IOU.
46 |
47 | Args:
48 | pred_bbox_tensor: torch.Tensor, (N, 4)
49 | gold_bbox_tensor: torch.Tensor, (M, 4)
50 | IOU_threshold: float, threshold for IOU
51 | IOP_threshold: float, threshold for IOP
52 |
53 | """
54 |
55 | x_left_y_top_tensor = torch.max(
56 | pred_bbox_tensor.unsqueeze(1)[:, :, :2], gold_bbox_tensor.unsqueeze(0)[:, :, :2]
57 | ) # (N, M, 2)
58 | x_right_y_bottom_tensor = torch.min(
59 | pred_bbox_tensor.unsqueeze(1)[:, :, 2:], gold_bbox_tensor.unsqueeze(0)[:, :, 2:]
60 | ) # (N, M, 2)
61 |
62 | x_left = x_left_y_top_tensor[:, :, 0]
63 | y_top = x_left_y_top_tensor[:, :, 1]
64 | x_right = x_right_y_bottom_tensor[:, :, 0]
65 | y_bottom = x_right_y_bottom_tensor[:, :, 1]
66 |
67 | # Compute the intersection area
68 | intersection_area = torch.logical_or(x_right < x_left, y_bottom < y_top).float()
69 | intersection_area = 1 - intersection_area # (N, M)
70 | intersection_area = intersection_area * (x_right - x_left) * (y_bottom - y_top)
71 |
72 | # Compute the area of both bounding boxes
73 | bbox1_area = (pred_bbox_tensor[:, 2] - pred_bbox_tensor[:, 0]) * (
74 | pred_bbox_tensor[:, 3] - pred_bbox_tensor[:, 1]
75 | ) # (N,)
76 | bbox2_area = (gold_bbox_tensor[:, 2] - gold_bbox_tensor[:, 0]) * (
77 | gold_bbox_tensor[:, 3] - gold_bbox_tensor[:, 1]
78 | ) # (M,)
79 |
80 | # Compute the IOU
81 | iou = intersection_area / (
82 | bbox1_area.unsqueeze(1) + bbox2_area.unsqueeze(0) - intersection_area + 1e-6
83 | ) # (N, M)
84 |
85 | # Map the pred bbox to gold bbox
86 | iou_max, iou_gold_bbox_idx = torch.max(iou, dim=1) # (N,)
87 | iou_pred_bbox_idx = torch.arange(pred_bbox_tensor.shape[0]) # (N,)
88 | iou_pred_bbox_idx = iou_pred_bbox_idx[iou_max > IOU_threshold]
89 | iou_gold_bbox_idx = iou_gold_bbox_idx[iou_max > IOU_threshold]
90 |
91 | # For preds not associated with any gold bbox, recheck and associate if the overlap of pred bbox is > 0.1
92 | iop = intersection_area / bbox1_area.unsqueeze(1) # (N, M)
93 | iop_max, iop_gold_bbox_idx = torch.max(iop, dim=1) # (N,)
94 | iop_pred_bbox_idx = torch.arange(pred_bbox_tensor.shape[0]) # (N,)
95 | bool_mask = torch.logical_and(iop_max > IOP_threshold, iou_max <= IOU_threshold)
96 | iop_pred_bbox_idx = iop_pred_bbox_idx[bool_mask]
97 | iop_gold_bbox_idx = iop_gold_bbox_idx[bool_mask]
98 |
99 | pred_bbox_idx = torch.cat([iou_pred_bbox_idx, iop_pred_bbox_idx], dim=0)
100 | gold_bbox_idx = torch.cat([iou_gold_bbox_idx, iop_gold_bbox_idx], dim=0)
101 |
102 | return pred_bbox_idx, gold_bbox_idx, iou, intersection_area
103 |
104 |
105 | def preprocess_det_bbox(
106 | pred_bbox_collection, gold_bbox_collection, IOU_threshold=0.1, IOP_threshold=0.1
107 | ):
108 | """Preprocess detected bbox and gold bbox.
109 |
110 | Args:
111 | pred_bbox_collection: List[List[float]], list of detected bounding boxes
112 | Each detected bounding box is represented as [x1, y1, x2, y2, x3, y3, x4, y4]
113 | gold_bbox_collection: List[Dict], list of gold bounding boxes
114 | IOU_threshold: float, threshold for IOU (Intersection over Union)
115 | IOP_threshold: float, threshold for IOP (Intersection over Prediction)
116 | """
117 |
118 | # Reformat bounding boxes to [x_left, y_top, x_right, y_bottom]
119 | pred_cell_bboxes = [
120 | [
121 | min(coord[0], coord[2], coord[4], coord[6]),
122 | min(coord[1], coord[3], coord[5], coord[7]),
123 | max(coord[0], coord[2], coord[4], coord[6]),
124 | max(coord[1], coord[3], coord[5], coord[7]),
125 | ]
126 | for coord in pred_bbox_collection
127 | ]
128 |
129 | gold_cell_bboxes = [x["bbox"] for x in gold_bbox_collection if "bbox" in x]
130 | gold_cell_contents = [
131 | "".join(x["tokens"]) for x in gold_bbox_collection if "bbox" in x
132 | ]
133 | grp_to_filled_gold_idx_mapping = {}
134 | current_filled_idx = 0
135 | for gold_idx, gold_bbox in enumerate(gold_bbox_collection):
136 | if "bbox" in gold_bbox:
137 | grp_to_filled_gold_idx_mapping[current_filled_idx] = gold_idx
138 | current_filled_idx += 1
139 |
140 | pred_bbox_tensor = torch.tensor(pred_cell_bboxes)
141 | gold_bbox_tensor = torch.tensor(gold_cell_bboxes)
142 |
143 | pred_bbox_idx, gold_bbox_idx, iou, intersection_area = group_det_bbox(
144 | pred_bbox_tensor,
145 | gold_bbox_tensor,
146 | IOU_threshold=IOU_threshold,
147 | IOP_threshold=IOP_threshold,
148 | )
149 |
150 | # Group up pred_idx by common gold_bbox_idx
151 | pred_bbox_idx_group = {}
152 | for pred_idx, gold_idx in zip(pred_bbox_idx, gold_bbox_idx):
153 | if gold_idx.item() not in pred_bbox_idx_group:
154 | pred_bbox_idx_group[gold_idx.item()] = []
155 | pred_bbox_idx_group[gold_idx.item()].append(pred_idx.item())
156 |
157 | def sort_bbox_coords(bbox_coord_list):
158 | if len(bbox_coord_list) == 1:
159 | return bbox_coord_list
160 | else:
161 | new_list = []
162 | # 1. get min_height across all bboxes in the list
163 | min_height = 100000
164 | for bbox_coord in bbox_coord_list:
165 | min_height = min(min_height, bbox_coord[3] - bbox_coord[1])
166 |
167 | sorted_by_height_interval = {}
168 | for bbox_coord in bbox_coord_list:
169 | height_interval = int(bbox_coord[1] / min_height)
170 | if height_interval not in sorted_by_height_interval:
171 | sorted_by_height_interval[height_interval] = []
172 | sorted_by_height_interval[height_interval].append(bbox_coord)
173 |
174 | for height_interval in sorted(sorted_by_height_interval.keys()):
175 | bbox_coords = sorted_by_height_interval[height_interval]
176 | # sort bbox_coords by x_left
177 | bbox_coords = sorted(bbox_coords, key=lambda x: x[0])
178 | new_list.extend(bbox_coords)
179 |
180 | return new_list
181 |
182 | # First, serialize bbox coords within each group
183 | for gold_idx in pred_bbox_idx_group.keys():
184 | pred_bbox_idx_group[gold_idx] = sort_bbox_coords(
185 | [pred_cell_bboxes[x] for x in pred_bbox_idx_group[gold_idx]]
186 | )
187 |
188 | # Next, serialize groups
189 | gold_idx_first_bbox_list = [(k, v[0]) for k, v in pred_bbox_idx_group.items()]
190 | minimum_group_height = 100000
191 | for gold_idx, bbox_coord in gold_idx_first_bbox_list:
192 | minimum_group_height = min(minimum_group_height, bbox_coord[3] - bbox_coord[1])
193 |
194 | sorted_by_group_height_interval = {}
195 | for gold_idx, bbox_coord in gold_idx_first_bbox_list:
196 | height_interval = int(bbox_coord[1] / minimum_group_height)
197 | if height_interval not in sorted_by_group_height_interval:
198 | sorted_by_group_height_interval[height_interval] = []
199 | sorted_by_group_height_interval[height_interval].append((gold_idx, bbox_coord))
200 |
201 | sorted_gold_idx_first_bbox_list = []
202 | for height_interval in sorted(sorted_by_group_height_interval.keys()):
203 | gold_idx_first_bbox_list = sorted_by_group_height_interval[height_interval]
204 | # sort bbox_coords by x_left
205 | gold_idx_first_bbox_list = sorted(
206 | gold_idx_first_bbox_list, key=lambda x: x[1][0]
207 | )
208 | sorted_gold_idx_first_bbox_list.extend(gold_idx_first_bbox_list)
209 |
210 | # Finally, serialize the whole list
211 | serialized_pred_bbox = {}
212 | for group_idx, group_data in enumerate(sorted_gold_idx_first_bbox_list):
213 | serialized_pred_bbox[group_idx] = [
214 | pred_bbox_idx_group[group_data[0]],
215 | grp_to_filled_gold_idx_mapping[group_data[0]],
216 | gold_cell_contents[group_data[0]],
217 | ]
218 |
219 | return serialized_pred_bbox
220 |
221 |
222 | if __name__ == "__main__":
223 | parser = argparse.ArgumentParser()
224 | parser.add_argument(
225 | "--data_config_path", type=str, help="Path to the data config file"
226 | )
227 | parser.add_argument(
228 | "--output_dir", type=str, help="Output directory to save the preprocessed data"
229 | )
230 | parser.add_argument(
231 | "--bin_idx", type=int, default=-1, help="Index of the bin to process"
232 | )
233 | parser.add_argument(
234 | "--num_bins",
235 | type=int,
236 | default=0,
237 | help="Number of bins to split the dataset into",
238 | )
239 | args = parser.parse_args()
240 |
241 | # sanity check
242 | if args.bin_idx != -1:
243 | assert args.num_bins > 0
244 |
245 | # config loading and setting up for preprocessing
246 | data_config = json.load(open(args.data_config_path, "r"))
247 | random.seed(42)
248 | generated_dataset = {
249 | k: [] for k in data_config["SPLITS"]
250 | } # This is to store all preprocessed data
251 |
252 | # Data loading
253 | meta_pubtabnet_data = open(data_config["PUBTABNET_PATH"], "r").readlines()
254 | pse_det_data = {"train": {}, "validation": {}}
255 | for split in data_config["SPLITS"]:
256 | for pickle_path in tqdm(
257 | data_config["DR_COORD_PATH"][split].values(),
258 | desc="Loading PSE Det result for %s" % split,
259 | ):
260 | pse_data_list_loaded = pickle.load(open(pickle_path, "rb"))
261 | for pse_data in pse_data_list_loaded:
262 | pse_det_data[split][pse_data["file_name"]] = pse_data
263 |
264 | # In PubTabNet, ambiguous HTML representations for training and validation datasets are removed. NOTE: this is not done for test dataset
265 | split_dataset = {"train": [], "validation": []}
266 | ambiguous_data_filenames = []
267 | for split_type, split_amb_filenames in json.load(
268 | open(data_config["AMBIGUOUS_DATA_PATH"], "r")
269 | ).items():
270 | ambiguous_data_filenames.extend(split_amb_filenames)
271 | for raw_data in tqdm(meta_pubtabnet_data, desc="Removing amgiguous data"):
272 | data = json.loads(raw_data)
273 | if data["filename"] in ambiguous_data_filenames:
274 | continue
275 | split = data["split"]
276 | if split == "val":
277 | split = "validation"
278 | assert split in ["train", "validation"], (
279 | "Invalid split %s" % split
280 | ) # NOTE: Test dataset is not processed at this stage
281 | split_dataset[split].append(raw_data)
282 |
283 | # Preprocessing
284 | # Preprocess Train and Validation dataset
285 | for split in ["train", "validation"]:
286 | if args.bin_idx != -1:
287 | # Split the dataset into bins
288 | num_data = len(split_dataset[split])
289 | bin_size = int(num_data / args.num_bins)
290 | start_idx = args.bin_idx * bin_size
291 | if args.bin_idx == args.num_bins - 1:
292 | end_idx = num_data
293 | else:
294 | end_idx = start_idx + bin_size
295 | sliced_dataset = split_dataset[split][start_idx:end_idx]
296 | else:
297 | sliced_dataset = split_dataset[split]
298 |
299 | for raw_data in tqdm(sliced_dataset, desc="Pre-processing split %s" % split):
300 | loaded_data = json.loads(raw_data)
301 | data_filename = loaded_data["filename"]
302 |
303 | # GET OTSL representation
304 | otsl_seq, num_rows, num_cols = convert_html_to_otsl(
305 | html_seq=loaded_data["html"]["structure"]["tokens"],
306 | otsl_tag_maps=data_config["OTSL_TAG"],
307 | )
308 | gold_bbox_seq = format_pubtabnet_gold_coords(loaded_data["html"]["cells"])
309 |
310 | # Get pse det result
311 | pse_det_result = pse_det_data[split][data_filename]
312 |
313 | # Get pred and gold bbox idx grouping
314 | # pred_bbox_idx is a dict mapping group_idx to list of bbox_coords
315 | pred_bbox_idx = preprocess_det_bbox(
316 | pse_det_result["bbox"],
317 | loaded_data["html"]["cells"],
318 | IOU_threshold=0.1,
319 | IOP_threshold=0.1,
320 | )
321 |
322 | data_entry = {
323 | "file_name": loaded_data["filename"],
324 | "dr_coord": pred_bbox_idx,
325 | "gold_coord": gold_bbox_seq,
326 | "org_html": loaded_data["html"]["structure"]["tokens"],
327 | "otsl_seq": otsl_seq,
328 | "num_rows": num_rows,
329 | "num_cols": num_cols,
330 | "split": split,
331 | }
332 |
333 | generated_dataset[split].append(data_entry)
334 |
335 | list_of_data = generated_dataset[split]
336 | if args.bin_idx != -1:
337 | savename = "%s/dataset_%s_%d_%d.jsonl" % (
338 | args.output_dir,
339 | split,
340 | args.bin_idx,
341 | args.num_bins,
342 | )
343 | else:
344 | savename = "%s/dataset_%s.jsonl" % (args.output_dir, split)
345 | with open(savename, "w") as f:
346 | for d in list_of_data:
347 | f.write(json.dumps(d) + "\n")
348 |
--------------------------------------------------------------------------------
/dataset/preprocess_data_utils.py:
--------------------------------------------------------------------------------
1 | def convert_html_to_otsl(html_seq, otsl_tag_maps):
2 | """
3 | Convert list of html tokens to OTSL format
4 |
5 | Args:
6 | html_seq List[str]: list of html tokens
7 | E.g. ['', '', '', ' | ', ...]
8 | ost_tag_maps Dict[str, str]: mapping of otsl tag symbols
9 | E.g.{"C": "C-tag", "L": "L-tag", "U": "U-tag", "X": "X-tag", "NL": "NL-tag"}
10 |
11 | Returns:
12 | List[str]: Full OTSL sequence
13 | Int: Number of rows in the table
14 | Int: Number of columns in the table
15 | """
16 |
17 | # 1. Split list of HTML tokens into head and body
18 | end_of_head_index = html_seq.index(" ")
19 | thead_seq = html_seq[:end_of_head_index]
20 | thead_seq = [
21 | x for x in thead_seq if x not in ["", "", "", ""]
22 | ]
23 | tbody_seq = html_seq[(end_of_head_index + 1) :]
24 | tbody_seq = [
25 | x for x in tbody_seq if x not in ["", "", "", ""]
26 | ]
27 |
28 | # 2. Format HTML tags into row-wise list of tokens
29 | thead_row_wise_seq = get_row_wise(thead_seq)
30 | tbody_row_wise_seq = get_row_wise(tbody_seq)
31 |
32 | # 2.1 Check if the thead section is empty
33 | is_head_empty = False
34 | if len(thead_row_wise_seq) == 0:
35 | is_head_empty = True
36 |
37 | # 3. Convert row-wise list of tokens into OTSL array -> 4. Convert OTSL array into OTSL sequence
38 | thead_OTSL_array, num_head_rows, num_cols = None, 0, None
39 | thead_OTSL_seq = []
40 | if not is_head_empty:
41 | thead_OTSL_array, num_head_rows, num_cols = get_OTSL_array(
42 | thead_row_wise_seq, otsl_tag_maps
43 | )
44 | thead_OTSL_seq = convert_OTSL_array_to_OTSL_seq(
45 | thead_OTSL_array,
46 | num_rows=num_head_rows,
47 | num_cols=num_cols,
48 | otsl_tag_maps=otsl_tag_maps,
49 | )
50 |
51 | tbody_OTSL_array, num_body_rows, num_cols = get_OTSL_array(
52 | tbody_row_wise_seq, otsl_tag_maps, ref_num_cols=num_cols
53 | )
54 | tbody_OTSL_seq = convert_OTSL_array_to_OTSL_seq(
55 | tbody_OTSL_array,
56 | num_rows=num_body_rows,
57 | num_cols=num_cols,
58 | otsl_tag_maps=otsl_tag_maps,
59 | )
60 |
61 | # 5. Combine thead and tbody into one OTSL sequence
62 | combined_OTSL_seq = (
63 | [""]
64 | + thead_OTSL_seq
65 | + ["", ""]
66 | + tbody_OTSL_seq
67 | + [""]
68 | )
69 | num_rows = num_head_rows + num_body_rows
70 |
71 | return combined_OTSL_seq, num_rows, num_cols
72 |
73 |
74 | def get_OTSL_array(row_wise_html_tags, otsl_tag_maps, ref_num_cols=None):
75 | """Generate OTSL array from row-wise html tags.
76 |
77 | Args:
78 | row_wise_html_tags List[List[str]]: list of list of html tags, where each inner list is a row
79 | E.g. [['', ' | '], ['', ' | ']]
80 | otsl_tag_maps Dict[str, str]: mapping of otsl tag symbols
81 | E.g.{"C": "C-tag", "L": "L-tag", "U": "U-tag", "X": "X-tag", "NL": "NL-tag"}
82 | ref_num_cols int: reference number of columns to use. If None, will derive from row_wise_html_tags
83 | - Used to sanity check if tbody's num_cols match that of thead
84 |
85 | Returns:
86 | Tuple[List[List[str]], int, int]: OTSL array, number of rows, number of columns
87 | """
88 | num_rows, num_cols = get_num_rows_and_cols(row_wise_html_tags)
89 | if ref_num_cols is not None and ref_num_cols != num_cols:
90 | raise ValueError(
91 | "Number of columns in tbody does not match that of thead. Got %s but expected %s"
92 | % (num_cols, ref_num_cols)
93 | )
94 |
95 | # 1. Initialize OTSL array
96 | otsl_array = [list([None] * num_cols) for _ in range(num_rows)]
97 |
98 | # 2. Fill in OTSL array
99 | curr_row_ind, curr_col_ind = 0, 0
100 | current_data = {"standard": 0, "rowspan": 0, "colspan": 0}
101 | for row_tokens in row_wise_html_tags:
102 | for tok_i, tok in enumerate(row_tokens):
103 | # 2.1 sanity check token
104 | if tok not in ["", " | ", ""] and (
105 | "rowspan" not in tok and "colspan" not in tok
106 | ):
107 | raise ValueError("Invalid HTML %s" % tok)
108 |
109 | # 2.2 iter over tokens in the row
110 | if tok in [" | "]:
111 | continue
112 | elif tok == " | ":
113 | current_data["standard"] += 1
114 | elif "rowspan" in tok:
115 | current_data["rowspan"] += int(tok.split("=")[1].split('"')[1])
116 | elif "colspan" in tok:
117 | current_data["colspan"] += int(tok.split("=")[1].split('"')[1])
118 | elif tok == " | ":
119 | # End of cell -> i.e. Time to start updating OTSL array with current_data
120 | # 2.2.1 Find row & col ind to insert data
121 | while otsl_array[curr_row_ind][curr_col_ind] is not None:
122 | curr_col_ind += 1
123 | if curr_col_ind >= num_cols:
124 | curr_col_ind = 0
125 | curr_row_ind += 1
126 | assert (
127 | curr_row_ind < num_rows
128 | ), "curr_row_ind %s >= num_rows %s" % (curr_row_ind, num_rows)
129 |
130 | # 2.2.2 Sanity check current_data before insertion
131 | sanity_check_move(current_data)
132 |
133 | # 2.2.3 Insert data
134 | otsl_array = insert_data_into_OTSL(
135 | current_data=current_data,
136 | OTSL_array=otsl_array,
137 | otsl_tag_maps=otsl_tag_maps,
138 | curr_row_ind=curr_row_ind,
139 | curr_col_ind=curr_col_ind,
140 | )
141 |
142 | # 2.2.4 reset current_data
143 | current_data = {"standard": 0, "rowspan": 0, "colspan": 0}
144 |
145 | else:
146 | raise ValueError("Invalid HTML %s" % tok)
147 |
148 | return otsl_array, num_rows, num_cols
149 |
150 |
151 | def convert_OTSL_array_to_OTSL_seq(otsl_array, num_rows, num_cols, otsl_tag_maps):
152 | """Convert OTSL array to OTSL sequence.
153 |
154 | Args:
155 | otsl_array List[List[str]]: OTSL array
156 | num_rows int: number of rows in OTSL array
157 | num_cols int: number of columns in OTSL array
158 | otsl_tag_maps Dict[str, str]: mapping of otsl tag symbols
159 | E.g.{"C": "C-tag", "L": "L-tag", "U": "U-tag", "X": "X-tag", "NL": "NL-tag"}
160 |
161 | Returns:
162 | List[str]: OTSL sequence
163 | """
164 | OTSL_seq = []
165 |
166 | for row_ind in range(num_rows):
167 | for col_ind in range(num_cols):
168 | assert (
169 | otsl_array[row_ind][col_ind] is not None
170 | ), "row_ind %s, col_ind %s" % (row_ind, col_ind)
171 | OTSL_seq.append(otsl_array[row_ind][col_ind])
172 |
173 | OTSL_seq.append(otsl_tag_maps["NL"])
174 |
175 | return OTSL_seq
176 |
177 |
178 | # -----Auxiliary Functions-----#
179 | def get_row_wise(tok_list):
180 | """Given list of HTML tokens, group them into row-wise format.
181 |
182 | NOTE:
183 | Raises error if there are tokens not encapsulated by
184 |
185 | Args:
186 | tok_list List[str]: list of html tokens
187 | E.g. ['', '', ' | ', ...]
188 |
189 | Returns:
190 | List[List[str]]: list of list of tokens, where each inner list is a row
191 | """
192 | row_wise_tokens = []
193 |
194 | is_within_row = False
195 | for tok in tok_list:
196 | if tok == " ":
197 | is_within_row = True
198 | tmp_row = []
199 | elif tok == " ":
200 | is_within_row = False
201 | row_wise_tokens.append(tmp_row)
202 | else:
203 | assert is_within_row, "Token not encapsulated by "
204 | tmp_row.append(tok)
205 |
206 | return row_wise_tokens
207 |
208 |
209 | def get_num_rows_and_cols(row_wise_html_tags):
210 | """Given row-wise html tags, derive number of rows and columns.
211 |
212 | Args:
213 | row_wise_html_tags List[List[str]]: list of list of html tags, where each inner list is a row
214 | E.g. [['', ' | '], ['', ' | ']]
215 |
216 | Returns:
217 | Tuple[int, int]: number of rows and columns
218 | """
219 |
220 | # Derive the number of rows in this table
221 | num_rows = len(row_wise_html_tags)
222 | num_cols = 0
223 | col_span_tracker = 0
224 |
225 | # Derive the number of columns in this table
226 | for first_row_tok in row_wise_html_tags[0]:
227 | if first_row_tok == " | | ":
228 | if col_span_tracker == 0:
229 | num_cols += 1
230 | else:
231 | num_cols += col_span_tracker
232 | col_span_tracker = 0
233 | else:
234 | if "colspan" in first_row_tok:
235 | col_span_tracker += int(first_row_tok.split("=")[1].split('"')[1])
236 |
237 | return num_rows, num_cols
238 |
239 |
240 | def sanity_check_move(current_data):
241 | """Sanity checker of current move data prior to updating OTSL array.
242 |
243 | Args:
244 | current_data Dict: current data of move
245 | E.g. {'standard': 1, 'rowspan': 0, 'colspan': 0}
246 |
247 | Checks:
248 | 1. If standard (i.e. single cell), then rowspan and colspan must be 0
249 | 2. If not standard, then rowspan or colspan must be > 0
250 | """
251 |
252 | if current_data["standard"] == 0:
253 | assert sum([current_data["rowspan"], current_data["colspan"]]) > 0
254 | else:
255 | assert current_data["standard"] == 1
256 | assert sum([current_data["rowspan"], current_data["colspan"]]) == 0
257 |
258 |
259 | def insert_data_into_OTSL(
260 | current_data, OTSL_array, otsl_tag_maps, curr_row_ind, curr_col_ind
261 | ):
262 | """Given current_data, insert data into OTSL array.
263 |
264 | Args:
265 | current_data Dict: current data of move
266 | E.g. {'standard': 1, 'rowspan': 0, 'colspan': 0}
267 | OTSL_array List[List[str]]: OTSL array
268 | otsl_tag_maps Dict: mapping of otsl tag symbols
269 | curr_row_ind int: current row index
270 | curr_col_ind int: current column index
271 |
272 | NOTE:
273 | This function updates the OTSL array based on the current_data.
274 | There are 4 cases in total:
275 | 1. Standard cell (i.e. single cell, no rowspan or colspan)
276 | 2. Colspan only
277 | 3. Rowspan only
278 | 4. Both rowspan and colspan
279 |
280 | Returns:
281 | List[List[str]]: updated OTSL array
282 | """
283 |
284 | if current_data["standard"] == 1:
285 | assert OTSL_array[curr_row_ind][curr_col_ind] is None
286 | OTSL_array[curr_row_ind][curr_col_ind] = otsl_tag_maps[
287 | "C"
288 | ] # single cell mapped as 'C' in OTSL
289 | else:
290 | # Colspan only
291 | if current_data["rowspan"] == 0:
292 | assert OTSL_array[curr_row_ind][curr_col_ind] is None
293 | OTSL_array[curr_row_ind][curr_col_ind] = otsl_tag_maps["C"]
294 | for i in range(1, current_data["colspan"]):
295 | assert OTSL_array[curr_row_ind][curr_col_ind + i] is None
296 | OTSL_array[curr_row_ind][curr_col_ind + i] = otsl_tag_maps[
297 | "L"
298 | ] # All cells other than root for colspan mapped as 'L' in OTSL
299 |
300 | # Rowspan only
301 | elif current_data["colspan"] == 0:
302 | assert OTSL_array[curr_row_ind][curr_col_ind] is None
303 | OTSL_array[curr_row_ind][curr_col_ind] = otsl_tag_maps["C"]
304 | for i in range(1, current_data["rowspan"]):
305 | assert OTSL_array[curr_row_ind + i][curr_col_ind] is None
306 | OTSL_array[curr_row_ind + i][curr_col_ind] = otsl_tag_maps[
307 | "U"
308 | ] # All cells other than root for rowspan mapped as 'U' in OTSL
309 |
310 | # Both rowspan and colspan
311 | else:
312 | assert OTSL_array[curr_row_ind][curr_col_ind] is None
313 | OTSL_array[curr_row_ind][curr_col_ind] = otsl_tag_maps["C"]
314 |
315 | for i in range(1, current_data["colspan"]):
316 | assert OTSL_array[curr_row_ind][curr_col_ind + i] is None
317 | OTSL_array[curr_row_ind][curr_col_ind + i] = otsl_tag_maps["L"]
318 |
319 | for i in range(1, current_data["rowspan"]):
320 | assert OTSL_array[curr_row_ind + i][curr_col_ind] is None
321 | OTSL_array[curr_row_ind + i][curr_col_ind] = otsl_tag_maps["U"]
322 |
323 | for i in range(1, current_data["rowspan"]):
324 | for j in range(1, current_data["colspan"]):
325 | assert OTSL_array[curr_row_ind + i][curr_col_ind + j] is None
326 | OTSL_array[curr_row_ind + i][curr_col_ind + j] = otsl_tag_maps["X"]
327 |
328 | return OTSL_array
329 |
330 |
331 | def calculate_pointer_index(
332 | curr_row_ind, curr_col_ind, row_offset, col_offset, is_table_body, num_cols
333 | ):
334 | """Given current row & col index, along with other info, calculate the index to point to for potsl."""
335 |
336 | # Apply offset values
337 | point_index = (curr_row_ind - row_offset) * num_cols + (curr_col_ind - col_offset)
338 |
339 | # Add number of rows to pointer as each row ends with NL tag
340 | point_index += curr_row_ind - row_offset
341 |
342 | # If current table is tbody, offset by 3 since , ,