├── LICENSE ├── README.md ├── header-annotation ├── 01_extract_headers_from_html.py ├── 02_generate_training_sequences.py ├── 03_train_model.py ├── 04_generate_probabilities.py ├── 05_annotate_headers.py ├── config.ini ├── header_util.py ├── missing_chapter_util.py └── regex_util.py └── segmentation ├── bert_full_window ├── 01_generate_training_sequences.py ├── 02_train_BERT_model.py └── 03_generate_BERT_probabilities.py ├── bert_single_para ├── 01_generate_training_data.py ├── 02_tokenize_sequences.py ├── 03_train_BERT_model.py ├── 04_generate_BERT_probabilities.py └── 05_generate_predictions_dp.py ├── bert_tokenize_test_books.py ├── generate_ground_truth.py ├── metrics └── generate_metrics.py ├── paragraph_to_sentence.py ├── tokenize_books.py └── weighted_overlap ├── compute_densities.py ├── compute_peaks_prominences.py └── get_predictions_dp.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Charuta Pethe 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # chapter-captor 2 | Code for the EMNLP 2020 paper titled "Chapter Captor: Text Segmentation in Novels" 3 | 4 | The dataset can be downloaded from: https://drive.google.com/file/d/13BAP8FjnabbSzb1wRWyIwmdbOP0d7fam/view?usp=sharing 5 | -------------------------------------------------------------------------------- /header-annotation/01_extract_headers_from_html.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import configparser 3 | import os 4 | import requests 5 | from bs4 import BeautifulSoup, NavigableString, Comment 6 | import re 7 | import glob 8 | import gzip 9 | import matplotlib.pyplot as plt 10 | import nltk 11 | import pandas as pd 12 | from sklearn.cluster import AgglomerativeClustering 13 | import os.path 14 | from functools import partial 15 | import multiprocessing 16 | 17 | # Function to strip unwanted tags (such as italics) from text 18 | def strip_tags(html, invalid_tags): 19 | soup = BeautifulSoup(html, features='lxml') 20 | for tag in soup.findAll(True): 21 | if tag.name in invalid_tags: 22 | s = "" 23 | for c in tag.contents: 24 | if not isinstance(c, NavigableString): 25 | c = strip_tags(str(c), invalid_tags) 26 | s += str(c) 27 | tag.replaceWith(s) 28 | return soup 29 | 30 | # Class to store a header-content pair 31 | class HeaderContent(object): 32 | def __init__(self, header, content): 33 | self.header = header 34 | self.content = content 35 | 36 | def add_header(self, header): 37 | self.header.append(header) 38 | 39 | def add_paragraph(self, paragraph): 40 | self.content.append(paragraph) 41 | 42 | def get_num_headers(self): 43 | return len(self.header) 44 | 45 | def get_num_paras(self): 46 | return len(self.content) 47 | 48 | def get_num_words(self): 49 | return len(' '.join([x.strip() for x in self.content]).split()) 50 | 51 | def print_headers(self): 52 | for elem in self.header: 53 | print(elem.strip()) 54 | 55 | def print_joined_headers(self): 56 | print(' '.join([elem.strip().replace('\n', ' ') for elem in self.header])) 57 | 58 | def print_short_content(self): 59 | for elem in self.content: 60 | print(elem.strip()[:20]) 61 | 62 | # Function to obtain list of HeaderContent objects 63 | def segment_book(html_location): 64 | url = html_location 65 | with open(url, 'r') as f: 66 | html = f.read() 67 | 68 | soup = strip_tags(html, ['b', 'i', 'u']) 69 | 70 | book = list() 71 | prev_header = False 72 | curr_obj = HeaderContent(header=list(), content=list()) 73 | 74 | for x in soup.find_all('span', {'class': 'pagenum'}): 75 | x.decompose() 76 | for x in soup.find_all('span', {'class': 'returnTOC'}): 77 | x.decompose() 78 | for x in soup.find_all(attrs={'class': 'figcenter'}): 79 | x.decompose() 80 | for x in soup.find_all(attrs={'class': 'caption'}): 81 | x.decompose() 82 | for x in soup.find_all(attrs={'class': 'totoc'}): 83 | x.decompose() 84 | for x in soup.find_all(['pre', 'img', 'style']): 85 | x.decompose() 86 | for x in soup.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6']): 87 | for elem in x.find_all(): 88 | if elem.name == 'i': 89 | elem.decompose() 90 | 91 | skip_count = 0 92 | for x in soup.find_all(): 93 | if x.name == 'html': 94 | continue 95 | 96 | if skip_count > 0: 97 | skip_count -= 1 98 | continue 99 | 100 | if x.name == 'hr' and x.has_attr('class') and 'pb' in x['class']: 101 | book.append(curr_obj) 102 | curr_obj = HeaderContent(header=list(), content=list()) 103 | 104 | if x.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']: 105 | if prev_header: 106 | curr_obj.add_header(x.text) 107 | else: 108 | book.append(curr_obj) 109 | curr_obj = HeaderContent(header=list(), content=list()) 110 | curr_obj.add_header(x.text) 111 | prev_header = True 112 | skip_count = len(x.find_all()) 113 | else: 114 | t = ''.join([elem for elem in x.find_all(text=True, recursive=False) if not isinstance(elem, Comment)]).strip() 115 | 116 | if t: 117 | if 'start of the project gutenberg' in ' '.join(t.lower().split()): 118 | book = list() 119 | curr_obj = HeaderContent(header=list(), content=list()) 120 | continue 121 | if 'end of the project gutenberg' in ' '.join(t.lower().split()): 122 | break 123 | if 'xml' in t and 'version' in t and 'encoding' in t: 124 | continue 125 | curr_obj.add_paragraph(t) 126 | prev_header = False 127 | book.append(curr_obj) 128 | return book 129 | 130 | 131 | def process_chapter(header_contents): 132 | retval = list() 133 | 134 | l = list() 135 | for elem in header_contents: 136 | headers = elem.header 137 | if len(headers) > 0 and headers[0].strip().lower().startswith('chapter'): 138 | l.append(len(headers)) 139 | max_count = max(set(l), key=l.count) 140 | 141 | for elem in header_contents: 142 | headers = elem.header 143 | found = False 144 | for idx in range(len(headers)): 145 | h = headers[idx] 146 | if idx < len(headers) - max_count: 147 | # Append as content 'C' 148 | retval.append(('C', h)) 149 | elif found: 150 | # Append words with label 'H' 151 | retval.append(('H', h)) 152 | elif h.strip().lower().startswith('chapter'): 153 | # Append this and all subsequent headers with label 'H' 154 | found = True 155 | retval.append(('H', h)) 156 | else: 157 | retval.append(('C', h)) 158 | contents = elem.content 159 | for c in contents: 160 | retval.append(('C', c)) 161 | 162 | return retval 163 | 164 | def process_part(header_contents): 165 | retval = list() 166 | 167 | l = list() 168 | for elem in header_contents: 169 | headers = elem.header 170 | if len(headers) > 0 and headers[0].strip().lower().startswith('part'): 171 | l.append(len(headers)) 172 | max_count = max(set(l), key=l.count) 173 | 174 | for elem in header_contents: 175 | headers = elem.header 176 | found = False 177 | for idx in range(len(headers)): 178 | h = headers[idx] 179 | if idx < len(headers) - max_count: 180 | # Append as content 'C' 181 | retval.append(('C', h)) 182 | elif found: 183 | # Append words with label 'H' 184 | retval.append(('H', h)) 185 | elif h.strip().lower().startswith('part'): 186 | # Append this and all subsequent headers with label 'H' 187 | found = True 188 | retval.append(('H', h)) 189 | else: 190 | retval.append(('C', h)) 191 | contents = elem.content 192 | for c in contents: 193 | retval.append(('C', c)) 194 | 195 | return retval 196 | 197 | def process_roman(header_contents): 198 | retval = list() 199 | 200 | pattern = re.compile("^(?=[MDCLXVI])M*(C[MD]|D?C{0,3})(X[CL]|L?X{0,3})(I[XV]|V?I{0,3})$") 201 | 202 | l = list() 203 | for elem in header_contents: 204 | headers = elem.header 205 | if len(headers) > 0: 206 | proc = re.sub(r'([^\s\w]|_)+', ' ', headers[0].strip()).split() 207 | if len(proc) > 0 and pattern.match(proc[0]): 208 | l.append(len(headers)) 209 | max_count = max(set(l), key=l.count) 210 | 211 | for elem in header_contents: 212 | headers = elem.header 213 | found = False 214 | for idx in range(len(headers)): 215 | h = headers[idx] 216 | proc = re.sub(r'([^\s\w]|_)+', ' ', h.strip()).split() 217 | if idx < len(headers) - max_count: 218 | # Append as content 'C' 219 | retval.append(('C', h)) 220 | elif found: 221 | # Append words with label 'H' 222 | retval.append(('H', h)) 223 | elif len(proc) > 0 and pattern.match(proc[0]): 224 | # Append this and all subsequent headers with label 'H' 225 | found = True 226 | retval.append(('H', h)) 227 | else: 228 | retval.append(('C', h)) 229 | contents = elem.content 230 | for c in contents: 231 | retval.append(('C', c)) 232 | 233 | return retval 234 | 235 | def process_number(header_contents): 236 | retval = list() 237 | 238 | l = list() 239 | for elem in header_contents: 240 | headers = elem.header 241 | if len(headers) > 0 and headers[0].strip().lower().replace('.', ' ').split()[0].isnumeric(): 242 | l.append(len(headers)) 243 | max_count = max(set(l), key=l.count) 244 | 245 | for elem in header_contents: 246 | headers = elem.header 247 | found = False 248 | for idx in range(len(headers)): 249 | h = headers[idx] 250 | if idx < len(headers) - max_count: 251 | # Append as content 'C' 252 | retval.append(('C', h)) 253 | elif found: 254 | # Append words with label 'H' 255 | retval.append(('H', h)) 256 | elif h.strip().lower().replace('.', ' ').split()[0].isnumeric(): 257 | # Append this and all subsequent headers with label 'H' 258 | found = True 259 | retval.append(('H', h)) 260 | else: 261 | retval.append(('C', h)) 262 | contents = elem.content 263 | for c in contents: 264 | retval.append(('C', c)) 265 | return retval 266 | 267 | def process_base_case(header_contents): 268 | retval = list() 269 | 270 | if len(header_contents) == 0: 271 | return retval 272 | 273 | if len(header_contents) == 1: 274 | # return text and labels directly 275 | headers = header_contents[0].header 276 | header_text = ''.join([h for h in headers]) 277 | 278 | contents = header_contents[0].content 279 | content_text = ''.join([c for c in contents]) 280 | 281 | return [('H', header_text), ('C', content_text)] 282 | 283 | word_nums = [elem.get_num_words() for elem in header_contents] 284 | agg = AgglomerativeClustering(n_clusters=2, linkage='average').fit([[x] for x in word_nums]) 285 | cluster_word_arrs = dict() 286 | cluster_word_arrs[0] = list() 287 | cluster_word_arrs[1] = list() 288 | for idx in range(len(word_nums)): 289 | label = agg.labels_[idx] 290 | cluster_word_arrs[label].append(word_nums[idx]) 291 | mean_0 = sum(cluster_word_arrs[0]) / len(cluster_word_arrs[0]) 292 | mean_1 = sum(cluster_word_arrs[1]) / len(cluster_word_arrs[1]) 293 | if mean_0 > mean_1: 294 | greater_cluster = 0 295 | else: 296 | greater_cluster = 1 297 | 298 | labels_agg = list(agg.labels_) 299 | 300 | first_occ = labels_agg.index(greater_cluster) 301 | last_occ = len(labels_agg) - 1 - labels_agg[::-1].index(greater_cluster) 302 | 303 | # Count number of headers in each chapter heading, take mode of that number 304 | l = [len(x.header) for x in header_contents] 305 | max_count = max(set(l), key=l.count) 306 | 307 | for idx in range(len(header_contents)): 308 | headers = header_contents[idx].header 309 | if idx >= first_occ and idx <= last_occ: 310 | # Add text with header tag 311 | start_header_index = len(headers) - max_count 312 | for index in range(len(headers)): 313 | h = headers[index] 314 | if index < start_header_index: 315 | ans_label = 'C' 316 | else: 317 | ans_label = 'H' 318 | retval.append((ans_label, h)) 319 | else: 320 | # Add text with content tag 321 | for h in headers: 322 | retval.append(('C', h)) 323 | 324 | contents = header_contents[idx].content 325 | for c in contents: 326 | retval.append(('C', c)) 327 | 328 | return retval 329 | 330 | def process_header_contents(header_contents): 331 | headers = [' '.join([elem.strip().replace('\n', ' ') for elem in x.header]) for x in header_contents] 332 | 333 | chapter = 0 334 | for w in headers: 335 | if w.lower().startswith('chapter'): 336 | chapter += 1 337 | 338 | part = 0 339 | for w in headers: 340 | if w.lower().startswith('part'): 341 | part += 1 342 | 343 | pattern = re.compile("^(?=[MDCLXVI])M*(C[MD]|D?C{0,3})(X[CL]|L?X{0,3})(I[XV]|V?I{0,3})$") 344 | roman = 0 345 | for w in headers: 346 | w = re.sub(r'([^\s\w]|_)+', ' ', w) 347 | spl = w.strip().split() 348 | if len(spl) > 0 and pattern.match(spl[0]): 349 | roman += 1 350 | 351 | number = 0 352 | for w in headers: 353 | spl = w.lower().replace('.', ' ').split() 354 | if len(spl) > 0 and spl[0].isnumeric(): 355 | number += 1 356 | 357 | d = dict() 358 | d['chapter'] = chapter 359 | d['part'] = part 360 | d['roman'] = roman 361 | d['number'] = number 362 | 363 | descending = sorted(d.items(), key=lambda kv: -kv[1]) 364 | 365 | if descending[0][1] == 0: 366 | # Nothing worked 367 | pass 368 | 369 | elif descending[0][0] == 'chapter': 370 | return process_chapter(header_contents) 371 | 372 | elif descending[0][0] == 'part': 373 | return process_part(header_contents) 374 | 375 | elif descending[0][0] == 'roman': 376 | if descending[0][1] < 2: 377 | # failed, detected 'I' pronoun as number 378 | pass 379 | else: 380 | return process_roman(header_contents) 381 | 382 | elif descending[0][0] == 'number': 383 | if descending[0][1] < 2: 384 | # failed, false positive 385 | pass 386 | else: 387 | return process_number(header_contents) 388 | 389 | return process_base_case(header_contents) 390 | 391 | 392 | def process_book(html_dir, extracted_header_dir, book_id): 393 | input_location = os.path.join(html_dir, book_id + '.html') 394 | output_location = os.path.join(extracted_header_dir, book_id + '.csv') 395 | if os.path.exists(output_location): 396 | return book_id_str, 0 397 | try: 398 | header_content_pairs = segment_book(input_location) 399 | hc = process_header_contents(header_content_pairs) 400 | df = pd.DataFrame(hc, columns=['label', 'text']) 401 | df['key'] = (df['label'] != df['label'].shift(1)).astype(int).cumsum() 402 | df2 = pd.DataFrame(df.groupby(['key', 'label'])['text'].apply('\n\n'.join)) 403 | df2 = df2.reset_index()[['label', 'text']] 404 | df2.to_csv(output_location, index=False) 405 | return book_id_str, 0 406 | except: 407 | return book_id_str, -1 408 | 409 | 410 | if __name__ == '__main__': 411 | 412 | parser = argparse.ArgumentParser() 413 | parser.add_argument('--config_file', help='Configuration file', required=True) 414 | args = parser.parse_args() 415 | config_file = args.config_file 416 | 417 | config = configparser.ConfigParser() 418 | config.read_file(open(config_file)) 419 | 420 | # Read list of books to process 421 | book_list = config.get('01_Extract_headers_from_HTML', 'book_list') 422 | if not os.path.isfile(book_list): 423 | print('Please provide a valid file name for the list of book IDs in the "book_list" field.') 424 | exit() 425 | with open(book_list, 'r') as f: 426 | books = f.read().splitlines() 427 | 428 | # Read location of HTML files 429 | html_dir = config.get('01_Extract_headers_from_HTML', 'html_dir') 430 | if not os.path.isdir(html_dir): 431 | print('Please provide a valid directory name where the HTMLs are located, in the "html_dir" field.') 432 | exit() 433 | 434 | # Read location to store extracted headers 435 | extracted_header_dir = config.get('01_Extract_headers_from_HTML', 'extracted_header_dir') 436 | if not os.path.isdir(extracted_header_dir): 437 | os.makedirs(extracted_header_dir) 438 | 439 | # Read number of processes to use 440 | num_procs = int(config.get('01_Extract_headers_from_HTML', 'num_procs')) 441 | 442 | # Read location to store status of header extraction 443 | log_file = config.get('01_Extract_headers_from_HTML', 'log_file') 444 | 445 | func = partial(process_book, html_dir, extracted_header_dir) 446 | 447 | pool = multiprocessing.Pool(processes=num_procs) 448 | data = pool.map(func, books) 449 | pool.close() 450 | pool.join() 451 | 452 | print('Done! Saving status results to log file...') 453 | 454 | df = pd.DataFrame(data, columns=['bookID', 'status']) 455 | df.to_csv(log_file, index=False) 456 | 457 | print('Saved results to log file!') 458 | 459 | -------------------------------------------------------------------------------- /header-annotation/02_generate_training_sequences.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import configparser 3 | import os 4 | import nltk 5 | import gzip 6 | import pandas as pd 7 | from pytorch_pretrained_bert import BertTokenizer 8 | import random 9 | import multiprocessing 10 | import nltk 11 | import pickle 12 | 13 | import multiprocessing 14 | from functools import partial 15 | 16 | 17 | # Function to remove Gutenberg-specific header and footer 18 | def remove_gutenberg_header_footer(lines): 19 | start_arr = [idx for idx in range(len(lines)) if '***' in lines[idx] and 'START' in lines[idx].upper() and 'GUTENBERG' in lines[idx].upper()] 20 | end_arr = [idx for idx in range(len(lines)) if '***' in lines[idx] and 'END' in lines[idx].upper() and 'GUTENBERG' in lines[idx].upper()] 21 | 22 | if len(start_arr) > 0 and len(end_arr) > 0: 23 | return lines[start_arr[0] + 1 : end_arr[0]] 24 | elif len(start_arr) > 0: 25 | return lines[start_arr[0] + 1:] 26 | elif len(end_arr) > 0: 27 | return lines[:end_arr[0]] 28 | return lines 29 | 30 | # Function to obtain line numbers for annotated headers from text file 31 | def get_annotated_headers(book_index, text_file_dir, header_content_dir): 32 | def find_index(heading, book_lines): 33 | ans = list() 34 | idx = 0 35 | partial_match = 0 36 | curr = '' 37 | while idx < len(book_lines): 38 | if len(book_lines[idx]) == 0: 39 | if partial_match > 0: 40 | partial_match += 1 41 | idx += 1 42 | continue 43 | 44 | if partial_match > 0: 45 | curr += ' ' + book_lines[idx] 46 | curr = ' '.join(nltk.word_tokenize(curr)) 47 | if curr == heading: 48 | ans.append((start_idx, partial_match + 1)) 49 | partial_match = 0 50 | curr = '' 51 | elif heading.startswith(curr): 52 | partial_match += 1 53 | else: 54 | partial_match = 0 55 | curr = '' 56 | 57 | elif book_lines[idx] == heading: 58 | ans.append((idx, 1)) 59 | 60 | elif heading.startswith(book_lines[idx]): 61 | curr += book_lines[idx] 62 | partial_match = 1 63 | start_idx = idx 64 | 65 | idx += 1 66 | return ans 67 | 68 | lines = list() 69 | with gzip.open(text_file_dir + book_index + '.txt.gz', 'rt') as f: 70 | for line in f: 71 | lines.append(line) 72 | 73 | lines = remove_gutenberg_header_footer(lines) 74 | mod = [' '.join(nltk.word_tokenize((str(x)).strip())) for x in lines] 75 | 76 | df = pd.read_csv(header_content_dir + book_index + '.csv') 77 | if 'token' in df.columns: 78 | df['text'] = df['token'] 79 | hc = list() 80 | for idx, row in df.iterrows(): 81 | if row['label'] == 'H': 82 | header = row['text'] 83 | if idx + 1 < len(df): 84 | content = df['text'][idx + 1] 85 | else: 86 | content = ' ' 87 | hc.append((' '.join(nltk.word_tokenize(str(header).strip())), nltk.word_tokenize(str(content).strip())[0])) 88 | header_list = list() 89 | for header, content in hc: 90 | l = find_index(header, mod) 91 | for index, length in l: 92 | content_index = index + length 93 | while len(mod[content_index]) == 0: 94 | content_index += 1 95 | if mod[content_index].startswith(content): 96 | header_list.append((index, length)) 97 | else: 98 | pass 99 | return lines, header_list 100 | 101 | # Function to generate training sequences from a book 102 | def get_sequences_whitespace(lines, headers, seq_len, tokenizer): 103 | token_sequences = list() 104 | label_sequences = list() 105 | 106 | for index, length in headers: 107 | text = ' '.join(lines[index:index + length]) 108 | tokens = tokenizer.tokenize(text.replace('\n', ' [unused1] ')) 109 | 110 | need_tokens = seq_len - len(tokens) 111 | 112 | prev_tokens_needed = random.randint(0, need_tokens) 113 | next_tokens_needed = need_tokens - prev_tokens_needed 114 | 115 | # Generate previous tokens 116 | prev_tokens = list() 117 | idx = index - 1 118 | max_count = seq_len * 2 119 | while True: 120 | if idx < 0 or len(prev_tokens) >= prev_tokens_needed or max_count == 0: 121 | prev_tokens = prev_tokens[-prev_tokens_needed:] 122 | break 123 | try: 124 | prev_tokens = tokenizer.tokenize(lines[idx].replace('\n', ' [unused1] ')) + prev_tokens 125 | except: 126 | pass 127 | idx -= 1 128 | max_count -= 1 129 | 130 | # Generate next tokens 131 | next_tokens = list() 132 | idx = index + length 133 | max_count = seq_len * 2 134 | while True: 135 | if idx >= len(lines) or len(next_tokens) >= next_tokens_needed or max_count == 0: 136 | next_tokens = next_tokens[:next_tokens_needed] 137 | break 138 | try: 139 | next_tokens = next_tokens + tokenizer.tokenize(lines[idx].replace('\n', ' [unused1] ')) 140 | except: 141 | pass 142 | idx += 1 143 | max_count -= 1 144 | 145 | ts = prev_tokens + tokens + next_tokens 146 | ls = [0] * len(prev_tokens) + [1] * len(tokens) + [0] * len(next_tokens) 147 | 148 | if len(ts) == seq_len: 149 | token_sequences.append(tokenizer.convert_tokens_to_ids(ts)) 150 | label_sequences.append(ls) 151 | 152 | return token_sequences, label_sequences 153 | 154 | 155 | def process_book(text_files_dir, header_content_dir, seq_gen_dir, seq_len, tokenizer, book_index): 156 | try: 157 | lines, headers = get_annotated_headers(book_index, text_files_dir, header_content_dir) 158 | token_sequences, label_sequences = get_sequences_whitespace(lines, headers, seq_len, tokenizer) 159 | 160 | with open(os.path.join(seq_gen_dir, book_index + '_tokens.pkl'), 'wb') as f: 161 | pickle.dump(token_sequences, f) 162 | with open(os.path.join(seq_gen_dir, book_index + '_labels.pkl'), 'wb') as f: 163 | pickle.dump(label_sequences, f) 164 | return book_index, 0 165 | 166 | except Exception as e: 167 | print(e) 168 | return book_index, -1 169 | 170 | 171 | if __name__ == "__main__": 172 | parser = argparse.ArgumentParser() 173 | parser.add_argument('--config_file', help='Configuration file', required=True) 174 | args = parser.parse_args() 175 | config_file = args.config_file 176 | 177 | config = configparser.ConfigParser() 178 | config.read_file(open(config_file)) 179 | 180 | # Read the list of book IDs in the training set 181 | train_set_books_file = config.get('02_Generate_training_sequences', 'train_books_list') 182 | if not os.path.isfile(train_set_books_file): 183 | print('Please provide a valid file name for the list of training set book IDs in the "train_books_list" field.') 184 | exit() 185 | with open(train_set_books_file) as f: 186 | train_book_ids = [x.strip() for x in f.readlines()] 187 | 188 | 189 | # Read the directory which contains txt.gz files 190 | text_files_dir = config.get('02_Generate_training_sequences', 'text_files_dir') 191 | if not os.path.isdir(text_files_dir): 192 | print('Please provide a valid directory name where the txt.gz files for the training set books are stored, in the "text_files_dir" field.') 193 | exit() 194 | 195 | # Read the directory where the extracted headers are stored 196 | header_content_dir = config.get('01_Extract_headers_from_HTML', 'extracted_header_dir') 197 | if not os.path.isdir(extracted_header_dir): 198 | print('Please run 01_extract_headers_from_html.py first.') 199 | exit() 200 | 201 | # Read the directory where extracted training sequences are to be stored 202 | seq_gen_dir = config.get('02_Generate_training_sequences', 'generated_sequence_dir') 203 | if not os.path.isdir(seq_gen_dir): 204 | os.makedirs(seq_gen_dir) 205 | 206 | # Read the sequence length to be generated 207 | seq_len = int(config.get('02_Generate_training_sequences', 'seq_len')) 208 | 209 | # Read number of processes to use 210 | num_procs = int(config.get('02_Generate_training_sequences', 'num_procs')) 211 | 212 | # Read location to store status of header extraction 213 | log_file = config.get('02_Generate_training_sequences', 'log_file') 214 | 215 | # Define tokenizer 216 | tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False, never_split=['[unused1]']) 217 | 218 | 219 | func = partial(process_book, text_files_dir, header_content_dir, seq_gen_dir, seq_len, tokenizer) 220 | 221 | pool = multiprocessing.Pool(processes=num_procs) 222 | data = pool.map(func, train_book_ids) 223 | pool.close() 224 | pool.join() 225 | 226 | print('Done! Saving status results to log file...') 227 | 228 | df = pd.DataFrame(data, columns=['bookID', 'status']) 229 | df.to_csv(log_file, index=False) 230 | 231 | print('Saved results to log file!') -------------------------------------------------------------------------------- /header-annotation/03_train_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import configparser 3 | import os 4 | import tensorflow as tf 5 | import torch 6 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 7 | from sklearn.model_selection import train_test_split 8 | from pytorch_pretrained_bert import BertTokenizer, BertConfig 9 | from pytorch_pretrained_bert import BertAdam, BertForTokenClassification 10 | from tqdm import tqdm, trange 11 | import numpy as np 12 | import pickle 13 | 14 | 15 | # Accuracy metrics function 16 | def flat_accuracy(seq_pred, seq_labels): 17 | m = seq_pred.argmax(axis=2) 18 | m2 = m.flatten() 19 | m2 = m2.detach().cpu().numpy() 20 | l2 = seq_labels.flatten() 21 | l2 = l2.to('cpu').numpy() 22 | tp, tn, fp, fn = 0, 0, 0, 0 23 | for idx in range(len(m2)): 24 | if l2[idx] == 1 and m2[idx] == 1: 25 | tp += 1 26 | elif l2[idx] == 1 and m2[idx] == 0: 27 | fn += 1 28 | elif l2[idx] == 0 and m2[idx] == 1: 29 | fp += 1 30 | elif l2[idx] == 0 and m2[idx] == 0: 31 | tn += 1 32 | return np.sum(m2 == l2) / len(l2), tp, tn, fp, fn 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--config_file', help='Configuration file', required=True) 38 | args = parser.parse_args() 39 | config_file = args.config_file 40 | 41 | config = configparser.ConfigParser() 42 | config.read_file(open(config_file)) 43 | 44 | # Read location to store model checkpoints 45 | checkpoint_dir = config.get('03_Train_model', 'checkpoint_dir') 46 | if not os.path.isdir(checkpoint_dir): 47 | os.makedirs(checkpoint_dir) 48 | 49 | # Read set of training books 50 | # Read the list of book IDs in the training set 51 | train_set_books_file = config.get('02_Generate_training_sequences', 'train_books_list') 52 | if not os.path.isfile(train_set_books_file): 53 | print('Please provide a valid file name for the list of training set book IDs in the "train_books_list" field.') 54 | exit() 55 | with open(train_set_books_file) as f: 56 | train_book_ids = [x.strip() for x in f.readlines()] 57 | 58 | # Read the number of epochs to train for 59 | num_epochs = int(config.get('03_Train_model', 'num_epochs')) 60 | 61 | 62 | # Read the directory location where generated sequences are stored 63 | seq_gen_dir = config.get('02_Generate_training_sequences', 'generated_sequence_dir') 64 | if not os.path.isdir(seq_gen_dir): 65 | print('Please run 02_generate_training_sequences.py first.') 66 | exit() 67 | 68 | 69 | # Read token and label sequences 70 | token_list = list() 71 | label_list = list() 72 | 73 | for book_index in train_book_ids: 74 | try: 75 | with open(os.path.join(seq_gen_dir, book_index + '_tokens.pkl'), 'rb') as f: 76 | t_list = pickle.load(f) 77 | with open(os.path.join(seq_gen_dir, book_index + '_labels.pkl'), 'rb') as f: 78 | l_list = pickle.load(f) 79 | token_list += t_list 80 | label_list += l_list 81 | except Exception as e: 82 | print('Could not fetch sequences for: ' + book_index) 83 | print(e) 84 | continue 85 | 86 | print(len(token_list), " sequences in training data") 87 | 88 | 89 | # Train-validation split for loaded data 90 | print("Train-validation split") 91 | train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(token_list, label_list, random_state=2019, test_size=0.1) 92 | 93 | 94 | # Converting data to tensor form 95 | print("Initializing tensors") 96 | train_inputs = torch.tensor(train_inputs) 97 | validation_inputs = torch.tensor(validation_inputs) 98 | train_labels = torch.tensor(train_labels) 99 | validation_labels = torch.tensor(validation_labels) 100 | 101 | # Batch size 102 | batch_size = 32 103 | 104 | 105 | # Creating objects to use in training 106 | train_data = TensorDataset(train_inputs, train_labels) 107 | train_sampler = RandomSampler(train_data) 108 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size) 109 | 110 | validation_data = TensorDataset(validation_inputs, validation_labels) 111 | validation_sampler = SequentialSampler(validation_data) 112 | validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size) 113 | 114 | 115 | # GPU / CPU initialization 116 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 117 | n_gpu = torch.cuda.device_count() 118 | torch.cuda.get_device_name(0) 119 | 120 | 121 | 122 | # Model to fine-tune 123 | print("Initializing model") 124 | model = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=2) 125 | 126 | 127 | 128 | # Parameters 129 | param_optimizer = list(model.named_parameters()) 130 | no_decay = ['bias', 'gamma', 'beta'] 131 | optimizer_grouped_parameters = [ 132 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 133 | 'weight_decay_rate': 0.01}, 134 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 135 | 'weight_decay_rate': 0.0} 136 | ] 137 | 138 | 139 | 140 | # Optimizer 141 | optimizer = BertAdam(optimizer_grouped_parameters, lr=2e-5, warmup=.1) 142 | 143 | print("Making model GPU compatible") 144 | model = model.to(device) 145 | 146 | 147 | # Training loop 148 | train_loss_set = [] 149 | epochs = num_epochs 150 | 151 | count = 1 152 | 153 | 154 | for _ in trange(epochs, desc='Epoch'): 155 | print("Epoch " + str(count)) 156 | # Set to training mode 157 | model.train() 158 | 159 | tr_loss = 0 160 | nb_tr_examples, nb_tr_steps = 0, 0 161 | 162 | # Process each batch 163 | for step, batch in enumerate(train_dataloader): 164 | # Convert batch to GPU 165 | batch = tuple(t.to(device) for t in batch) 166 | # Unpack 167 | b_input_ids, b_labels = batch 168 | # Reset gradients to 0 169 | optimizer.zero_grad() 170 | # Compute loss 171 | loss = model(b_input_ids, labels=b_labels) 172 | # Append loss to list 173 | train_loss_set.append(loss.item()) 174 | # Back-prop 175 | loss.backward() 176 | # Update optimizer params 177 | optimizer.step() 178 | # Add to total training loss 179 | tr_loss += loss.item() 180 | # Add to number of training examples 181 | nb_tr_examples += b_input_ids.size(0) 182 | # Add to number of training steps 183 | nb_tr_steps += 1 184 | print("Train loss: {}".format(tr_loss/nb_tr_steps)) 185 | 186 | 187 | # Set to eval mode 188 | model.eval() 189 | 190 | eval_loss, eval_accuracy, tp, tn, fp, fn = 0, 0, 0, 0, 0, 0 191 | nb_eval_steps, nb_eval_examples = 0, 0 192 | 193 | # Process each batch 194 | for batch in validation_dataloader: 195 | # Convert batch to GPU 196 | batch = tuple(t.to(device) for t in batch) 197 | # Unpack 198 | b_input_ids, b_labels = batch 199 | # Do not update gradients 200 | with torch.no_grad(): 201 | # Get logits 202 | logits = model(b_input_ids, token_type_ids=None) 203 | # Compute metrics 204 | tmp_eval_accuracy, tmp_tp, tmp_tn, tmp_fp, tmp_fn = flat_accuracy(logits, b_labels) 205 | eval_accuracy += tmp_eval_accuracy 206 | tp += tmp_tp 207 | tn += tmp_tn 208 | fp += tmp_fp 209 | fn += tmp_fn 210 | nb_eval_steps += 1 211 | 212 | 213 | print('Validation Accuracy: {}'.format(eval_accuracy/nb_eval_steps)) 214 | print("TP = ", tp) 215 | print("TN = ", tn) 216 | print("FP = ", fp) 217 | print("FN = ", fn) 218 | print("----------------") 219 | 220 | # Save model 221 | torch.save({ 222 | 'epoch': count, 223 | 'model_state_dict': model.state_dict(), 224 | 'optimizer_state_dict': optimizer.state_dict(), 225 | 'loss': loss, 226 | }, os.path.join(checkpoint_dir, 'epoch_' + str(count) + '.pt')) 227 | 228 | count += 1 229 | 230 | print('Done!') 231 | 232 | -------------------------------------------------------------------------------- /header-annotation/04_generate_probabilities.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import configparser 3 | import os 4 | import lxml.etree as ET 5 | import pandas as pd 6 | from pytorch_pretrained_bert import BertTokenizer 7 | import torch 8 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 9 | from pytorch_pretrained_bert import BertForTokenClassification 10 | import numpy as np 11 | import tensorflow as tf 12 | import pickle 13 | 14 | def is_whitespace(c): 15 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 16 | return True 17 | return False 18 | 19 | def chunks(l, n): 20 | # For item i in a range that is a length of l, 21 | for i in range(0, len(l), n): 22 | # Create an index range for l of n items: 23 | yield l[i:i + n] 24 | 25 | 26 | def get_scores(tokens, model, device, tokenizer, sequence_length, slide_by): 27 | 28 | # Convert tokens to IDs 29 | chunk_list = list(chunks(tokens, 512)) 30 | toks = list() 31 | for c in chunk_list: 32 | toks += tokenizer.convert_tokens_to_ids(c) 33 | 34 | #toks = tokenizer.convert_tokens_to_ids(tokens) 35 | 36 | # Generate test sequences using sliding window 37 | test_sequences = list() 38 | test_labels_dummy = list() 39 | test_token_indices = list() 40 | 41 | idx = 0 42 | end_flag = False 43 | 44 | while idx < len(toks): 45 | if not end_flag and idx + sequence_length >= len(toks): 46 | idx = len(toks) - sequence_length 47 | end_flag = True 48 | # Get window 49 | s = toks[idx:idx + sequence_length] 50 | test_sequences.append(s) 51 | test_labels_dummy.append([0 for _ in s]) 52 | test_token_indices.append([elem for elem in range(idx, idx + sequence_length)]) 53 | idx += slide_by 54 | if end_flag: 55 | break 56 | 57 | # Get predictions for test sequences 58 | batch_size = 32 59 | prediction_inputs = torch.tensor(test_sequences) 60 | prediction_labels_dummy = torch.tensor(test_labels_dummy) 61 | 62 | prediction_data = TensorDataset(prediction_inputs, prediction_labels_dummy) 63 | prediction_sampler = SequentialSampler(prediction_data) 64 | prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size, num_workers=1) 65 | 66 | model.eval() 67 | 68 | predictions = list() 69 | 70 | for batch in prediction_dataloader: 71 | batch = tuple(t.to(device) for t in batch) 72 | b_input_ids, b_labels_dummy = batch 73 | with torch.no_grad(): 74 | # Forward pass, calculate logit predictions 75 | logits = model(b_input_ids, token_type_ids=None) 76 | logits = logits.detach().cpu().numpy() 77 | predictions.append(logits) 78 | 79 | # Flatten output 80 | flat_preds = list() 81 | for batch in predictions: 82 | # batch is 8 x 100 x 2 83 | for sequence in batch: 84 | # sequence is 100 x 2 85 | for probs in sequence: 86 | # probs is 1 x 2 87 | flat_preds.append(probs) 88 | 89 | flat_probabilities = list() 90 | for x in flat_preds: 91 | tmp0 = np.exp(x[0]) 92 | tmp1 = np.exp(x[1]) 93 | summ = tmp0 + tmp1 94 | flat_probabilities.append(tmp1 / summ) 95 | 96 | flat_token_indices = [item for sublist in test_token_indices for item in sublist] 97 | 98 | d_probs = dict() 99 | 100 | for iterator in range(len(flat_token_indices)): 101 | index = flat_token_indices[iterator] 102 | 103 | if index in d_probs: 104 | d_probs[index].append(flat_probabilities[iterator]) 105 | else: 106 | d_probs[index] = [flat_probabilities[iterator]] 107 | 108 | new_probs = [0] * (max(d_probs.keys()) + 1) 109 | for idx in d_probs.keys(): 110 | new_probs[idx] = max(d_probs[idx]) 111 | 112 | return new_probs 113 | 114 | 115 | def generate_header_probabilities_from_model(input_dir, output_dir, seq_len, model, device, tokenizer, book_id): 116 | try: 117 | base_path = os.path.join(input_dir, book_id + '.xml') 118 | header_csv_path = os.path.join(output_dir, book_id + '.csv') 119 | output_xml_path = os.path.join(output_dir, book_id + '_lines.xml') 120 | 121 | # Read from input XML 122 | parser = ET.XMLParser(huge_tree=True) 123 | tree = ET.parse(str(base_path), parser=parser) 124 | book = tree.getroot() 125 | 126 | # Get content from front matter and body (if present) 127 | front_matter = book.find('front') 128 | body = book.find('body') 129 | 130 | assert body is not None 131 | 132 | # Convert to lines format XML 133 | line_number = 0 134 | if front_matter is not None: 135 | if front_matter.text is not None: 136 | front_matter_lines = front_matter.text.splitlines(keepends=True) 137 | front_matter.text = None 138 | previous = None 139 | for elem in front_matter_lines: 140 | e = ET.Element("line") 141 | e.text = elem 142 | e.set("num", str(line_number)) 143 | if previous is None: 144 | front_matter.insert(0, e) 145 | else: 146 | previous.addnext(e) 147 | previous = e 148 | line_number += 1 149 | for child in front_matter.getchildren(): 150 | if child.tail is not None: 151 | child_lines = child.tail.splitlines(keepends=True) 152 | child.tail = None 153 | previous = child 154 | for elem in child_lines: 155 | e = ET.Element("line") 156 | e.text = elem 157 | e.set("num", str(line_number)) 158 | previous.addnext(e) 159 | previous = e 160 | line_number += 1 161 | 162 | if body.text is not None: 163 | body_lines = body.text.splitlines(keepends=True) 164 | body.text = None 165 | previous = None 166 | for elem in body_lines: 167 | e = ET.Element("line") 168 | e.text = elem 169 | e.set("num", str(line_number)) 170 | if previous is None: 171 | body.insert(0, e) 172 | else: 173 | previous.addnext(e) 174 | previous = e 175 | line_number += 1 176 | for child in body.getchildren(): 177 | if child.tail is not None: 178 | child_lines = child.tail.splitlines(keepends=True) 179 | child.tail = None 180 | previous = child 181 | for elem in child_lines: 182 | e = ET.Element("line") 183 | e.text = elem 184 | e.set("num", str(line_number)) 185 | previous.addnext(e) 186 | previous = e 187 | line_number += 1 188 | 189 | 190 | content = [x.text for x in book.findall(".//line")] 191 | 192 | 193 | # Generate probabilities per line 194 | 195 | # Convert to tokens 196 | matrix = list() 197 | for line_number, line in enumerate(content): 198 | text = line.replace('\n', ' [unused1] ') 199 | doc_tokens = list() 200 | char_to_word_offset = list() 201 | prev_is_whitespace = True 202 | for c in text: 203 | if is_whitespace(c): 204 | prev_is_whitespace = True 205 | else: 206 | if prev_is_whitespace: 207 | doc_tokens.append(c) 208 | else: 209 | doc_tokens[-1] += c 210 | prev_is_whitespace = False 211 | char_to_word_offset.append(len(doc_tokens) - 1) 212 | tok_to_orig_index = list() 213 | all_doc_tokens = list() 214 | for i, token in enumerate(doc_tokens): 215 | sub_tokens = tokenizer.tokenize(token) 216 | for sub_token in sub_tokens: 217 | tok_to_orig_index.append(i) 218 | all_doc_tokens.append(sub_token) 219 | doc_tokens_to_line_index = [char_to_word_offset.index(x) for x in range(len(doc_tokens))] 220 | 221 | for final_token_idx, final_token in enumerate(all_doc_tokens): 222 | matrix.append([final_token, line_number, doc_tokens_to_line_index[tok_to_orig_index[final_token_idx]]]) 223 | 224 | df = pd.DataFrame(matrix) 225 | df.rename(columns={0:'token', 1:'line_number', 2:'token_word_pos'}, inplace=True) 226 | 227 | token_list = list(df['token'].apply(lambda x: str(x))) 228 | 229 | probs = get_scores(token_list, model, device, tokenizer, seq_len, seq_len // 2) 230 | 231 | df['prob'] = probs 232 | 233 | df = df.groupby(['line_number', 'token_word_pos'], as_index=False).agg({'token': (lambda x: ''.join([y[2:] if y.startswith('##') else y for y in x])), 'prob': 'mean'}) 234 | 235 | df = df[['token', 'line_number', 'token_word_pos', 'prob']] 236 | 237 | df.to_csv(header_csv_path, index=False) 238 | 239 | with open(output_xml_path, 'wb') as f: 240 | f.write(ET.tostring(book, pretty_print=True)) 241 | 242 | return book_id, 0 243 | 244 | except: 245 | return book_id, -1 246 | 247 | 248 | if __name__ == "__main__": 249 | parser = argparse.ArgumentParser() 250 | parser.add_argument('--config_file', help='Configuration file', required=True) 251 | args = parser.parse_args() 252 | config_file = args.config_file 253 | 254 | config = configparser.ConfigParser() 255 | config.read_file(open(config_file)) 256 | 257 | # Read set of test books 258 | # Read the list of book IDs in the test set 259 | test_set_books_file = config.get('04_Generate_test_probs', 'test_books_list') 260 | if not os.path.isfile(test_set_books_file): 261 | print('Please provide a valid file name for the list of test set book IDs in the "test_books_list" field.') 262 | exit() 263 | with open(test_set_books_file) as f: 264 | test_book_ids = [x.strip() for x in f.readlines()] 265 | 266 | # Read location where model checkpoints are stored 267 | checkpoint_dir = config.get('03_Train_model', 'checkpoint_dir') 268 | if not os.path.isdir(checkpoint_dir): 269 | print('Please run 03_train_model.py first.') 270 | exit() 271 | 272 | # Read number of epochs for which model was trained 273 | num_epochs = int(config.get('03_Train_model', 'num_epochs')) 274 | 275 | # Read test books location 276 | base_xml_files_dir = config.get('04_Generate_test_probs', 'base_xml_files_dir') 277 | if not os.path.isdir(base_xml_files_dir): 278 | print('Please provide a valid directory name where the xml files for the training set books are stored, in the "base_xml_files_dir" field.') 279 | exit() 280 | 281 | # Read location to store probability outputs 282 | prob_dir = config.get('04_Generate_test_probs', 'prob_dir') 283 | if not os.path.isdir(prob_dir): 284 | os.makedirs(prob_dir) 285 | 286 | # Read sequence length 287 | seq_len = int(config.get('02_Generate_training_sequences', 'seq_len')) 288 | 289 | # Read location to store status of header extraction 290 | log_file = config.get('04_Generate_test_probs', 'log_file') 291 | 292 | 293 | # BERT model 294 | model = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=2) 295 | checkpoint = torch.load(os.path.join(checkpoint_dir, 'epoch_' + str(num_epochs) + '.pt')) 296 | model.load_state_dict(checkpoint['model_state_dict']) 297 | 298 | # Device 299 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 300 | model = model.to(device) 301 | 302 | # BERT tokenizer 303 | tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False, never_split=['[unused1]']) 304 | 305 | 306 | l = list() 307 | print(len(test_book_ids), 'books') 308 | for idx, book_id in enumerate(test_book_ids): 309 | print(idx, book_id) 310 | book, status = generate_header_probabilities_from_model(base_xml_files_dir, prob_dir, seq_len, model, device, tokenizer, book_id) 311 | l.append((book, status)) 312 | 313 | print('Done! Saving status results to log file...') 314 | 315 | df = pd.DataFrame(l, columns=['bookID', 'status']) 316 | 317 | df.to_csv(log_file, index=False) 318 | 319 | print('Saved results to log file!') 320 | 321 | 322 | 323 | -------------------------------------------------------------------------------- /header-annotation/05_annotate_headers.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import configparser 3 | import os 4 | import lxml.etree as ET 5 | import pandas as pd 6 | from pytorch_pretrained_bert import BertTokenizer 7 | import torch 8 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 9 | from pytorch_pretrained_bert import BertForTokenClassification 10 | import numpy as np 11 | import tensorflow as tf 12 | import pickle 13 | 14 | from functools import partial 15 | import multiprocessing 16 | 17 | from header_util import get_header_lines 18 | from regex_util import get_rules 19 | 20 | 21 | import time 22 | import timeout_decorator 23 | 24 | @timeout_decorator.timeout(60) 25 | def annotate_headers(input_dir, output_xml_dir, output_pickle_dir, regex_rules_cache, book_id): 26 | input_xml_path = os.path.join(input_dir, book_id + '_lines.xml') 27 | header_csv_path = os.path.join(input_dir, book_id + '.csv') 28 | output_xml_path = os.path.join(output_xml_dir, book_id + '_headers.xml') 29 | 30 | if os.path.exists(output_xml_path): 31 | return book_id, 1 32 | 33 | print('started', book_id) 34 | 35 | # Read from input XML 36 | parser = ET.XMLParser(huge_tree=True) 37 | tree = ET.parse(str(input_xml_path), parser=parser) 38 | book = tree.getroot() 39 | 40 | body = book.find('body') 41 | body_start = int(body.find('.//line').attrib['num']) 42 | 43 | # Read token-wise predictions 44 | toks_df = pd.read_csv(header_csv_path) 45 | 46 | # Read regex rules from cache 47 | with open(regex_rules_cache, 'rb') as f: 48 | all_seqs, all_seqs_orig, rules, priority = pickle.load(f) 49 | 50 | output_pickle_prefix = os.path.join(output_pickle_dir, book_id) 51 | # Get list of header line numbers 52 | header_lines = get_header_lines(toks_df, all_seqs, all_seqs_orig, rules, priority, body_start, output_pickle_prefix, book) 53 | 54 | 55 | # First line: last line 56 | d = {x[0]:x[-1] for x, y in header_lines} 57 | # First line: attributes 58 | attrs = {x[0]:y for x, y in header_lines} 59 | 60 | # Enclose line numbers in header tags contained in attrs 61 | 62 | # Delete section tags 63 | ET.strip_tags(book, "section") 64 | 65 | for from_line, to_line in d.items(): 66 | desc, number, number_text, number_type, title, rule_text = attrs[from_line] 67 | f = book.find('.//line[@num="' + str(from_line) + '"]') 68 | 69 | new_element = ET.Element('header') 70 | new_element.set('desc', str(desc)) 71 | new_element.set('number', str(number)) 72 | new_element.set('number_text', str(number_text)) 73 | new_element.set('number_type', str(number_type)) 74 | new_element.set('title', str(title)) 75 | new_element.set('rule_text', str(rule_text).strip(',')) 76 | 77 | prev = f.getprevious() 78 | if prev is not None: 79 | for line_num in range(from_line, to_line + 1): 80 | e = book.find('.//line[@num="' + str(line_num) + '"]') 81 | new_element.append(e) 82 | prev.addnext(new_element) 83 | else: 84 | parent = f.getparent() 85 | for line_num in range(from_line, to_line + 1): 86 | e = book.find('.//line[@num="' + str(line_num) + '"]') 87 | new_element.append(e) 88 | parent.insert(0, new_element) 89 | 90 | 91 | ET.strip_tags(book, "line") 92 | 93 | # Write to file 94 | with open(output_xml_path, 'wb') as f: 95 | f.write(ET.tostring(book, pretty_print=True)) 96 | 97 | return book_id, 0 98 | 99 | 100 | def process_book(input_dir, output_xml_dir, output_pickle_dir, regex_rules_cache, book_id): 101 | return annotate_headers(input_dir, output_xml_dir, output_pickle_dir, regex_rules_cache, book_id) 102 | 103 | if __name__ == "__main__": 104 | 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument('--config_file', help='Configuration file', required=True) 107 | args = parser.parse_args() 108 | config_file = args.config_file 109 | 110 | config = configparser.ConfigParser() 111 | config.read_file(open(config_file)) 112 | 113 | # Read set of test books 114 | # Read the list of book IDs in the test set 115 | test_set_books_file = config.get('04_Generate_test_probs', 'test_books_list') 116 | if not os.path.isfile(test_set_books_file): 117 | print('Please provide a valid file name for the list of test set book IDs in the "test_books_list" field.') 118 | exit() 119 | with open(test_set_books_file) as f: 120 | test_book_ids = [x.strip() for x in f.readlines()] 121 | 122 | # Read location to store probability outputs 123 | input_dir = config.get('04_Generate_test_probs', 'prob_dir') 124 | if not os.path.isdir(input_dir): 125 | print('Please run 04_generate_test_probs.py first.') 126 | exit() 127 | 128 | # Read location to store annotated XML output 129 | output_xml_dir = config.get('05_Annotate_headers', 'output_xml_dir') 130 | if not os.path.isdir(output_xml_dir): 131 | os.makedirs(output_xml_dir) 132 | 133 | # Read location to store staged header output 134 | output_pickle_dir = config.get('05_Annotate_headers', 'output_pickle_dir') 135 | if not os.path.isdir(output_pickle_dir): 136 | os.makedirs(output_pickle_dir) 137 | 138 | 139 | # Read number of processes to use 140 | num_procs = int(config.get('05_Annotate_headers', 'num_procs')) 141 | 142 | # Read location to store status of header extraction 143 | log_file = config.get('05_Annotate_headers', 'log_file') 144 | 145 | 146 | 147 | regex_rules_cache = './regex_rules_cache.pkl' 148 | # Generating regex rules 149 | if not os.path.exists(regex_rules_cache): 150 | print("Generating regex rules...") 151 | all_seqs, all_seqs_orig, rules, priority = get_rules() 152 | with open(regex_rules_cache, 'wb') as f: 153 | pickle.dump((all_seqs, all_seqs_orig, rules, priority), f) 154 | 155 | 156 | func = partial(process_book, input_dir, output_xml_dir, output_pickle_dir, regex_rules_cache) 157 | 158 | pool = multiprocessing.Pool(processes=num_procs) 159 | data = pool.map(func, test_book_ids) 160 | pool.close() 161 | pool.join() 162 | 163 | print('Done! Saving status results to log file...') 164 | 165 | df = pd.DataFrame(data, columns=['bookID', 'status']) 166 | df.to_csv(log_file, index=False) 167 | 168 | print('Saved results to log file!') 169 | -------------------------------------------------------------------------------- /header-annotation/config.ini: -------------------------------------------------------------------------------- 1 | [01_Extract_headers_from_HTML] 2 | 3 | book_list = file_location_with_one_book_id_per_line 4 | html_dir = directory_location_with_books_in_html_format 5 | extracted_header_dir = directory_location_to_store_extracted_headers 6 | num_procs = number_of_processes_to_use_for_parallelization 7 | log_file = location_to_store_status_results 8 | 9 | 10 | [02_Generate_training_sequences] 11 | 12 | train_books_list = file_location_with_one_train_book_id_per_line 13 | text_files_dir = directory_location_with_books_in_txt_gz_format 14 | generated_sequence_dir = directory_location_to_store_generated_sequences 15 | seq_len = length_of_sequences_to_generate_(we_use_120_in_the_paper) 16 | num_procs = number_of_processes_to_use_for_parallelization 17 | log_file = location_to_store_status_results 18 | 19 | [03_Train_model] 20 | 21 | checkpoint_dir = directory_location_to_store_model_checkpoints 22 | num_epochs = number_of_epochs_to_train_for 23 | 24 | [04_Generate_test_probs] 25 | 26 | test_books_list = file_location_with_one_test_book_id_per_line 27 | base_xml_files_dir = directory_location_with_books_in_tagged_format 28 | prob_dir = directory_location_to_store_predicted_probabilities 29 | log_file = location_to_store_status_results 30 | 31 | [05_Annotate_headers] 32 | 33 | output_xml_dir = location_to_store_annotated_xmls 34 | output_pickle_dir = location_to_store_stagewise_results 35 | num_procs = number_of_processes_to_use_for_parallelization 36 | log_file = location_to_store_status_results -------------------------------------------------------------------------------- /header-annotation/header_util.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from regex_util import get_best_rules_for_all 4 | from regex_util import get_matching_rule_beginning 5 | from regex_util import get_corresponding_rule_regex_text 6 | from missing_chapter_util import find_missing_chapters 7 | from missing_chapter_util import convert_to_int 8 | 9 | import itertools 10 | import re 11 | 12 | import pickle 13 | import lxml.etree as ET 14 | 15 | import copy 16 | 17 | 18 | def group_ranges(L): 19 | """ 20 | Collapses a list of integers into a list of the start and end of 21 | consecutive runs of numbers. Returns a generator of generators. 22 | >>> [list(x) for x in group_ranges([1, 2, 3, 5, 6, 8])] 23 | [[1, 3], [5, 6], [8]] 24 | """ 25 | for w, z in itertools.groupby(L, lambda x, y=itertools.count(): next(y)-x): 26 | grouped = list(z) 27 | yield (x for x in [grouped[0], grouped[-1]][:len(grouped)]) 28 | 29 | def header_to_xml(header_lines, book, output_xml_path): 30 | 31 | header_lines = [[y for y in x] for x in group_ranges(header_lines)] 32 | 33 | # First line: last line 34 | d = {x[0]:x[-1] for x in header_lines} 35 | 36 | # Delete section tags 37 | ET.strip_tags(book, "section") 38 | 39 | for from_line, to_line in d.items(): 40 | f = book.find('.//line[@num="' + str(from_line) + '"]') 41 | 42 | new_element = ET.Element('header') 43 | 44 | prev = f.getprevious() 45 | if prev is not None: 46 | for line_num in range(from_line, to_line + 1): 47 | e = book.find('.//line[@num="' + str(line_num) + '"]') 48 | new_element.append(e) 49 | prev.addnext(new_element) 50 | else: 51 | parent = f.getparent() 52 | for line_num in range(from_line, to_line + 1): 53 | e = book.find('.//line[@num="' + str(line_num) + '"]') 54 | new_element.append(e) 55 | parent.insert(0, new_element) 56 | 57 | 58 | ET.strip_tags(book, "line") 59 | 60 | # Write to file 61 | with open(output_xml_path, 'wb') as f: 62 | f.write(ET.tostring(book, pretty_print=True)) 63 | 64 | 65 | 66 | 67 | def get_high_vicinity_indices(very_high_threshold, very_low_threshold, new_probs, body_start): 68 | 69 | # Find indices with probability above very high threshold 70 | pos_pred_indices = [idx for idx in range(len(new_probs)) if new_probs[idx] > very_high_threshold and idx >= body_start] 71 | 72 | mod_pos_pred_groups = list() 73 | seen = set() 74 | 75 | for index in pos_pred_indices: 76 | # If index is already seen, ignore 77 | if index in seen: 78 | continue 79 | # Find consecutive indices that come after this index, and have probability > very_low_threshold 80 | # (Expand to the right) 81 | idx = index 82 | new_group = [idx] 83 | seen.add(idx) 84 | idx += 1 85 | while idx < len(new_probs) and new_probs[idx] > very_low_threshold: 86 | new_group.append(idx) 87 | seen.add(idx) 88 | idx += 1 89 | # Append to list of all indices 90 | mod_pos_pred_groups.append(new_group) 91 | 92 | return mod_pos_pred_groups 93 | 94 | 95 | def merge_close_groups(list_of_groups, tokens): 96 | new_list_of_groups = list() 97 | idx = 0 98 | while idx < len(list_of_groups) - 1: 99 | if abs(max(list_of_groups[idx]) - min(list_of_groups[idx + 1])) < 20 and '[unused1]' not in tokens[max(list_of_groups[idx]) + 1:min(list_of_groups[idx + 1])]: 100 | start = min(list_of_groups[idx]) 101 | end = max(list_of_groups[idx + 1]) 102 | curr = [x for x in range(start, end + 1)] 103 | new_list_of_groups.append(curr) 104 | idx += 1 105 | else: 106 | new_list_of_groups.append(list_of_groups[idx]) 107 | idx += 1 108 | if idx == len(list_of_groups) - 1: 109 | new_list_of_groups.append(list_of_groups[idx]) 110 | 111 | if new_list_of_groups == list_of_groups: 112 | return new_list_of_groups 113 | 114 | return merge_close_groups(new_list_of_groups, tokens) 115 | 116 | 117 | 118 | def get_header_attrs(text, rule): 119 | #print([text], rule) 120 | desc = None 121 | number = None 122 | number_text = None 123 | number_type = None 124 | title = None 125 | rule_text = ','.join(rule) 126 | 127 | number_list = ['roman_upper', 'roman_lower', 'numeral', 'word_number_one', 'word_number_One', 'word_number_ONE', 'word_number_first', 'word_number_First', 'word_number_FIRST', 'numeral'] 128 | 129 | curr_text = text 130 | for element in rule: 131 | if len(element) == 0: 132 | break 133 | 134 | if element == 'whitespace': 135 | curr_text = curr_text.lstrip() 136 | continue 137 | 138 | r = re.compile(get_corresponding_rule_regex_text(element)) 139 | m = r.match(curr_text) 140 | start, end = m.span() 141 | 142 | if element == 'desc': 143 | desc = curr_text[start:end].strip() 144 | elif element in number_list: 145 | number = convert_to_int(curr_text[start:end].strip(), element) 146 | number_text = curr_text[start:end].strip() 147 | number_type = element 148 | elif element in ['title_upper', 'title_lower']: 149 | title = curr_text[start:end].strip() 150 | curr_text = curr_text[end:] 151 | #print('Returning ', (desc, number, number_text, number_type, title)) 152 | return desc, number, number_text, number_type, title, rule_text 153 | 154 | 155 | 156 | def get_header_lines(df, all_seqs, all_seqs_orig, rules, priority, body_start, output_pickle_prefix, book): 157 | tokens = [str(x) for x in list(df['token'])] 158 | probs = list(df['prob']) 159 | line_nums = list(df['line_number']) 160 | 161 | body_start_word = line_nums.index(body_start) 162 | 163 | # Get indices in vicinity of high probability indices 164 | 165 | # Top 10% tokens 166 | top_10_percent = len(df) // 10 167 | very_high_threshold = sorted(df.loc[body_start_word:]['prob'])[-top_10_percent] 168 | 169 | 170 | 171 | very_low_threshold = 0.1 172 | pos_pred_groups = get_high_vicinity_indices(very_high_threshold, very_low_threshold, probs, body_start_word) 173 | 174 | # Merge groups close to each other 175 | mod_pos_pred_groups = merge_close_groups(pos_pred_groups, tokens) 176 | 177 | # Replace the [unused1] token with \n and strip 178 | stripped_texts = [' '.join([tokens[idx].replace('[unused1]', '\n') for idx in x]).strip() for x in mod_pos_pred_groups] 179 | 180 | # Keep only those groups which have at least one alphanumeric character when stripped 181 | mod_pos_pred_groups = [mod_pos_pred_groups[idx] for idx in range(len(mod_pos_pred_groups)) if any([x.isalnum() for x in stripped_texts[idx]])] 182 | 183 | 184 | # Remove newlines from beginning 185 | mod_pos_pred_groups_2 = list() 186 | for group in mod_pos_pred_groups: 187 | first_idx = 0 188 | while first_idx < len(group) and tokens[group[first_idx]] == '[unused1]': 189 | first_idx += 1 190 | mod_pos_pred_groups_2.append(group[first_idx:]) 191 | mod_pos_pred_groups = mod_pos_pred_groups_2 192 | 193 | # Extend to entire previous and next part of line 194 | mod_pos_pred_groups_2 = list() 195 | for group in mod_pos_pred_groups: 196 | tmp = group 197 | # Previous 198 | first_idx = group[0] - 1 199 | line = line_nums[group[0]] 200 | while line_nums[first_idx] == line: 201 | tmp.insert(0, first_idx) 202 | first_idx -= 1 203 | # Next 204 | first_idx = group[-1] + 1 205 | line = line_nums[group[-1]] 206 | while first_idx < len(line_nums) and line_nums[first_idx] == line: 207 | tmp.append(first_idx) 208 | first_idx += 1 209 | mod_pos_pred_groups_2.append(tmp) 210 | mod_pos_pred_groups = mod_pos_pred_groups_2 211 | 212 | # Add newlines at the end if present (Adding for title_upper to match) 213 | mod_pos_pred_groups_2 = list() 214 | for group in mod_pos_pred_groups: 215 | new_group = group 216 | last_idx = group[-1] + 1 217 | while last_idx < len(tokens) and tokens[last_idx] == '[unused1]': 218 | new_group.append(last_idx) 219 | last_idx += 1 220 | mod_pos_pred_groups_2.append(new_group) 221 | mod_pos_pred_groups = mod_pos_pred_groups_2 222 | 223 | # Convert groups to texts for regex matching 224 | likely_headers = [' '.join([tokens[idx].replace('[unused1]', '\n') for idx in x]) for x in mod_pos_pred_groups] 225 | 226 | header_lines = set() 227 | for g in mod_pos_pred_groups: 228 | for x in g: 229 | header_lines.add(line_nums[x]) 230 | header_lines = sorted(list(header_lines)) 231 | header_to_xml(header_lines, copy.deepcopy(book), output_pickle_prefix + '_stage1_headers.xml') 232 | 233 | # Find the best matching rule for each header 234 | rules_found = get_best_rules_for_all(likely_headers, all_seqs, rules, priority) 235 | 236 | 237 | # Look for missing chapters using each rule 238 | rules_found = [[x for x in r if x] if r else [] for r in rules_found] 239 | num_set = set(['roman_upper', 'roman_lower', 'numeral', 'word_number_one', 'word_number_One', 'word_number_ONE', 'word_number_first', 'word_number_First', 'word_number_FIRST']) 240 | clipped_rules_found = list() 241 | for r in rules_found: 242 | if len(r) == 0: 243 | clipped_rules_found.append(r) 244 | continue 245 | tmp = list() 246 | tmp_idx = 0 247 | while tmp_idx < len(r): 248 | tmp.append(r[tmp_idx]) 249 | if r[tmp_idx] in num_set: 250 | break 251 | tmp_idx += 1 252 | clipped_rules_found.append(tmp) 253 | 254 | d_rules = dict() 255 | for r_idx, r in enumerate(clipped_rules_found): 256 | if len(r) == 0: 257 | continue 258 | t = tuple(r) 259 | if t not in d_rules: 260 | d_rules[t] = [r_idx] 261 | else: 262 | d_rules[t].append(r_idx) 263 | 264 | d_rules_new = dict() 265 | for t in d_rules: 266 | r = list(t) 267 | d_rules_new[t] = find_missing_chapters(df, [mod_pos_pred_groups[r_idx] for r_idx in d_rules[t]], r, body_start_word) 268 | 269 | # Improve each match to include title if present 270 | valid_indices = [idx for idx in range(len(all_seqs_orig))] 271 | new_seqs = [all_seqs[idx] for idx in valid_indices] 272 | new_rules = [rules[idx] for idx in valid_indices] 273 | 274 | ans = list() 275 | attrs_dict = dict() 276 | 277 | all_header_groups = list() 278 | for r in d_rules_new.keys(): 279 | if list(r) == ['title_upper']: 280 | continue 281 | rule_groups = d_rules_new[r] 282 | all_header_groups += rule_groups 283 | all_header_groups = sorted(all_header_groups) 284 | 285 | next_group_index = dict() 286 | for idx in range(len(all_header_groups) - 1): 287 | next_group_index[all_header_groups[idx][0]] = all_header_groups[idx + 1][0] 288 | 289 | sorted_rule_list = list(d_rules_new.keys()) 290 | sorted_rule_list.sort(key=lambda x:[priority.index(y) for y in x]) 291 | seen = set() 292 | 293 | for rule in sorted_rule_list: 294 | for group_idx, group in enumerate(d_rules_new[rule]): 295 | start_idx = group[0] 296 | 297 | end_idx = start_idx + 100 298 | if start_idx in next_group_index: 299 | end_idx = min(end_idx, next_group_index[start_idx]) 300 | 301 | tmp_text = ' '.join([x if x != '[unused1]' else '\n' for x in tokens[start_idx:end_idx]]) 302 | matched_rules = get_matching_rule_beginning(tmp_text, new_seqs, new_rules, priority) 303 | if len(matched_rules) > 0: 304 | matched_bool = False 305 | for r_orig in matched_rules: 306 | try: 307 | r = new_rules[new_seqs.index(r_orig)] 308 | m = r.match(tmp_text) 309 | start, end = m.span() 310 | match_len = end - start 311 | if tmp_text[start:end].strip(' ')[-1] != '\n' and tmp_text[end:].strip(' ')[0] != '\n': 312 | continue 313 | 314 | attrs = get_header_attrs(tmp_text[start:end], r_orig) 315 | matched_bool = True 316 | break 317 | except: 318 | pass 319 | if not matched_bool: 320 | continue 321 | tmp = list() 322 | curr_str = '' 323 | 324 | new_iter = start_idx 325 | while new_iter < len(tokens): 326 | curr_str += (tokens[new_iter] if tokens[new_iter] != '[unused1]' else '\n') 327 | if len(curr_str) > match_len: 328 | break 329 | tmp.append(new_iter) 330 | curr_str += ' ' 331 | new_iter += 1 332 | if tmp: 333 | if not any([index in seen for index in tmp]): 334 | ans.append(tmp) 335 | seen.update(tmp) 336 | attrs_dict[line_nums[start_idx]] = attrs 337 | 338 | 339 | 340 | line_nos = set() 341 | for group in ans: 342 | for idx in group: 343 | line_nos.add(line_nums[idx]) 344 | 345 | final_ans = list() 346 | line_nos_sorted = sorted(list(line_nos)) 347 | 348 | idx = 0 349 | while idx < len(line_nos_sorted): 350 | if line_nos_sorted[idx] in attrs_dict: 351 | tmp = list() 352 | tmp.append(line_nos_sorted[idx]) 353 | inner_idx = idx + 1 354 | while inner_idx < len(line_nos_sorted) and line_nos_sorted[inner_idx] not in attrs_dict: 355 | tmp.append(line_nos_sorted[inner_idx]) 356 | inner_idx += 1 357 | final_ans.append((tmp, attrs_dict[line_nos_sorted[idx]])) 358 | 359 | idx = inner_idx 360 | continue 361 | 362 | idx += 1 363 | 364 | 365 | header_lines = set() 366 | for g, _ in final_ans: 367 | for x in g: 368 | header_lines.add(x) 369 | header_lines = sorted(list(header_lines)) 370 | 371 | header_to_xml(header_lines, copy.deepcopy(book), output_pickle_prefix + '_stage2_headers.xml') 372 | 373 | header_lines = final_ans 374 | 375 | 376 | 377 | # Keep only those groups which do not start with a I'm, I've, I'd, I'll 378 | punctuation = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~‘’' 379 | stripped_texts = [' '.join([tokens[idx].replace('[unused1]', '\n') for idx in x[0]]).strip() for x in header_lines] 380 | 381 | header_lines_new = list() 382 | for idx in range(len(stripped_texts)): 383 | if not any([stripped_texts[idx].lstrip().lstrip(punctuation).lstrip().startswith(x) for x in {"I'm", "I've", "I’m", "I’ve", "I'd", "I’d", "I'll", "I’ll"}]): 384 | header_lines_new.append(header_lines[idx]) 385 | 386 | header_lines = header_lines_new 387 | 388 | # Remove false positives 389 | remove_indices = set() 390 | for idx, hl in enumerate(header_lines): 391 | line_nums, attrs = hl 392 | desc, number, number_text, number_type, title, rule_text = attrs 393 | if number == None: 394 | continue 395 | 396 | i = idx + 1 397 | while i < len(header_lines) and header_lines[i][1][1] is None: 398 | i += 1 399 | if i < len(header_lines) and header_lines[i][1][1] == number + 1 and header_lines[i][1][3] == number_type: 400 | for i2 in range(idx + 1, i): 401 | remove_indices.add(i2) 402 | final_ans = [elem for idx, elem in enumerate(header_lines) if idx not in remove_indices] 403 | 404 | 405 | header_lines = set() 406 | for g, _ in final_ans: 407 | for x in g: 408 | header_lines.add(x) 409 | header_lines = sorted(list(header_lines)) 410 | header_to_xml(header_lines, copy.deepcopy(book), output_pickle_prefix + '_stage3_headers.xml') 411 | 412 | return final_ans 413 | -------------------------------------------------------------------------------- /header-annotation/missing_chapter_util.py: -------------------------------------------------------------------------------- 1 | import inflect 2 | from collections import Counter 3 | from collections import OrderedDict 4 | import collections 5 | import re 6 | import pandas as pd 7 | 8 | def get_regex_rule(number_type): 9 | 10 | p = inflect.engine() 11 | 12 | numeral = "[0-9]+" 13 | 14 | roman_upper = "(?=[MDCLXVI])M*(C[MD]|D?C{0,3})(X[CL]|L?X{0,3})(I[XV]|V?I{0,3})(?![A-Za-z0-9'\"])" 15 | roman_lower = "(?=[mdxlxvi])m*(c[md]|d?c{0,3})(x[cl]|l?x{0,3})(i[xv]|v?i{0,3})(?![A-Za-z0-9'\"])" 16 | 17 | word_number_one = '|'.join(sorted([p.number_to_words(x) for x in range(201)], key=len)[::-1]) 18 | word_number_One = '|'.join(sorted([p.number_to_words(x).title() for x in range(201)], key=len)[::-1]) 19 | word_number_ONE = '|'.join(sorted([p.number_to_words(x).upper() for x in range(201)], key=len)[::-1]) 20 | word_number_first = '|'.join(sorted([p.ordinal(p.number_to_words(x)) for x in range(201)], key=len)[::-1]) 21 | word_number_First = '|'.join(sorted([p.ordinal(p.number_to_words(x)).title() for x in range(201)], key=len)[::-1]) 22 | word_number_FIRST = '|'.join(sorted([p.ordinal(p.number_to_words(x)).upper() for x in range(201)], key=len)[::-1]) 23 | 24 | tmp = [p.number_to_words(x) for x in range(201)] 25 | tmp2 = list() 26 | for elem in tmp: 27 | if '-' in elem: 28 | tmp2.append(elem.replace('-', '[\s]+-[\s]+')) 29 | word_number_one += '|' + '|'.join(tmp2) 30 | 31 | tmp = [p.number_to_words(x).title() for x in range(201)] 32 | tmp2 = list() 33 | for elem in tmp: 34 | if '-' in elem: 35 | tmp2.append(elem.replace('-', '[\s]+-[\s]+')) 36 | word_number_One += '|' + '|'.join(tmp2) 37 | 38 | tmp = [p.number_to_words(x).upper() for x in range(201)] 39 | tmp2 = list() 40 | for elem in tmp: 41 | if '-' in elem: 42 | tmp2.append(elem.replace('-', '[\s]+-[\s]+')) 43 | word_number_ONE += '|' + '|'.join(tmp2) 44 | 45 | tmp = [p.ordinal(p.number_to_words(x)) for x in range(201)] 46 | tmp2 = list() 47 | for elem in tmp: 48 | if '-' in elem: 49 | tmp2.append(elem.replace('-', '[\s]+-[\s]+')) 50 | word_number_first += '|' + '|'.join(tmp2) 51 | 52 | tmp = [p.ordinal(p.number_to_words(x)).title() for x in range(201)] 53 | tmp2 = list() 54 | for elem in tmp: 55 | if '-' in elem: 56 | tmp2.append(elem.replace('-', '[\s]+-[\s]+')) 57 | word_number_First += '|' + '|'.join(tmp2) 58 | 59 | tmp = [p.ordinal(p.number_to_words(x)).upper() for x in range(201)] 60 | tmp2 = list() 61 | for elem in tmp: 62 | if '-' in elem: 63 | tmp2.append(elem.replace('-', '[\s]+-[\s]+')) 64 | word_number_FIRST += '|' + '|'.join(tmp2) 65 | 66 | string_to_pattern = dict() 67 | string_to_pattern['numeral'] = numeral 68 | string_to_pattern['roman_upper'] = roman_upper 69 | string_to_pattern['roman_lower'] = roman_lower 70 | string_to_pattern['word_number_one'] = word_number_one 71 | string_to_pattern['word_number_One'] = word_number_One 72 | string_to_pattern['word_number_ONE'] = word_number_ONE 73 | string_to_pattern['word_number_first'] = word_number_first 74 | string_to_pattern['word_number_First'] = word_number_First 75 | string_to_pattern['word_number_FIRST'] = word_number_FIRST 76 | 77 | return re.compile(string_to_pattern[number_type]) 78 | 79 | 80 | def convert_to_int(number, number_type): 81 | 82 | def convert_one_to_int(number): 83 | p = inflect.engine() 84 | #number = number.replace(' ', '') 85 | number = ' '.join(number.strip().split()) 86 | number = '-'.join([x.strip() for x in number.split('-')]) 87 | l = [p.number_to_words(x) for x in range(201)] 88 | return l.index(number) 89 | 90 | def convert_first_to_int(number): 91 | p = inflect.engine() 92 | #number = number.replace(' ', '') 93 | number = ' '.join(number.strip().split()) 94 | number = '-'.join([x.strip() for x in number.split('-')]) 95 | l = [p.ordinal(p.number_to_words(x)) for x in range(201)] 96 | return l.index(number) 97 | 98 | def convert_roman_to_int(number): 99 | 100 | def value(r): 101 | if (r == 'I'): 102 | return 1 103 | if (r == 'V'): 104 | return 5 105 | if (r == 'X'): 106 | return 10 107 | if (r == 'L'): 108 | return 50 109 | if (r == 'C'): 110 | return 100 111 | if (r == 'D'): 112 | return 500 113 | if (r == 'M'): 114 | return 1000 115 | return -1 116 | res = 0 117 | i = 0 118 | 119 | number = number.upper() 120 | while (i < len(number)): 121 | # Getting value of symbol s[i] 122 | s1 = value(number[i]) 123 | if (i+1 < len(number)): 124 | # Getting value of symbol s[i+1] 125 | s2 = value(number[i+1]) 126 | # Comparing both values 127 | if (s1 >= s2): 128 | # Value of current symbol is greater 129 | # or equal to the next symbol 130 | res = res + s1 131 | i = i + 1 132 | else: 133 | # Value of current symbol is greater 134 | # or equal to the next symbol 135 | res = res + s2 - s1 136 | i = i + 2 137 | else: 138 | res = res + s1 139 | i = i + 1 140 | return res 141 | 142 | if number_type == 'numeral': 143 | return int(number) 144 | if number_type.lower() == 'word_number_one': 145 | return convert_one_to_int(number.lower()) 146 | if number_type.lower() == 'word_number_first': 147 | return convert_first_to_int(number.lower()) 148 | if number_type.startswith('roman'): 149 | return convert_roman_to_int(number) 150 | 151 | return 0 152 | 153 | 154 | 155 | def findLISindices(arrA): 156 | LIS = [0 for i in range(len(arrA))] 157 | for i in range(len(arrA)): 158 | maximum = -1 159 | for j in range(i): 160 | if arrA[i] > arrA[j]: 161 | if maximum == -1 or maximum < LIS[j] + 1: 162 | maximum = 1 + LIS[j] 163 | if maximum == -1: 164 | maximum = 1 165 | LIS[i] = maximum 166 | 167 | result = -1 168 | index = -1 169 | 170 | for i in range(len(LIS)): 171 | if result < LIS[i]: 172 | result = LIS[i] 173 | index = i 174 | 175 | answer = list() 176 | answer.insert(0, index) 177 | res = result - 1 178 | for i in range(index - 1, -1, -1): 179 | if LIS[i] == res: 180 | answer.insert(0, i) 181 | res -= 1 182 | 183 | return answer 184 | 185 | 186 | def write_roman(num): 187 | 188 | roman = OrderedDict() 189 | roman[1000] = "M" 190 | roman[900] = "CM" 191 | roman[500] = "D" 192 | roman[400] = "CD" 193 | roman[100] = "C" 194 | roman[90] = "XC" 195 | roman[50] = "L" 196 | roman[40] = "XL" 197 | roman[10] = "X" 198 | roman[9] = "IX" 199 | roman[5] = "V" 200 | roman[4] = "IV" 201 | roman[1] = "I" 202 | 203 | def roman_num(num): 204 | for r in roman.keys(): 205 | x, y = divmod(num, r) 206 | yield roman[r] * x 207 | num -= (r * x) 208 | if num <= 0: 209 | break 210 | 211 | return "".join([a for a in roman_num(num)]) 212 | 213 | def convert_to_number_type_regex_string(x, number_type): 214 | if number_type == 'numeral': 215 | return str(x) + "(?![A-Za-z0-9'\"])" 216 | 217 | if number_type == 'roman_upper': 218 | return write_roman(x) + "(?![A-Za-z0-9'\"])" 219 | elif number_type == 'roman_lower': 220 | return write_roman(x).lower() + "(?![A-Za-z0-9'\"])" 221 | 222 | p = inflect.engine() 223 | conv = p.number_to_words(x) 224 | 225 | if number_type == 'word_number_One': 226 | conv = conv.title() 227 | elif number_type == 'word_number_ONE': 228 | conv = conv.upper() 229 | 230 | elif number_type == 'word_number_first': 231 | conv = p.ordinal(conv) 232 | elif number_type == 'word_number_First': 233 | conv = p.ordinal(conv).title() 234 | elif number_type == 'word_number_FIRST': 235 | conv = p.ordinal(conv).upper() 236 | 237 | if '-' in conv: 238 | conv = conv.replace('-', '[\s]*-[\s]*') 239 | 240 | return conv + "(?![A-Za-z0-9'\"])" 241 | 242 | 243 | def find_in_book(tokens, search_pattern, from_index, to_index, last_occurrence=False): 244 | curr = 0 245 | char_indices = list() 246 | for elem in tokens: 247 | char_indices.append(curr) 248 | if elem == '[unused1]': 249 | curr += 2 250 | else: 251 | curr += 1 + len(elem) 252 | 253 | text = ' '.join([x if x != '[unused1]' else '\n' for x in tokens]) 254 | 255 | if from_index >= len(char_indices): 256 | return [] 257 | if to_index >= len(char_indices): 258 | to_index = len(char_indices) - 1 259 | 260 | lower = char_indices[from_index] 261 | upper = min(char_indices[to_index], len(text)) + 1 262 | ans = list() 263 | for elem in search_pattern.finditer(text[lower:upper]): 264 | if elem.span()[0] + char_indices[from_index] in char_indices and text[elem.span()[0] + char_indices[from_index] - 2] == '\n': 265 | ans.append(elem) 266 | 267 | ans2 = list() 268 | 269 | for m in ans: 270 | 271 | start, end = m.span() 272 | start += char_indices[from_index] 273 | end += char_indices[from_index] 274 | 275 | if start in char_indices: 276 | start_idx = char_indices.index(start) 277 | idx = start_idx 278 | while idx < len(char_indices) and char_indices[idx] < end: 279 | idx += 1 280 | if idx - 2 != start_idx: 281 | ans2.append([x for x in range(start_idx, idx)]) 282 | ans2.append([start_idx]) 283 | 284 | return ans2 285 | 286 | def findRestartIndices(nums): 287 | if len(nums) == 0: 288 | return [] 289 | ans = [0] 290 | if len(nums) == 1: 291 | return ans 292 | 293 | idx = 1 294 | while idx < len(nums): 295 | if nums[idx] <= nums[idx - 1]: 296 | return ans 297 | ans.append(idx) 298 | idx += 1 299 | return ans 300 | 301 | def hunt(pred_texts, pred_indices, matched_rule, tokens, body_start_token_idx, number_restart=False): 302 | # Get the number type from the matched rule 303 | number_type = matched_rule[-1] 304 | 305 | # Get the regex rule corresponding to the number type 306 | number_regex_rule = get_regex_rule(number_type) 307 | 308 | # Get the position of the number type in the matched rule 309 | number_index = matched_rule.index(number_type) 310 | 311 | number_match_dict = collections.defaultdict(lambda: list()) 312 | found_missing_dict = collections.defaultdict(lambda: list()) 313 | indices_groups_dict = collections.defaultdict(lambda: list()) 314 | 315 | if number_index > 0: 316 | # Rule for descriptor that may occur before number 317 | desc = "(CHAPTER|Chapter|CHAP|Chap|PART|Part|BOOK|Book|STORY|Story|LETTER|Letter|VOLUME|Volume|VOL|Vol|CASE|Case|THE|The)[\s]*(THE|The|the|NO|No|no|NO\.|No\.|no\.|NUMBER|Number|number|NUMBER\.|Number\.|number\.)*[\s]*" 318 | desc_rule = re.compile(desc) 319 | for idx, text in enumerate(pred_texts): 320 | desc_match = desc_rule.match(text) 321 | if desc_match: 322 | number_start = desc_match.span()[1] 323 | matched_text = text[:number_start] 324 | rem_header = text[number_start:] 325 | number_match = number_regex_rule.match(rem_header) 326 | if number_match: 327 | start, end = number_match.span() 328 | number_match_dict[matched_text].append(rem_header[start:end]) 329 | indices_groups_dict[matched_text].append(pred_indices[idx]) 330 | 331 | else: 332 | for idx, text in enumerate(pred_texts): 333 | number_match = number_regex_rule.match(text) 334 | if number_match: 335 | start, end = number_match.span() 336 | number_match_dict[''].append(text[start:end]) 337 | indices_groups_dict[''].append(pred_indices[idx]) 338 | 339 | seen = set() 340 | for group in pred_indices: 341 | seen.update(group) 342 | for descriptor in number_match_dict: 343 | pred_indices_tmp = indices_groups_dict[descriptor] 344 | numbers = number_match_dict[descriptor] 345 | converted_numbers = [convert_to_int(n, number_type) for n in numbers] 346 | 347 | queue = list() 348 | queue.append((converted_numbers, 0, body_start_token_idx, len(tokens) - 1)) 349 | 350 | while queue: 351 | converted_numbers, offset, last_from, last_to = queue.pop(0) 352 | if number_restart: 353 | lis_indices = findRestartIndices(converted_numbers) 354 | else: 355 | lis_indices = findLISindices(converted_numbers) 356 | 357 | if len(lis_indices) > 0: 358 | if lis_indices[0] > 0: 359 | a = converted_numbers[:lis_indices[0]] 360 | b = offset 361 | c = last_from 362 | d = pred_indices_tmp[offset + lis_indices[0]][-1] 363 | queue.append((a, b, c, d)) 364 | if lis_indices[-1] < len(converted_numbers) - 1: 365 | a = converted_numbers[lis_indices[-1] + 1:] 366 | b = offset + lis_indices[-1] + 1 367 | c = pred_indices_tmp[offset + lis_indices[-1]][-1] + 1 368 | d = last_to 369 | queue.append((a, b, c, d)) 370 | 371 | smallest_number = converted_numbers[lis_indices[0]] - 1 372 | while smallest_number > 0: 373 | from_index = last_from 374 | to_index = pred_indices_tmp[offset + lis_indices[0]][-1] 375 | if number_index > 0: 376 | search_pattern = descriptor.strip() + '[\s]*' + convert_to_number_type_regex_string(smallest_number, number_type) 377 | else: 378 | search_pattern = convert_to_number_type_regex_string(smallest_number, number_type) 379 | search_pattern = re.compile(search_pattern) 380 | found_indices = find_in_book(tokens, search_pattern, from_index, to_index) 381 | if len(found_indices) > 0: 382 | for group in found_indices: 383 | if not any(x in seen for x in group): 384 | found_missing_dict[descriptor].insert(0, group) 385 | seen.update(group) 386 | to_index = found_indices[-1][0] - 1 387 | else: 388 | break 389 | smallest_number -= 1 390 | 391 | idx = 0 392 | for idx in range(len(lis_indices) - 1): 393 | temp = converted_numbers[lis_indices[idx]] + 1 394 | from_index = pred_indices_tmp[offset + lis_indices[idx]][-1] + 1 395 | to_index = pred_indices_tmp[offset + lis_indices[idx + 1]][-1] 396 | while converted_numbers[lis_indices[idx + 1]] > temp: 397 | if number_index > 0: 398 | search_pattern = descriptor.strip() + '[\s]*' + convert_to_number_type_regex_string(temp, number_type) 399 | else: 400 | search_pattern = convert_to_number_type_regex_string(temp, number_type) 401 | search_pattern = re.compile(search_pattern) 402 | found_indices = find_in_book(tokens, search_pattern, from_index, to_index) 403 | if len(found_indices) > 0: 404 | for group in found_indices: 405 | if not any(x in seen for x in group): 406 | found_missing_dict[descriptor].append(group) 407 | seen.update(group) 408 | from_index = found_indices[0][-1] + 1 409 | temp += 1 410 | 411 | largest_number = converted_numbers[lis_indices[-1]] 412 | while True: 413 | largest_number += 1 414 | from_index = pred_indices_tmp[offset + lis_indices[-1]][-1] + 1 415 | to_index = last_to 416 | if number_index > 0: 417 | search_pattern = descriptor.strip() + '[\s]*' + convert_to_number_type_regex_string(largest_number, number_type) 418 | else: 419 | search_pattern = convert_to_number_type_regex_string(largest_number, number_type) 420 | search_pattern = re.compile(search_pattern) 421 | found_indices = find_in_book(tokens, search_pattern, from_index, to_index) 422 | if len(found_indices) > 0: 423 | for group in found_indices: 424 | if not any(x in seen for x in group): 425 | found_missing_dict[descriptor].append(group) 426 | seen.update(group) 427 | from_index = found_indices[0][-1] + 1 428 | else: 429 | break 430 | ans = list() 431 | for descriptor in found_missing_dict: 432 | for group in found_missing_dict[descriptor]: 433 | ans.append(group) 434 | return ans 435 | 436 | def find_missing_chapters(df, pred_indices, matched_rule, body_start_token_idx): 437 | number_list = ['roman_upper', 'roman_lower', 'numeral', 'word_number_one', 'word_number_One', 'word_number_ONE', 'word_number_first', 'word_number_First', 'word_number_FIRST'] 438 | 439 | # If the matched rule does not contain a number form, return as is 440 | if len(set(number_list).intersection(set(matched_rule))) == 0: 441 | return pred_indices 442 | 443 | 444 | tokens = [str(x) for x in list(df['token'])] 445 | line_nums = list(df['line_number']) 446 | pred_texts = [' '.join([tokens[x] if tokens[x] != '[unused1]' else '\n' for x in group]) for group in pred_indices] 447 | 448 | missing_indices = hunt(pred_texts, pred_indices, matched_rule, tokens, body_start_token_idx, number_restart=False) 449 | 450 | new_pred_indices = pred_indices + missing_indices 451 | new_pred_texts = [' '.join([tokens[x] if tokens[x] != '[unused1]' else '\n' for x in group]) for group in new_pred_indices] 452 | 453 | new_missing_indices = hunt(new_pred_texts, new_pred_indices, matched_rule, tokens, body_start_token_idx, number_restart=True) 454 | 455 | return sorted(pred_indices + missing_indices + new_missing_indices) 456 | -------------------------------------------------------------------------------- /header-annotation/regex_util.py: -------------------------------------------------------------------------------- 1 | import inflect 2 | from collections import Counter 3 | import re 4 | from itertools import groupby 5 | 6 | def generate_sequences(l): 7 | if len(l) == 0: 8 | return [] 9 | subsequent = generate_sequences(l[1:]) 10 | answer = list() 11 | if len(subsequent) > 0: 12 | answer += subsequent 13 | for elem in l[0]: 14 | answer.append([elem]) 15 | for elem2 in subsequent: 16 | answer.append([elem] + elem2) 17 | return answer 18 | 19 | def remove_duplicates(l): 20 | res = [] 21 | for i in l: 22 | if i not in res: 23 | res.append(i) 24 | return res 25 | 26 | def remove_consecutives(l): 27 | res = list() 28 | for elem in l: 29 | tmp = [x[0] for x in groupby(elem)] 30 | if tmp not in res: 31 | res.append(tmp) 32 | return res 33 | 34 | def remove_whitespace_from_ends(l): 35 | res = list() 36 | for elem in l: 37 | start = 0 38 | while start < len(elem) and elem[start] == 'whitespace': 39 | start += 1 40 | end = len(elem) - 1 41 | while end >= 0 and elem[end] == 'whitespace': 42 | end -= 1 43 | tmp = elem[start:end + 1] 44 | if tmp and tmp not in res: 45 | res.append(elem[start:end + 1]) 46 | return res 47 | 48 | def issublist(b, a): 49 | return b in [a[i:len(b)+i] for i in range(len(a))] 50 | 51 | def remove_incompatible_consecutives(l, incompatible): 52 | inc_list = list() 53 | for x in incompatible: 54 | for y in incompatible: 55 | inc_list.append([x, y]) 56 | res = list() 57 | for elem in l: 58 | if any([issublist(x, elem) for x in inc_list]): 59 | pass 60 | else: 61 | res.append(elem) 62 | return res 63 | 64 | 65 | def get_corresponding_rule_regex_text(rule_name): 66 | word_numbers = ['word_number_one', 'word_number_One', 'word_number_ONE', 'word_number_first', 'word_number_First', 'word_number_FIRST'] 67 | 68 | if rule_name == 'desc': 69 | return "[A-Za-z0-9\S]*(CHAPTER|Chapter|CHAP|Chap|PART|Part|BOOK|Book|STORY|Story|LETTER|Letter|VOLUME|Volume|VOL|Vol|CASE|Case|THE|The)[\s]*(THE|The|the|NO|No|no|NO\.|No\.|no\.|NUMBER|Number|number|NUMBER\.|Number\.|number\.)*" 70 | 71 | elif rule_name == 'roman_upper': 72 | return "(?=[MDCLXVI])M*(C[MD]|D?C{0,3})(X[CL]|L?X{0,3})(I[XV]|V?I{0,3})(?![A-Za-z0-9'\"])" 73 | elif rule_name == 'roman_lower': 74 | return "(?=[mdxlxvi])m*(c[md]|d?c{0,3})(x[cl]|l?x{0,3})(i[xv]|v?i{0,3})(?![A-Za-z0-9'\"])" 75 | 76 | elif rule_name == 'numeral': 77 | return "[0-9]+" 78 | 79 | elif rule_name == 'punctuation': 80 | return "(?=\S)[^a-zA-Z\d\s]+" 81 | 82 | elif rule_name == 'title_upper': 83 | return "([^a-z\s\.]*(?=.*[A-Z]+.*)[^a-z]+[\n]+)+" 84 | elif rule_name == 'title_lower': 85 | return "((?=.*[a-z]+.*)(?=.*[A-Z]+.*)[^\r\n]*)\n[^\S\n]*\n" 86 | 87 | elif rule_name == 'whitespace': 88 | return "[\s]+" 89 | 90 | elif rule_name in word_numbers: 91 | p = inflect.engine() 92 | if rule_name == 'word_number_one': 93 | tmp = [p.number_to_words(x) for x in range(201)] 94 | elif rule_name == 'word_number_One': 95 | tmp = [p.number_to_words(x).title() for x in range(201)] 96 | elif rule_name == 'word_number_ONE': 97 | tmp = [p.number_to_words(x).upper() for x in range(201)] 98 | elif rule_name == 'word_number_first': 99 | tmp = [p.ordinal(p.number_to_words(x)) for x in range(201)] 100 | elif rule_name == 'word_number_First': 101 | tmp = [p.ordinal(p.number_to_words(x)).title() for x in range(201)] 102 | elif rule_name == 'word_number_FIRST': 103 | tmp = [p.ordinal(p.number_to_words(x)).upper() for x in range(201)] 104 | 105 | tmp2 = list() 106 | for elem in tmp: 107 | if '-' in elem: 108 | tmp2.append(elem.replace('-', '[\s]*-[\s]*')) 109 | else: 110 | tmp2.append(elem) 111 | reg = '|'.join(sorted(tmp2, key=len)[::-1]) 112 | return reg 113 | 114 | return None 115 | 116 | 117 | 118 | def get_rules(): 119 | p = inflect.engine() 120 | 121 | desc = get_corresponding_rule_regex_text('desc') 122 | 123 | roman_upper = get_corresponding_rule_regex_text('roman_upper') 124 | roman_lower = get_corresponding_rule_regex_text('roman_lower') 125 | 126 | numeral = get_corresponding_rule_regex_text('numeral') 127 | 128 | punctuation = get_corresponding_rule_regex_text('punctuation') 129 | 130 | title_upper = get_corresponding_rule_regex_text('title_upper') 131 | title_lower = get_corresponding_rule_regex_text('title_lower') 132 | 133 | whitespace = get_corresponding_rule_regex_text('whitespace') 134 | 135 | word_number_one = get_corresponding_rule_regex_text('word_number_one') 136 | word_number_One = get_corresponding_rule_regex_text('word_number_One') 137 | word_number_ONE = get_corresponding_rule_regex_text('word_number_ONE') 138 | word_number_first = get_corresponding_rule_regex_text('word_number_first') 139 | word_number_First = get_corresponding_rule_regex_text('word_number_First') 140 | word_number_FIRST = get_corresponding_rule_regex_text('word_number_FIRST') 141 | 142 | 143 | l2 = list() 144 | 145 | l2.append(['desc']) 146 | l2.append(['whitespace']) 147 | l2.append(['punctuation']) 148 | l2.append(['whitespace']) 149 | l2.append(['roman_upper', 'roman_lower', 'numeral', 'word_number_one', 'word_number_One', 'word_number_ONE', 'word_number_first', 'word_number_First', 'word_number_FIRST']) 150 | l2.append(['whitespace']) 151 | l2.append(['punctuation']) 152 | l2.append(['whitespace']) 153 | l2.append(['title_upper', 'title_lower']) 154 | 155 | 156 | string_to_pattern = dict() 157 | string_to_pattern['desc'] = desc 158 | string_to_pattern['roman_upper'] = roman_upper 159 | string_to_pattern['roman_lower'] = roman_lower 160 | string_to_pattern['numeral'] = numeral 161 | string_to_pattern['punctuation'] = punctuation 162 | string_to_pattern['title_upper'] = title_upper 163 | string_to_pattern['title_lower'] = title_lower 164 | string_to_pattern['whitespace'] = whitespace 165 | 166 | string_to_pattern['word_number_one'] = word_number_one 167 | string_to_pattern['word_number_One'] = word_number_One 168 | string_to_pattern['word_number_ONE'] = word_number_ONE 169 | string_to_pattern['word_number_first'] = word_number_first 170 | string_to_pattern['word_number_First'] = word_number_First 171 | string_to_pattern['word_number_FIRST'] = word_number_FIRST 172 | 173 | 174 | all_seqs = generate_sequences(l2) 175 | all_seqs_2 = remove_duplicates(all_seqs) 176 | all_seqs_3 = remove_consecutives(all_seqs_2) 177 | all_seqs_4 = remove_whitespace_from_ends(all_seqs_3) 178 | 179 | all_seqs_5 = remove_incompatible_consecutives(all_seqs_4, ['title_upper', 'title_lower', 'roman_upper', 'roman_lower']) 180 | blacklist_sequences = list() 181 | blacklist_sequences.append(['desc', 'title_upper']) 182 | blacklist_sequences.append(['punctuation', 'title_lower']) 183 | blacklist_sequences.append(['punctuation']) 184 | blacklist_sequences.append(['punctuation', 'whitespace', 'punctuation']) 185 | blacklist_sequences.append(['punctuation', 'whitespace', 'title_upper']) 186 | blacklist_sequences.append(['punctuation', 'whitespace', 'title_lower']) 187 | blacklist_sequences.append(['punctuation', 'whitespace', 'punctuation', 'whitespace', 'title_lower']) 188 | blacklist_sequences.append(['title_lower']) 189 | blacklist_sequences.append(['roman_lower', 'whitespace', 'title_lower']) 190 | 191 | all_seqs_5 = [x for x in all_seqs_5 if x not in blacklist_sequences] 192 | 193 | all_seqs_5 = [x for x in all_seqs_5 if not issublist(['punctuation', 'whitespace', 'punctuation'], x)] 194 | 195 | number_list = ['roman_upper', 'roman_lower', 'numeral', 'word_number_one', 'word_number_One', 'word_number_ONE', 'word_number_first', 'word_number_First', 'word_number_FIRST'] 196 | all_seqs_5 = [x for x in all_seqs_5 if 'desc' not in x or ('desc' in x and any([m in x for m in number_list]))] 197 | 198 | all_seqs_5 = [x for x in all_seqs_5 if x[0] != 'word_number_one' and x[0] != 'word_number_first'] 199 | 200 | all_seqs_5 = [x for x in all_seqs_5 if x[0] != 'punctuation' or (x[0] == 'punctuation' and 'punctuation' in x[1:])] 201 | 202 | 203 | all_seqs_5 = [x for x in all_seqs_5 if len(x) < 2 or not(x[-1] == 'title_lower' and x[-2] == 'whitespace')] 204 | 205 | 206 | tmp_all_seqs = list() 207 | first_list = ['word_number_First', 'word_number_FIRST'] 208 | for seq in all_seqs_5: 209 | if 'desc' in seq: 210 | if 'word_number_First' in seq: 211 | word = 'word_number_First' 212 | elif 'word_number_FIRST' in seq: 213 | word = 'word_number_FIRST' 214 | else: 215 | tmp_all_seqs.append(seq) 216 | continue 217 | tmp_seq = list() 218 | for r in seq: 219 | if r == 'desc': 220 | tmp_seq.append(word) 221 | elif r in first_list: 222 | tmp_seq.append('desc') 223 | else: 224 | tmp_seq.append(r) 225 | tmp_all_seqs.append(tmp_seq) 226 | tmp_all_seqs.append(seq) 227 | all_seqs_5 = tmp_all_seqs 228 | 229 | words = ['desc', 'title_upper', 'title_lower', 'word_number_first', 'word_number_First', 'word_number_FIRST', 'word_number_one', 'word_number_One', 'word_number_ONE', 'roman_upper', 'roman_lower'] 230 | 231 | seqs_new = list() 232 | for s in all_seqs_5: 233 | b = ''.join(['1' if x in words else '0' for x in s]) 234 | if '11' not in b: 235 | seqs_new.append(s) 236 | all_seqs_5 = seqs_new 237 | 238 | priority = ['desc', 'roman_upper', 'roman_lower', 'numeral', 'word_number_first', 'word_number_First', 'word_number_FIRST', 'word_number_one', 'word_number_One', 'word_number_ONE', 'whitespace', 'title_upper', 'title_lower', 'punctuation', ''] 239 | 240 | # Make all rule sequences of equal length by appending empty strings 241 | m = max([len(x) for x in all_seqs_5]) 242 | all_seqs_5_new = list() 243 | for elem in all_seqs_5: 244 | all_seqs_5_new.append(elem + [''] * (m - len(elem))) 245 | all_seqs_5 = all_seqs_5_new 246 | # Sort the rules found using pre-defined priority 247 | all_seqs_5.sort(key=lambda x:[priority.index(y) for y in x]) 248 | # Remove the empty strings we appended earlier 249 | all_seqs_no_empty = [[elem for elem in x if elem] for x in all_seqs_5] 250 | 251 | rule_texts = [''.join(['(' + string_to_pattern[x] + ')' for x in y]) for y in all_seqs_no_empty] 252 | 253 | rules = [re.compile(x) for x in rule_texts] 254 | 255 | 256 | return all_seqs_5, all_seqs_no_empty, rules, priority 257 | 258 | 259 | def get_best_matching_rule(text, text_rules, regex_rules, priority): 260 | answers = list() 261 | for idx in range(len(regex_rules)): 262 | r = regex_rules[idx] 263 | m = r.match(text) 264 | if m: 265 | if m.span()[0] == 0 and len(text[m.span()[1]:].strip()) == 0: 266 | answers.append(text_rules[idx]) 267 | answers.sort(key=lambda x:[priority.index(y) for y in x]) 268 | 269 | return answers 270 | 271 | 272 | def get_best_rules_for_all(texts, sequences_as_lists, rules, priority): 273 | rules_found = list() 274 | for idx, text in enumerate(texts): 275 | ans = get_best_matching_rule(text, sequences_as_lists, rules, priority) 276 | if len(ans) > 0: 277 | rules_found.append(ans[0]) 278 | else: 279 | rules_found.append(None) 280 | if len(rules_found) == 0: 281 | return None 282 | return rules_found 283 | 284 | 285 | 286 | def find_highest_priority_rule(rules_found): 287 | c = Counter(rules_found) 288 | for rule in rules_found: 289 | if c[rule] > 1: 290 | return rule 291 | return rules_found[0] 292 | 293 | 294 | def get_matching_rule_beginning(text, text_rules, regex_rules, priority): 295 | answers = list() 296 | for idx in range(len(regex_rules)): 297 | r = regex_rules[idx] 298 | m = r.match(text) 299 | if m: 300 | if m.span()[0] == 0: 301 | answers.append(text_rules[idx]) 302 | answers.sort(key=lambda x:[priority.index(y) for y in x]) 303 | return answers 304 | -------------------------------------------------------------------------------- /segmentation/bert_full_window/01_generate_training_sequences.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import lxml.etree as ET 3 | import random 4 | 5 | from transformers import BertTokenizer 6 | 7 | def get_examples_book(pred_dir, tokenizer, book_id): 8 | try: 9 | filename = pred_dir + str(book_id) + '.xml' 10 | parser = ET.XMLParser(huge_tree=True) 11 | tree = ET.parse(filename, parser=parser) 12 | book = tree.getroot() 13 | b = book.find('.//body') 14 | 15 | headers = b.findall('.//header') 16 | 17 | 18 | start_para_nums = list() 19 | for h in headers: 20 | t = h.getnext() 21 | if t.tag == 'p': 22 | start_para_nums.append(int(t.attrib['num'])) 23 | 24 | start_para_nums.append(int(b.findall('.//p')[-1].attrib['num']) + 1) 25 | examples = list() 26 | prev_section_start = -1 27 | for idx, p_num in enumerate(start_para_nums[:-1]): 28 | 29 | # Positive example 30 | prev_tokens = list() 31 | prev_idx = p_num - 1 32 | while prev_idx >= prev_section_start and prev_idx >= 0 and len(prev_tokens) < 254: 33 | prev_elem = b.find('.//p[@num=\'' + str(prev_idx) + '\']') 34 | prev_text = ' '.join([x.text for x in prev_elem.findall('.//s')]) 35 | tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(str(prev_text))) 36 | prev_tokens = tokens + prev_tokens 37 | prev_tokens = prev_tokens[-254:] 38 | prev_idx -= 1 39 | 40 | next_section_start = start_para_nums[idx + 1] 41 | next_tokens = list() 42 | next_idx = p_num 43 | while next_idx < next_section_start and len(next_tokens) < 254: 44 | next_elem = b.find('.//p[@num=\'' + str(next_idx) + '\']') 45 | next_text = ' '.join([x.text for x in next_elem.findall('.//s')]) 46 | tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(str(next_text))) 47 | next_tokens = next_tokens + tokens 48 | next_tokens = next_tokens[:254] 49 | next_idx += 1 50 | 51 | examples.append([prev_tokens, next_tokens, 1]) 52 | 53 | 54 | # Previous chapter 55 | if prev_section_start != -1: 56 | idx_use = random.randint(prev_section_start, p_num - 2) 57 | 58 | prev_tokens = list() 59 | prev_idx = idx_use - 1 60 | while prev_idx >= prev_section_start and prev_idx >= 0 and len(prev_tokens) < 254: 61 | prev_elem = b.find('.//p[@num=\'' + str(prev_idx) + '\']') 62 | prev_text = ' '.join([x.text for x in prev_elem.findall('.//s')]) 63 | tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(str(prev_text))) 64 | prev_tokens = tokens + prev_tokens 65 | prev_tokens = prev_tokens[-254:] 66 | prev_idx -= 1 67 | 68 | next_section_start = start_para_nums[idx + 1] 69 | next_tokens = list() 70 | next_idx = idx_use 71 | while next_idx < p_num and len(next_tokens) < 254: 72 | next_elem = b.find('.//p[@num=\'' + str(next_idx) + '\']') 73 | next_text = ' '.join([x.text for x in next_elem.findall('.//s')]) 74 | tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(str(next_text))) 75 | next_tokens = next_tokens + tokens 76 | next_tokens = next_tokens[:254] 77 | next_idx += 1 78 | 79 | examples.append([prev_tokens, next_tokens, 0]) 80 | prev_section_start = p_num 81 | 82 | # Next chapter 83 | idx_use = random.randint(p_num + 1, start_para_nums[idx + 1] - 2) 84 | 85 | prev_tokens = list() 86 | prev_idx = idx_use - 1 87 | while prev_idx >= p_num and prev_idx >= 0 and len(prev_tokens) < 254: 88 | prev_elem = b.find('.//p[@num=\'' + str(prev_idx) + '\']') 89 | prev_text = ' '.join([x.text for x in prev_elem.findall('.//s')]) 90 | tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(str(prev_text))) 91 | prev_tokens = tokens + prev_tokens 92 | prev_tokens = prev_tokens[-254:] 93 | prev_idx -= 1 94 | 95 | next_section_start = start_para_nums[idx + 1] 96 | next_tokens = list() 97 | next_idx = idx_use 98 | while next_idx < start_para_nums[idx + 1] and len(next_tokens) < 254: 99 | next_elem = b.find('.//p[@num=\'' + str(next_idx) + '\']') 100 | next_text = ' '.join([x.text for x in next_elem.findall('.//s')]) 101 | tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(str(next_text))) 102 | next_tokens = next_tokens + tokens 103 | next_tokens = next_tokens[:254] 104 | next_idx += 1 105 | 106 | examples.append([prev_tokens, next_tokens, 0]) 107 | return examples 108 | except: 109 | return [] 110 | 111 | 112 | 113 | if __name__ == '__main__': 114 | # Use appropriate locations 115 | train_books_list_file = 'train_books.txt' 116 | pred_dir = 'use_books_sentencized/' 117 | save_loc = 'train_sequences_tokenized.csv' 118 | 119 | with open(train_books_list_file, 'r') as f: 120 | train_book_ids = [x.strip() for x in f.readlines()] 121 | 122 | 123 | print(len(train_book_ids), 'books') 124 | 125 | tokenizer=BertTokenizer.from_pretrained('bert-base-uncased') 126 | 127 | training_data = list() 128 | 129 | for idx, book_id in enumerate(train_book_ids): 130 | training_data += get_examples_book(pred_dir, tokenizer, book_id) 131 | if idx % 100 == 0: 132 | print(idx, len(training_data)) 133 | 134 | df = pd.DataFrame(training_data) 135 | 136 | df.rename(columns={0:'para1_tokens', 1:'para2_tokens', 2:'label'}, inplace=True) 137 | 138 | df['para1_len'] = df['para1_tokens'].apply(lambda x: len(x)) 139 | df['para2_len'] = df['para2_tokens'].apply(lambda x: len(x)) 140 | 141 | df.to_csv(save_loc, index=False) 142 | 143 | 144 | -------------------------------------------------------------------------------- /segmentation/bert_full_window/02_train_BERT_model.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | from transformers import BertTokenizer, BertForNextSentencePrediction 4 | from sklearn.model_selection import train_test_split 5 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 6 | 7 | from pytorch_pretrained_bert import BertAdam 8 | import numpy as np 9 | import os 10 | from keras.preprocessing.sequence import pad_sequences 11 | 12 | import time 13 | import datetime 14 | 15 | import pickle 16 | 17 | from ast import literal_eval 18 | 19 | def format_time(elapsed): 20 | ''' 21 | Takes a time in seconds and returns a string hh:mm:ss 22 | ''' 23 | # Round to the nearest second. 24 | elapsed_rounded = int(round((elapsed))) 25 | 26 | # Format as hh:mm:ss 27 | return str(datetime.timedelta(seconds=elapsed_rounded)) 28 | 29 | def get_tokens(tokenizer, toks1, toks2, cls, sep): 30 | toks1 = [cls] + toks1 + [sep] 31 | toks2 = toks2 + [sep] 32 | 33 | indexed_tokens = toks1 + toks2 34 | segments_ids = [0] * len(toks1) + [1] * len(toks2) 35 | 36 | return indexed_tokens, segments_ids 37 | 38 | def flat_accuracy(preds, labels): 39 | pred_flat = np.argmax(preds, axis=1).flatten() 40 | labels_flat = labels.flatten() 41 | return np.sum(pred_flat == labels_flat) / len(labels_flat) 42 | 43 | 44 | if __name__ == '__main__': 45 | 46 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 47 | 48 | print(device) 49 | 50 | # Use appropriate locations 51 | training_data_loc = 'train_sequences_tokenized.csv' 52 | output_loc = 'trained_model/' 53 | 54 | print('Reading training data file...') 55 | 56 | df = pd.read_csv(training_data_loc, usecols=['para1_tokens', 'para2_tokens', 'para1_len', 'para2_len', 'label']) 57 | 58 | df = df[(df['para1_len'] > 0) & (df['para2_len'] > 0)] 59 | 60 | df['para1_tokens'] = df['para1_tokens'].apply(literal_eval) 61 | df['para2_tokens'] = df['para2_tokens'].apply(literal_eval) 62 | 63 | 64 | print('Loading tokenizer and BertNSP...') 65 | tokenizer=BertTokenizer.from_pretrained('bert-base-uncased') 66 | BertNSP=BertForNextSentencePrediction.from_pretrained('bert-base-uncased') 67 | 68 | 69 | 70 | cls, sep = tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]"]) 71 | 72 | # Use appropriate locations 73 | try: 74 | with open(os.path.join(output_loc, 'input_tokens.pkl'), 'rb') as f: 75 | input_tokens = pickle.load(f) 76 | with open(os.path.join(output_loc, 'input_seg_ids.pkl'), 'rb') as f: 77 | input_seg_ids = pickle.load(f) 78 | with open(os.path.join(output_loc, 'labels.pkl'), 'rb') as f: 79 | labels = pickle.load(f) 80 | 81 | except: 82 | print('Generating training input...') 83 | input_tokens = list() 84 | input_seg_ids = list() 85 | labels = list() 86 | for idx, row in df.iterrows(): 87 | if idx % 10000 == 0: 88 | print(idx) 89 | indexed_tokens, segments_ids = get_tokens(tokenizer, row['para1_tokens'], row['para2_tokens'], cls, sep) 90 | input_tokens.append(indexed_tokens) 91 | input_seg_ids.append(segments_ids) 92 | labels.append(row['label']) 93 | # Use appropriate locations 94 | with open(os.path.join(output_loc, 'input_tokens.pkl'), 'wb') as f: 95 | pickle.dump(input_tokens, f) 96 | with open(os.path.join(output_loc, 'input_seg_ids.pkl'), 'wb') as f: 97 | pickle.dump(input_seg_ids, f) 98 | with open(os.path.join(output_loc, 'labels.pkl'), 'wb') as f: 99 | pickle.dump(labels, f) 100 | 101 | input_ids = pad_sequences(input_tokens, maxlen=512, dtype="long", value=0, truncating="pre", padding="post") 102 | seg_ids = pad_sequences(input_seg_ids, maxlen=512, dtype="long", value=1, truncating="pre", padding="post") 103 | attention_masks = [[int(token_id > 0) for token_id in sent] for sent in input_ids] 104 | 105 | 106 | train_input_ids, validation_input_ids, train_seg_ids, validation_seg_ids, train_attention_masks, validation_attention_masks, train_labels, validation_labels = train_test_split(input_ids, seg_ids, attention_masks, labels, random_state=2019, test_size=0.1) 107 | 108 | 109 | train_input_ids = torch.tensor(train_input_ids) 110 | validation_input_ids = torch.tensor(validation_input_ids) 111 | train_seg_ids = torch.tensor(train_seg_ids) 112 | validation_seg_ids = torch.tensor(validation_seg_ids) 113 | train_attention_masks = torch.tensor(train_attention_masks) 114 | validation_attention_masks = torch.tensor(validation_attention_masks) 115 | train_labels = torch.tensor(train_labels) 116 | validation_labels = torch.tensor(validation_labels) 117 | 118 | 119 | batch_size = 32 120 | 121 | 122 | train_data = TensorDataset(train_input_ids, train_seg_ids, train_attention_masks, train_labels) 123 | train_sampler = RandomSampler(train_data) 124 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size) 125 | 126 | 127 | validation_data = TensorDataset(validation_input_ids, validation_seg_ids, validation_attention_masks, validation_labels) 128 | validation_sampler = RandomSampler(validation_data) 129 | validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size) 130 | 131 | 132 | print("Initializing GPU...") 133 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 134 | n_gpu = torch.cuda.device_count() 135 | torch.cuda.get_device_name(0) 136 | 137 | 138 | model = BertNSP 139 | 140 | 141 | # Parameters 142 | param_optimizer = list(model.named_parameters()) 143 | no_decay = ['bias', 'gamma', 'beta'] 144 | optimizer_grouped_parameters = [ 145 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 146 | 'weight_decay_rate': 0.01}, 147 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 148 | 'weight_decay_rate': 0.0} 149 | ] 150 | 151 | # Optimizer 152 | optimizer = BertAdam(optimizer_grouped_parameters, lr=2e-5, warmup=.1) 153 | 154 | print("Making model GPU compatible") 155 | model = model.to(device) 156 | 157 | 158 | epochs = 4 159 | 160 | print('Starting training...') 161 | 162 | loss_values = [] 163 | for epoch_i in range(0, epochs): 164 | print('Epoch ', epoch_i) 165 | 166 | model.train() 167 | 168 | t0 = time.time() 169 | 170 | total_loss = 0 171 | 172 | for step, batch in enumerate(train_dataloader): 173 | 174 | if step % 40 == 0 and not step == 0: 175 | # Calculate elapsed time in minutes. 176 | elapsed = format_time(time.time() - t0) 177 | print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(train_dataloader), elapsed)) 178 | 179 | 180 | batch = tuple(t.to(device) for t in batch) 181 | b_input_ids, b_seg_ids, b_attention_masks, b_labels = batch 182 | 183 | optimizer.zero_grad() 184 | 185 | outputs = model(b_input_ids, token_type_ids=b_seg_ids, attention_mask=b_attention_masks, next_sentence_label=b_labels) 186 | 187 | loss = outputs[0] 188 | 189 | total_loss += loss.item() 190 | 191 | loss.backward() 192 | 193 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 194 | 195 | optimizer.step() 196 | 197 | avg_train_loss = total_loss / len(train_dataloader) 198 | loss_values.append(avg_train_loss) 199 | print(" Average training loss: {0:.2f}".format(avg_train_loss)) 200 | print(" Training epoch took: {:}".format(format_time(time.time() - t0))) 201 | 202 | model.eval() 203 | t0 = time.time() 204 | eval_loss, eval_accuracy = 0, 0 205 | nb_eval_steps, nb_eval_examples = 0, 0 206 | 207 | for batch in validation_dataloader: 208 | batch = tuple(t.to(device) for t in batch) 209 | b_input_ids, b_seg_ids, b_attention_masks, b_labels = batch 210 | with torch.no_grad(): 211 | outputs = model(b_input_ids, token_type_ids=b_seg_ids, attention_mask=b_attention_masks) 212 | 213 | logits = outputs[0] 214 | 215 | logits = logits.detach().cpu().numpy() 216 | label_ids = b_labels.to('cpu').numpy() 217 | 218 | tmp_eval_accuracy = flat_accuracy(logits, label_ids) 219 | eval_accuracy += tmp_eval_accuracy 220 | nb_eval_steps += 1 221 | 222 | print(" Accuracy: {0:.2f}".format(eval_accuracy/nb_eval_steps)) 223 | print(" Validation took: {:}".format(format_time(time.time() - t0))) 224 | 225 | output_dir = output_loc + 'model_' + str(epoch_i) 226 | if not os.path.exists(output_dir): 227 | os.makedirs(output_dir) 228 | 229 | model_to_save = model.module if hasattr(model, 'module') else model 230 | model_to_save.save_pretrained(output_dir) 231 | tokenizer.save_pretrained(output_dir) 232 | 233 | print('Saved model to ' + output_dir) 234 | -------------------------------------------------------------------------------- /segmentation/bert_full_window/03_generate_BERT_probabilities.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import pandas as pd 4 | 5 | import torch 6 | from transformers import BertTokenizer, BertForNextSentencePrediction 7 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 8 | from keras.preprocessing.sequence import pad_sequences 9 | 10 | from transformers import BertTokenizer 11 | 12 | import sys 13 | 14 | def process_book(bert_tok_dir, pred_scores_dir, BertNSP, device, cls, sep, book_id): 15 | with open(os.path.join(bert_tok_dir, book_id + '.pkl'), 'rb') as f: 16 | d = pickle.load(f) 17 | 18 | m = max(d.keys()) 19 | 20 | scores = dict() 21 | for idx in range(0, m - 1): 22 | toks1 = d[idx] 23 | toks2 = d[idx + 1] 24 | 25 | i = idx - 1 26 | while len(toks1) < 254 and i >= 0: 27 | toks1 = d[i] + toks1 28 | i -= 1 29 | toks1 = toks1[-254:] 30 | 31 | i = idx + 2 32 | while len(toks2) < 254 and i <= m: 33 | toks2 = toks2 + d[i] 34 | i += 1 35 | toks2 = toks2[:254] 36 | 37 | 38 | ids1 = [cls] + toks1 + [sep] 39 | ids2 = toks2 + [sep] 40 | 41 | indexed_tokens = ids1 + ids2 42 | segments_ids = [0] * len(ids1) + [1] * len(ids2) 43 | 44 | indexed_tokens = pad_sequences([indexed_tokens], maxlen=512, dtype='long', value=0, truncating="pre", padding="post") 45 | segments_ids = pad_sequences([segments_ids], maxlen=512, dtype="long", value=1, truncating="pre", padding="post") 46 | attention_masks = [[int(token_id > 0) for token_id in sent] for sent in indexed_tokens] 47 | 48 | tokens_tensor = torch.tensor(indexed_tokens) 49 | segments_tensors = torch.tensor(segments_ids) 50 | attention_tensor = torch.tensor(attention_masks) 51 | 52 | tokens_tensor = tokens_tensor.to(device) 53 | segments_tensors = segments_tensors.to(device) 54 | attention_tensor = attention_tensor.to(device) 55 | 56 | BertNSP.eval() 57 | prediction = BertNSP(tokens_tensor, token_type_ids=segments_tensors, attention_mask=attention_tensor) 58 | prediction = prediction[0] # tuple to tensor 59 | softmax = torch.nn.Softmax(dim=1) 60 | prediction_sm = softmax(prediction) 61 | 62 | scores[idx] = prediction_sm[0][1].item() 63 | 64 | with open(os.path.join(pred_scores_dir, book_id + '.pkl'), 'wb') as f: 65 | pickle.dump(scores, f) 66 | 67 | return 68 | 69 | 70 | if __name__ == '__main__': 71 | # Use appropriate locations 72 | test_books_list_file = 'test_books.txt' 73 | 74 | bert_tok_dir = 'test_books_bert_tok/' 75 | pred_scores_dir = 'test_preds/' 76 | model_dir = 'model_3/' 77 | 78 | with open(test_books_list_file, 'r') as f: 79 | test_book_ids = [x.strip() for x in f.readlines()] 80 | 81 | partition = int(sys.argv[1]) 82 | from_idx = partition * 1000 83 | to_idx = (partition + 1) * 1000 84 | 85 | test_book_ids = test_book_ids[from_idx:to_idx] 86 | print(len(test_book_ids), 'books') 87 | 88 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 89 | 90 | cls, sep = tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]"]) 91 | 92 | 93 | device = torch.device("cuda:" + str(partition + 1) if torch.cuda.is_available() else "cpu") 94 | print(device) 95 | 96 | print(torch.cuda.device_count(), "GPUs") 97 | 98 | model = BertForNextSentencePrediction.from_pretrained(model_dir) 99 | model = model.to(device) 100 | 101 | 102 | for book_id in test_book_ids: 103 | print(book_id) 104 | try: 105 | process_book(bert_tok_dir, pred_scores_dir, model, device, cls, sep, book_id) 106 | except Exception as e: 107 | print(book_id, e) 108 | 109 | print('Done!') 110 | -------------------------------------------------------------------------------- /segmentation/bert_single_para/01_generate_training_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import lxml.etree as ET 3 | import random 4 | 5 | def get_examples_book(pred_dir, book_id): 6 | try: 7 | filename = pred_dir + str(book_id) + '.xml' 8 | parser = ET.XMLParser(huge_tree=True) 9 | tree = ET.parse(filename, parser=parser) 10 | book = tree.getroot() 11 | b = book.find('.//body') 12 | 13 | headers = b.findall('.//header') 14 | 15 | 16 | start_para_nums = list() 17 | for h in headers: 18 | t = h.getnext() 19 | if t.tag == 'p': 20 | start_para_nums.append(int(t.attrib['num'])) 21 | 22 | start_para_nums.append(int(b.findall('.//p')[-1].attrib['num']) + 1) 23 | examples = list() 24 | prev_start = None 25 | for idx, p_num in enumerate(start_para_nums[:-1]): 26 | # This and previous 27 | prev_elem = b.find('.//p[@num=\'' + str(p_num - 1) + '\']') 28 | curr_elem = b.find('.//p[@num=\'' + str(p_num) + '\']') 29 | if prev_elem is None or curr_elem is None: 30 | continue 31 | prev_text = ' '.join([x.text for x in prev_elem.findall('.//s')]) 32 | curr_text = ' '.join([x.text for x in curr_elem.findall('.//s')]) 33 | examples.append([prev_text, curr_text, 1]) 34 | 35 | if prev_start is not None: 36 | # Find a random number >= prev_start and < p_num - 1 37 | idx_use = random.randint(prev_start, p_num - 2) 38 | prev_elem = b.find('.//p[@num=\'' + str(idx_use) + '\']') 39 | curr_elem = b.find('.//p[@num=\'' + str(idx_use + 1) + '\']') 40 | prev_text = ' '.join([x.text for x in prev_elem.findall('.//s')]) 41 | curr_text = ' '.join([x.text for x in curr_elem.findall('.//s')]) 42 | examples.append([prev_text, curr_text, 0]) 43 | prev_start = p_num 44 | 45 | next_start = start_para_nums[idx + 1] 46 | # Find a random number >= p_num and < start_para_nums[idx + 1] - 1 47 | idx_use = random.randint(p_num, start_para_nums[idx + 1] - 2) 48 | prev_elem = b.find('.//p[@num=\'' + str(idx_use) + '\']') 49 | curr_elem = b.find('.//p[@num=\'' + str(idx_use + 1) + '\']') 50 | prev_text = ' '.join([x.text for x in prev_elem.findall('.//s')]) 51 | curr_text = ' '.join([x.text for x in curr_elem.findall('.//s')]) 52 | examples.append([prev_text, curr_text, 0]) 53 | return examples 54 | except: 55 | return [] 56 | 57 | 58 | 59 | if __name__ == '__main__': 60 | # Use appropriate locations 61 | train_books_list_file = 'train_books.txt' 62 | pred_dir = 'use_books_sentencized/' 63 | save_loc = 'train_sequences.csv' 64 | 65 | with open(train_books_list_file, 'r') as f: 66 | train_book_ids = [x.strip() for x in f.readlines()] 67 | 68 | 69 | print(len(train_book_ids), 'books') 70 | 71 | training_data = list() 72 | 73 | for idx, book_id in enumerate(train_book_ids): 74 | training_data += get_examples_book(pred_dir, book_id) 75 | if idx % 100 == 0: 76 | print(idx, len(training_data)) 77 | 78 | df = pd.DataFrame(training_data) 79 | 80 | df.rename(columns={0:'para1', 1:'para2', 2:'label'}, inplace=True) 81 | 82 | df.to_csv(save_loc, index=False) 83 | 84 | 85 | -------------------------------------------------------------------------------- /segmentation/bert_single_para/02_tokenize_sequences.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | from transformers import BertTokenizer 4 | import numpy as np 5 | import os 6 | 7 | import pickle 8 | 9 | if __name__ == '__main__': 10 | # Use appropriate locations 11 | training_data_loc = 'train_sequences.csv' 12 | output_loc = 'train_sequences_tokenized.csv' 13 | 14 | print('Reading training data file...') 15 | 16 | df = pd.read_csv(training_data_loc) 17 | 18 | print('Loading tokenizer...') 19 | tokenizer=BertTokenizer.from_pretrained('bert-base-uncased') 20 | 21 | print('Tokenizing para1...') 22 | df['para1_tokens'] = df['para1'].apply(lambda x: tokenizer.convert_tokens_to_ids(tokenizer.tokenize(str(x)))) 23 | df['para1_len'] = df['para1_tokens'].apply(lambda x: len(x)) 24 | 25 | print('Tokenizing para2...') 26 | df['para2_tokens'] = df['para2'].apply(lambda x: tokenizer.convert_tokens_to_ids(tokenizer.tokenize(str(x)))) 27 | df['para2_len'] = df['para2_tokens'].apply(lambda x: len(x)) 28 | 29 | 30 | df.to_csv(output_loc) 31 | -------------------------------------------------------------------------------- /segmentation/bert_single_para/03_train_BERT_model.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | from transformers import BertTokenizer, BertForNextSentencePrediction 4 | from sklearn.model_selection import train_test_split 5 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 6 | 7 | from pytorch_pretrained_bert import BertAdam 8 | import numpy as np 9 | import os 10 | from keras.preprocessing.sequence import pad_sequences 11 | 12 | import time 13 | import datetime 14 | 15 | import pickle 16 | 17 | from ast import literal_eval 18 | 19 | def format_time(elapsed): 20 | ''' 21 | Takes a time in seconds and returns a string hh:mm:ss 22 | ''' 23 | # Round to the nearest second. 24 | elapsed_rounded = int(round((elapsed))) 25 | 26 | # Format as hh:mm:ss 27 | return str(datetime.timedelta(seconds=elapsed_rounded)) 28 | 29 | def get_tokens(tokenizer, toks1, toks2, cls, sep): 30 | l1 = len(toks1) 31 | l2 = len(toks2) 32 | if l1 + l2 >= 297: 33 | if l1 > 148 and l2 > 148: 34 | toks1 = toks1[-148:] 35 | toks2 = toks2[:148] 36 | elif l1 > 148: 37 | rem_len = 297 - l2 38 | toks1 = toks1[-rem_len:] 39 | elif l2 > 148: 40 | rem_len = 297 - l1 41 | toks2 = toks2[:rem_len] 42 | 43 | toks1 = [cls] + toks1 + [sep] 44 | toks2 = toks2 + [sep] 45 | 46 | indexed_tokens = toks1 + toks2 47 | segments_ids = [0] * len(toks1) + [1] * len(toks2) 48 | 49 | return indexed_tokens, segments_ids 50 | 51 | def flat_accuracy(preds, labels): 52 | pred_flat = np.argmax(preds, axis=1).flatten() 53 | labels_flat = labels.flatten() 54 | return np.sum(pred_flat == labels_flat) / len(labels_flat) 55 | 56 | 57 | if __name__ == '__main__': 58 | 59 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 60 | 61 | print(device) 62 | 63 | # Use appropriate locations 64 | training_data_loc = 'train_sequences_tokenized.csv' 65 | output_loc = 'trained_model/' 66 | 67 | print('Reading training data file...') 68 | 69 | df = pd.read_csv(training_data_loc, usecols=['para1_tokens', 'para2_tokens', 'label']) 70 | 71 | df['para1_tokens'] = df['para1_tokens'].apply(literal_eval) 72 | df['para2_tokens'] = df['para2_tokens'].apply(literal_eval) 73 | 74 | print('Loading tokenizer and BertNSP...') 75 | tokenizer=BertTokenizer.from_pretrained('bert-base-uncased') 76 | BertNSP=BertForNextSentencePrediction.from_pretrained('bert-base-uncased') 77 | 78 | 79 | 80 | cls, sep = tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]"]) 81 | 82 | try: 83 | with open(os.path.join(output_loc, 'input_tokens.pkl'), 'rb') as f: 84 | input_tokens = pickle.load(f) 85 | with open(os.path.join(output_loc, 'input_seg_ids.pkl'), 'rb') as f: 86 | input_seg_ids = pickle.load(f) 87 | with open(os.path.join(output_loc, 'labels.pkl'), 'rb') as f: 88 | labels = pickle.load(f) 89 | 90 | except: 91 | print('Generating training input...') 92 | input_tokens = list() 93 | input_seg_ids = list() 94 | labels = list() 95 | for idx, row in df.iterrows(): 96 | if idx % 10000 == 0: 97 | print(idx) 98 | indexed_tokens, segments_ids = get_tokens(tokenizer, row['para1_tokens'], row['para2_tokens'], cls, sep) 99 | input_tokens.append(indexed_tokens) 100 | input_seg_ids.append(segments_ids) 101 | labels.append(row['label']) 102 | 103 | with open(os.path.join(output_loc, 'input_tokens.pkl'), 'wb') as f: 104 | pickle.dump(input_tokens, f) 105 | with open(os.path.join(output_loc, 'input_seg_ids.pkl'), 'wb') as f: 106 | pickle.dump(input_seg_ids, f) 107 | with open(os.path.join(output_loc, 'labels.pkl'), 'wb') as f: 108 | pickle.dump(labels, f) 109 | 110 | input_ids = pad_sequences(input_tokens, maxlen=300, dtype="long", value=0, truncating="pre", padding="post") 111 | seg_ids = pad_sequences(input_seg_ids, maxlen=300, dtype="long", value=1, truncating="pre", padding="post") 112 | attention_masks = [[int(token_id > 0) for token_id in sent] for sent in input_ids] 113 | 114 | 115 | train_input_ids, validation_input_ids, train_seg_ids, validation_seg_ids, train_attention_masks, validation_attention_masks, train_labels, validation_labels = train_test_split(input_ids, seg_ids, attention_masks, labels, random_state=2019, test_size=0.1) 116 | 117 | 118 | train_input_ids = torch.tensor(train_input_ids) 119 | validation_input_ids = torch.tensor(validation_input_ids) 120 | train_seg_ids = torch.tensor(train_seg_ids) 121 | validation_seg_ids = torch.tensor(validation_seg_ids) 122 | train_attention_masks = torch.tensor(train_attention_masks) 123 | validation_attention_masks = torch.tensor(validation_attention_masks) 124 | train_labels = torch.tensor(train_labels) 125 | validation_labels = torch.tensor(validation_labels) 126 | 127 | 128 | batch_size = 32 129 | 130 | 131 | train_data = TensorDataset(train_input_ids, train_seg_ids, train_attention_masks, train_labels) 132 | train_sampler = RandomSampler(train_data) 133 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size) 134 | 135 | 136 | validation_data = TensorDataset(validation_input_ids, validation_seg_ids, validation_attention_masks, validation_labels) 137 | validation_sampler = RandomSampler(validation_data) 138 | validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size) 139 | 140 | 141 | print("Initializing GPU...") 142 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 143 | n_gpu = torch.cuda.device_count() 144 | torch.cuda.get_device_name(0) 145 | 146 | 147 | model = BertNSP 148 | 149 | 150 | # Parameters 151 | param_optimizer = list(model.named_parameters()) 152 | no_decay = ['bias', 'gamma', 'beta'] 153 | optimizer_grouped_parameters = [ 154 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 155 | 'weight_decay_rate': 0.01}, 156 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 157 | 'weight_decay_rate': 0.0} 158 | ] 159 | 160 | # Optimizer 161 | optimizer = BertAdam(optimizer_grouped_parameters, lr=2e-5, warmup=.1) 162 | 163 | print("Making model GPU compatible") 164 | model = model.to(device) 165 | 166 | 167 | epochs = 4 168 | 169 | print('Starting training...') 170 | 171 | loss_values = [] 172 | for epoch_i in range(0, epochs): 173 | print('Epoch ', epoch_i) 174 | 175 | model.train() 176 | 177 | t0 = time.time() 178 | 179 | total_loss = 0 180 | 181 | for step, batch in enumerate(train_dataloader): 182 | 183 | if step % 40 == 0 and not step == 0: 184 | # Calculate elapsed time in minutes. 185 | elapsed = format_time(time.time() - t0) 186 | print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(train_dataloader), elapsed)) 187 | 188 | 189 | batch = tuple(t.to(device) for t in batch) 190 | b_input_ids, b_seg_ids, b_attention_masks, b_labels = batch 191 | 192 | optimizer.zero_grad() 193 | 194 | outputs = model(b_input_ids, token_type_ids=b_seg_ids, attention_mask=b_attention_masks, next_sentence_label=b_labels) 195 | 196 | loss = outputs[0] 197 | 198 | total_loss += loss.item() 199 | 200 | loss.backward() 201 | 202 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 203 | 204 | optimizer.step() 205 | 206 | avg_train_loss = total_loss / len(train_dataloader) 207 | loss_values.append(avg_train_loss) 208 | print(" Average training loss: {0:.2f}".format(avg_train_loss)) 209 | print(" Training epoch took: {:}".format(format_time(time.time() - t0))) 210 | 211 | model.eval() 212 | t0 = time.time() 213 | eval_loss, eval_accuracy = 0, 0 214 | nb_eval_steps, nb_eval_examples = 0, 0 215 | 216 | for batch in validation_dataloader: 217 | batch = tuple(t.to(device) for t in batch) 218 | b_input_ids, b_seg_ids, b_attention_masks, b_labels = batch 219 | with torch.no_grad(): 220 | outputs = model(b_input_ids, token_type_ids=b_seg_ids, attention_mask=b_attention_masks) 221 | 222 | logits = outputs[0] 223 | 224 | logits = logits.detach().cpu().numpy() 225 | label_ids = b_labels.to('cpu').numpy() 226 | 227 | tmp_eval_accuracy = flat_accuracy(logits, label_ids) 228 | eval_accuracy += tmp_eval_accuracy 229 | nb_eval_steps += 1 230 | 231 | print(" Accuracy: {0:.2f}".format(eval_accuracy/nb_eval_steps)) 232 | print(" Validation took: {:}".format(format_time(time.time() - t0))) 233 | 234 | output_dir = output_loc + 'model_' + str(epoch_i) 235 | if not os.path.exists(output_dir): 236 | os.makedirs(output_dir) 237 | 238 | model_to_save = model.module if hasattr(model, 'module') else model 239 | model_to_save.save_pretrained(output_dir) 240 | tokenizer.save_pretrained(output_dir) 241 | 242 | print('Saved model to ' + output_dir) 243 | -------------------------------------------------------------------------------- /segmentation/bert_single_para/04_generate_BERT_probabilities.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import pandas as pd 4 | 5 | import torch 6 | from transformers import BertTokenizer, BertForNextSentencePrediction 7 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 8 | from keras.preprocessing.sequence import pad_sequences 9 | 10 | from transformers import BertTokenizer 11 | 12 | import sys 13 | 14 | def process_book(bert_tok_dir, pred_scores_dir, BertNSP, device, cls, sep, book_id): 15 | with open(os.path.join(bert_tok_dir, book_id + '.pkl'), 'rb') as f: 16 | d = pickle.load(f) 17 | 18 | m = max(d.keys()) 19 | 20 | scores = dict() 21 | for idx in range(0, m - 1): 22 | toks1 = d[idx] 23 | toks2 = d[idx + 1] 24 | 25 | l1 = len(toks1) 26 | l2 = len(toks2) 27 | if l1 + l2 >= 297: 28 | if l1 > 148 and l2 > 148: 29 | toks1 = toks1[-148:] 30 | toks2 = toks2[:148] 31 | elif l1 > 148: 32 | rem_len = 297 - l2 33 | toks1 = toks1[-rem_len:] 34 | elif l2 > 148: 35 | rem_len = 297 - l1 36 | toks2 = toks2[:rem_len] 37 | 38 | ids1 = [cls] + toks1 + [sep] 39 | ids2 = toks2 + [sep] 40 | 41 | indexed_tokens = ids1 + ids2 42 | segments_ids = [0] * len(ids1) + [1] * len(ids2) 43 | 44 | indexed_tokens = pad_sequences([indexed_tokens], maxlen=300, dtype='long', value=0, truncating="pre", padding="post") 45 | segments_ids = pad_sequences([segments_ids], maxlen=300, dtype="long", value=1, truncating="pre", padding="post") 46 | attention_masks = [[int(token_id > 0) for token_id in sent] for sent in indexed_tokens] 47 | 48 | tokens_tensor = torch.tensor(indexed_tokens) 49 | segments_tensors = torch.tensor(segments_ids) 50 | attention_tensor = torch.tensor(attention_masks) 51 | 52 | tokens_tensor = tokens_tensor.to(device) 53 | segments_tensors = segments_tensors.to(device) 54 | attention_tensor = attention_tensor.to(device) 55 | 56 | BertNSP.eval() 57 | prediction = BertNSP(tokens_tensor, token_type_ids=segments_tensors, attention_mask=attention_tensor) 58 | prediction = prediction[0] # tuple to tensor 59 | softmax = torch.nn.Softmax(dim=1) 60 | prediction_sm = softmax(prediction) 61 | 62 | scores[idx] = prediction_sm[0][1].item() 63 | 64 | with open(os.path.join(pred_scores_dir, book_id + '.pkl'), 'wb') as f: 65 | pickle.dump(scores, f) 66 | 67 | return 68 | 69 | 70 | if __name__ == '__main__': 71 | # Use appropriate locations 72 | test_books_list_file = 'test_books.txt' 73 | 74 | bert_tok_dir = 'test_books_bert_tok/' 75 | pred_scores_dir = 'test_preds/' 76 | 77 | model_dir = 'model_3/' 78 | 79 | with open(test_books_list_file, 'r') as f: 80 | test_book_ids = [x.strip() for x in f.readlines()] 81 | 82 | partition = int(sys.argv[1]) 83 | from_idx = partition * 1000 84 | to_idx = (partition + 1) * 1000 85 | 86 | test_book_ids = test_book_ids[from_idx:to_idx] 87 | print(len(test_book_ids), 'books') 88 | 89 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 90 | 91 | cls, sep = tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]"]) 92 | 93 | 94 | device = torch.device("cuda:" + str(partition + 1) if torch.cuda.is_available() else "cpu") 95 | print(device) 96 | 97 | print(torch.cuda.device_count(), "GPUs") 98 | 99 | model = BertForNextSentencePrediction.from_pretrained(model_dir) 100 | model = model.to(device) 101 | 102 | 103 | for book_id in test_book_ids: 104 | print(book_id) 105 | try: 106 | process_book(bert_tok_dir, pred_scores_dir, model, device, cls, sep, book_id) 107 | except Exception as e: 108 | print(book_id, e) 109 | 110 | print('Done!') 111 | -------------------------------------------------------------------------------- /segmentation/bert_single_para/05_generate_predictions_dp.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import pickle 4 | import pandas as pd 5 | import numpy as np 6 | 7 | from functools import partial 8 | import multiprocessing 9 | 10 | 11 | class DPSolver(): 12 | def __init__(self, peaks, prominences, num_sentences, breaks_to_insert, alpha): 13 | self.peaks = list(peaks) 14 | self.peaks.append(num_sentences) 15 | self.prominences = prominences 16 | self.prominences.append(0) 17 | self.num_sentences = num_sentences 18 | self.k = breaks_to_insert 19 | self.dp_dict = dict() 20 | self.N = len(self.peaks) - 1 21 | self.ideal_length = num_sentences / (self.k + 1) 22 | self.prev_dict = dict() 23 | self.alpha = alpha 24 | 25 | def get_distance(self, idx1, idx2): 26 | if idx1 == -1: 27 | return self.peaks[idx2] / self.ideal_length 28 | sent_1 = self.peaks[idx1] 29 | sent_2 = self.peaks[idx2] 30 | return (abs((sent_2 - sent_1) - self.ideal_length)) / float(self.ideal_length) 31 | 32 | def get_mins(self, costs): 33 | keys = list(costs.keys()) 34 | min_key = keys[0] 35 | min_val = costs[keys[0]] 36 | for k in keys[1:]: 37 | if costs[k] < min_val: 38 | min_val = costs[k] 39 | min_key = k 40 | return min_key, min_val 41 | 42 | def dp_func(self, N, k): 43 | if k > N: 44 | return None 45 | # Memoized 46 | if (N, k) in self.dp_dict: 47 | return self.dp_dict[(N, k)] 48 | # Base case 49 | if k == 0: 50 | self.dp_dict[(N, k)] = -(self.prominences[N] * self.alpha) + (self.get_distance(-1, N) * (1 - self.alpha)) 51 | return self.dp_dict[(N, k)] 52 | 53 | # Recursive call 54 | costs = dict() 55 | for i in range(0, N): 56 | c = self.dp_func(i, k - 1) 57 | if c: 58 | costs[i] = c + (self.get_distance(i, N) * (1 - self.alpha)) 59 | if len(costs) == 0: 60 | self.dp_dict[(N, k)] = None 61 | return None 62 | 63 | min_N, min_cost = self.get_mins(costs) 64 | 65 | ans = min_cost - (self.prominences[N] * self.alpha) 66 | self.dp_dict[(N, k)] = ans 67 | self.prev_dict[(N, k)] = min_N 68 | return ans 69 | 70 | def solve(self): 71 | x = self.dp_func(self.N, self.k) 72 | return x 73 | 74 | def get_best_sequence(self): 75 | x = self.solve() 76 | ans_seq = list() 77 | N = self.N 78 | k = self.k 79 | while True: 80 | if (N, k) not in self.prev_dict: 81 | break 82 | previous = self.prev_dict[(N, k)] 83 | ans_seq.append(previous) 84 | N = previous 85 | k -= 1 86 | return ans_seq[::-1] 87 | 88 | 89 | 90 | def get_predictions(peaks, prominences, num_preds, max_sent_num, alpha): 91 | 92 | 93 | dps = DPSolver(peaks, prominences, max_sent_num + 1, num_preds, alpha) 94 | preds = dps.get_best_sequence() 95 | dp_predictions = [peaks[x] for x in preds] 96 | return dp_predictions 97 | 98 | 99 | def process_book(break_probs_dir, para_to_sent_dir, gt_dir, output_dir, book_id): 100 | with open(os.path.join(break_probs_dir, book_id + '.pkl'), 'rb') as f: 101 | break_probs = pickle.load(f) 102 | with open(os.path.join(para_to_sent_dir, book_id + '.pkl'), 'rb') as f: 103 | para_to_sent = pickle.load(f) 104 | 105 | peaks = list() 106 | prominences = list() 107 | for n, prob in break_probs.items(): 108 | if prob > 0.9: 109 | peaks.append(para_to_sent[n]) 110 | prominences.append(np.log(prob)) 111 | 112 | with open(os.path.join(gt_dir, book_id + '_gt_sents.pkl'), 'rb') as f: 113 | gt = pickle.load(f) 114 | with open(os.path.join(gt_dir, book_id + '_max_sent_num.pkl'), 'rb') as f: 115 | max_sent_num = int(pickle.load(f)) 116 | 117 | num_preds = len(gt) 118 | 119 | for alpha in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]: 120 | preds = get_predictions(peaks, prominences, num_preds, max_sent_num, alpha) 121 | with open(os.path.join(output_dir, book_id + '_alpha_' + str(int(alpha * 100)) + '.pkl'), 'wb') as f: 122 | pickle.dump(preds, f) 123 | 124 | print(book_id, 'success') 125 | return book_id, 'Success' 126 | 127 | 128 | 129 | 130 | if __name__ == "__main__": 131 | 132 | # Use appropriate locations 133 | test_books_list_file = 'test_books.txt' 134 | with open(test_books_list_file, 'r') as f: 135 | test_book_ids = [x.strip() for x in f.readlines()] 136 | 137 | gt_dir = 'test_gt_sentences/' 138 | 139 | break_probs_dir = 'test_preds/' 140 | 141 | para_to_sent_dir = 'test_books_para_to_sent/' 142 | 143 | output_dir = 'thresh_0_9/' 144 | 145 | if not os.path.exists(break_probs_dir): 146 | print('Invalid break probs dir') 147 | exit() 148 | 149 | if not os.path.exists(gt_dir): 150 | print('Invalid ground truth dir') 151 | exit() 152 | 153 | if not os.path.exists(output_dir): 154 | os.makedirs(output_dir) 155 | 156 | func = partial(process_book, break_probs_dir, para_to_sent_dir, gt_dir, output_dir) 157 | 158 | pool = multiprocessing.Pool(processes=32) 159 | data = pool.map(func, test_book_ids) 160 | pool.close() 161 | pool.join() 162 | 163 | df = pd.DataFrame(data) 164 | df.rename(columns={0:'book_id', 1:'status'}, inplace=True) 165 | df.to_csv('./log_file_dp_single_para.csv', index=False) 166 | print('Done!') 167 | 168 | -------------------------------------------------------------------------------- /segmentation/bert_tokenize_test_books.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import lxml.etree as ET 3 | import random 4 | import os 5 | 6 | import pickle 7 | import multiprocessing 8 | from functools import partial 9 | 10 | from transformers import BertTokenizer 11 | 12 | def process_book(sent_dir, bert_tok_dir, tokenizer, book_id): 13 | try: 14 | filename = os.path.join(sent_dir, book_id + '.xml') 15 | parser = ET.XMLParser(huge_tree=True) 16 | tree = ET.parse(filename, parser=parser) 17 | book = tree.getroot() 18 | b = book.find('.//body') 19 | 20 | paragraphs = b.findall('.//p') 21 | 22 | d = dict() 23 | for p in paragraphs: 24 | n = int(p.attrib['num']) 25 | text = ' '.join([x.text for x in p.findall('.//s')]) 26 | 27 | tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(str(text))) 28 | 29 | d[n] = tokens 30 | 31 | with open(os.path.join(bert_tok_dir, book_id + '.pkl'), 'wb') as f: 32 | pickle.dump(d, f) 33 | 34 | return book_id, 0 35 | except: 36 | return book_id, -1 37 | 38 | 39 | if __name__ == '__main__': 40 | # Use appropriate locations 41 | test_books_list_file = 'test_books.txt' 42 | 43 | sent_dir = 'use_books_sentencized/' 44 | bert_tok_dir = 'test_books_bert_tok/' 45 | 46 | with open(test_books_list_file, 'r') as f: 47 | test_book_ids = [x.strip() for x in f.readlines()] 48 | 49 | print(len(test_book_ids), 'books') 50 | 51 | tokenizer=BertTokenizer.from_pretrained('bert-base-uncased') 52 | 53 | 54 | 55 | func = partial(process_book, sent_dir, bert_tok_dir, tokenizer) 56 | 57 | pool = multiprocessing.Pool(processes=32) 58 | data = pool.map(func, test_book_ids) 59 | pool.close() 60 | pool.join() 61 | 62 | df = pd.DataFrame(data) 63 | df.rename(columns={0:'book_id', 1:'status'}, inplace=True) 64 | df.to_csv('log_test_tok_bert.csv', index=False) 65 | print('Done!') 66 | 67 | -------------------------------------------------------------------------------- /segmentation/generate_ground_truth.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import lxml.etree as ET 3 | import random 4 | import os 5 | 6 | import pickle 7 | import multiprocessing 8 | from functools import partial 9 | 10 | def get_examples_book(input_dir, output_dir, book_id): 11 | filename = input_dir + str(book_id) + '.xml' 12 | parser = ET.XMLParser(huge_tree=True) 13 | tree = ET.parse(filename, parser=parser) 14 | book = tree.getroot() 15 | b = book.find('.//body') 16 | 17 | headers = b.findall('.//header') 18 | 19 | start_para_nums = list() 20 | for h in headers: 21 | t = h.getnext() 22 | if t.tag == 'p': 23 | start_para_nums.append(int(t.attrib['num'])) 24 | 25 | gt = list() 26 | for para_num in start_para_nums: 27 | last_p = b.find('.//p[@num=\'' + str(para_num - 1) + '\']') 28 | if last_p is None: 29 | continue 30 | sents = last_p.findall('.//s') 31 | num = int(sents[-1].attrib['num']) 32 | gt.append(num) 33 | 34 | max_sent_num = b.findall('.//s')[-1].attrib['num'] 35 | 36 | with open(os.path.join(output_dir, book_id + '_gt_sents.pkl'), 'wb') as f: 37 | pickle.dump(gt, f) 38 | with open(os.path.join(output_dir, book_id + '_max_sent_num.pkl'), 'wb') as f: 39 | pickle.dump(max_sent_num, f) 40 | 41 | return book_id, 0 42 | 43 | if __name__ == '__main__': 44 | # Use appropriate locations 45 | test_books_list_file = 'test_books.txt' 46 | 47 | sent_dir = 'use_books_sentencized/' 48 | output_dir = 'test_gt_sentences/' 49 | 50 | with open(test_books_list_file, 'r') as f: 51 | test_book_ids = [x.strip() for x in f.readlines()] 52 | 53 | 54 | print(len(test_book_ids), 'books') 55 | 56 | func = partial(get_examples_book, sent_dir, output_dir) 57 | 58 | pool = multiprocessing.Pool(processes=32) 59 | data = pool.map(func, test_book_ids) 60 | pool.close() 61 | pool.join() 62 | 63 | df = pd.DataFrame(data) 64 | df.rename(columns={0:'book_id', 1:'status'}, inplace=True) 65 | df.to_csv('log.csv', index=False) 66 | print('Done!') 67 | -------------------------------------------------------------------------------- /segmentation/metrics/generate_metrics.py: -------------------------------------------------------------------------------- 1 | from nltk.metrics import segmentation as seg 2 | import pickle 3 | import pandas as pd 4 | import os 5 | 6 | import multiprocessing 7 | from functools import partial 8 | 9 | def get_standard_metrics(gt, pred, msn): 10 | gt_segs = ''.join(['1' if i in gt else '0' for i in range(msn)]) 11 | pred_segs = ''.join(['1' if i in pred else '0' for i in range(msn)]) 12 | k_val = int(round(len(gt_segs) / (gt_segs.count('1') * 2.0))) 13 | k_val = k_val // 4 14 | return seg.pk(gt_segs, pred_segs, k=k_val), seg.windowdiff(gt_segs, pred_segs, k=k_val) 15 | 16 | def get_prec_rec_f1(gt, pred): 17 | tp = len([x for x in pred if x in gt]) 18 | fp = len([x for x in pred if x not in gt]) 19 | fn = len([x for x in gt if x not in pred]) 20 | 21 | precision, recall, f1 = None, None, None 22 | 23 | try: 24 | precision = tp / (tp + fp) 25 | except: 26 | pass 27 | try: 28 | recall = tp / (tp + fn) 29 | except: 30 | pass 31 | try: 32 | f1 = 2 * precision * recall / (precision + recall) 33 | except: 34 | pass 35 | 36 | return precision, recall, f1 37 | 38 | 39 | def get_gt_msn(gt_dir, book_id): 40 | with open(os.path.join(gt_dir, book_id + '_gt_sents.pkl'), 'rb') as f: 41 | gt = pickle.load(f) 42 | with open(os.path.join(gt_dir, book_id + '_max_sent_num.pkl'), 'rb') as f: 43 | max_sent_num = int(pickle.load(f)) 44 | return gt, max_sent_num 45 | 46 | def get_pred(pred_dir, book_id, alpha): 47 | with open(os.path.join(pred_dir, book_id + '_alpha_' + str(alpha) + '.pkl'), 'rb') as f: 48 | preds = pickle.load(f) 49 | return preds 50 | 51 | def get_metrics(gt_loc, pred_loc, book_id): 52 | gt, msn = get_gt_msn(gt_loc, book_id) 53 | 54 | ans = list() 55 | 56 | for alpha in [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]: 57 | pk, wd, p, r, f1 = None, None, None, None, None 58 | 59 | try: 60 | pred = get_pred(pred_loc, book_id, alpha) 61 | except: 62 | ans.append([book_id, pk, wd, p, r, f1]) 63 | continue 64 | 65 | try: 66 | pk, wd = get_standard_metrics(gt, pred, msn) 67 | except: 68 | pass 69 | try: 70 | p, r, f1 = get_prec_rec_f1(gt, pred) 71 | except: 72 | pass 73 | ans.append([book_id, pk, wd, p, r, f1]) 74 | return ans 75 | 76 | 77 | if __name__ == "__main__": 78 | 79 | # Use appropriate locations 80 | gt_loc = 'test_gt_sentences/' 81 | 82 | pred_loc = 'woc/window_size_200/dp/' 83 | output_dir = 'woc/window_size_200/results/' 84 | 85 | if not os.path.exists(output_dir): 86 | os.makedirs(output_dir) 87 | 88 | test_books_list_file = 'test_books.txt' 89 | with open(test_books_list_file, 'r') as f: 90 | test_book_ids = [x.strip() for x in f.readlines()] 91 | 92 | d = dict() 93 | for idx in [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]: 94 | d[idx] = list() 95 | 96 | func = partial(get_metrics, gt_loc, pred_loc) 97 | 98 | pool = multiprocessing.Pool(processes=32) 99 | data = pool.map(func, test_book_ids) 100 | pool.close() 101 | pool.join() 102 | 103 | 104 | for metrics_b in data: 105 | for idx, elem in enumerate(metrics_b): 106 | d[idx * 10].append(elem) 107 | 108 | for alpha in d: 109 | print(alpha) 110 | df = pd.DataFrame(d[alpha], columns=['book_id', 'pk', 'wd', 'precision', 'recall', 'f1']) 111 | df.to_csv(os.path.join(output_dir, 'alpha_' + str(alpha) + '.csv')) 112 | 113 | print('Done!') 114 | 115 | -------------------------------------------------------------------------------- /segmentation/paragraph_to_sentence.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import lxml.etree as ET 3 | import random 4 | import os 5 | 6 | import pickle 7 | import multiprocessing 8 | from functools import partial 9 | 10 | def process_book(sent_dir, para_to_sent_dir, book_id): 11 | try: 12 | filename = os.path.join(sent_dir, book_id + '.xml') 13 | parser = ET.XMLParser(huge_tree=True) 14 | tree = ET.parse(filename, parser=parser) 15 | book = tree.getroot() 16 | b = book.find('.//body') 17 | 18 | paragraphs = b.findall('.//p') 19 | 20 | d = dict() 21 | for p in paragraphs: 22 | n = int(p.attrib['num']) 23 | 24 | sents = p.findall('.//s') 25 | num = int(sents[-1].attrib['num']) 26 | d[n] = num 27 | 28 | with open(os.path.join(para_to_sent_dir, book_id + '.pkl'), 'wb') as f: 29 | pickle.dump(d, f) 30 | 31 | return book_id, 0 32 | except: 33 | return book_id, -1 34 | 35 | 36 | if __name__ == '__main__': 37 | # Use appropriate locations 38 | test_books_list_file = 'test_books.txt' 39 | 40 | sent_dir = 'use_books_sentencized/' 41 | para_to_sent_dir = 'test_books_para_to_sent/' 42 | 43 | with open(test_books_list_file, 'r') as f: 44 | test_book_ids = [x.strip() for x in f.readlines()] 45 | 46 | 47 | #test_book_ids = test_book_ids[:10] 48 | print(len(test_book_ids), 'books') 49 | 50 | 51 | func = partial(process_book, sent_dir, para_to_sent_dir) 52 | 53 | pool = multiprocessing.Pool(processes=32) 54 | data = pool.map(func, test_book_ids) 55 | pool.close() 56 | pool.join() 57 | 58 | df = pd.DataFrame(data) 59 | df.rename(columns={0:'book_id', 1:'status'}, inplace=True) 60 | df.to_csv('log_test_tok_para_to_sent.csv', index=False) 61 | print('Done!') 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /segmentation/tokenize_books.py: -------------------------------------------------------------------------------- 1 | import lxml.etree as ET 2 | import os 3 | import glob 4 | 5 | from stanfordnlp.server import CoreNLPClient 6 | 7 | import pickle 8 | 9 | import pandas as pd 10 | 11 | from functools import partial 12 | import multiprocessing 13 | 14 | 15 | def sentencize(header_annot_dir, client, book_id): 16 | 17 | filename = os.path.join(header_annot_dir, str(book_id) + '.xml') 18 | 19 | parser = ET.XMLParser(huge_tree=True) 20 | tree = ET.parse(filename, parser=parser) 21 | book = tree.getroot() 22 | 23 | para_end_sentences = list() 24 | lemma_dict = dict() 25 | 26 | start_sentence_number = 0 27 | 28 | b = book.find('.//body') 29 | 30 | header_elems = [x for idx, x in enumerate(b)] 31 | 32 | for idx, element in enumerate(header_elems): 33 | content = element.tail 34 | element.tail = "" 35 | if content is None: 36 | continue 37 | 38 | ann = client.annotate(content) 39 | 40 | init_offset = content.index(ann.sentence[0].token[0].originalText) 41 | 42 | prev = element 43 | 44 | for idx_2, sent in enumerate(ann.sentence): 45 | 46 | sentence_tag = ET.Element("s") 47 | sentence_tag.text = content[init_offset + sent.characterOffsetBegin:init_offset + sent.characterOffsetEnd] 48 | num = start_sentence_number + idx_2 49 | sentence_tag.set('num', str(num)) 50 | 51 | if sent.token[-1].after.startswith('\n\n') or idx_2 == len(ann.sentence) - 1: 52 | para_end_sentences.append(num) 53 | 54 | lemma_dict[num] = [tok.lemma for tok in sent.token] 55 | 56 | prev.addnext(sentence_tag) 57 | prev = sentence_tag 58 | 59 | start_sentence_number += len(ann.sentence) 60 | 61 | if len(para_end_sentences) == 0: 62 | para_end_sentences = [num] 63 | if para_end_sentences[-1] != num: 64 | para_end_sentences.append(num) 65 | 66 | tree = ET.ElementTree(book) 67 | 68 | return tree, para_end_sentences, lemma_dict 69 | 70 | 71 | def paragraphize(tree, para_end_sentences): 72 | book = tree.getroot() 73 | 74 | body = book.find('.//body') 75 | 76 | elems = [x for x in body] 77 | 78 | new_body = ET.Element('body') 79 | 80 | para_num = 0 81 | start = 0 82 | end = para_end_sentences[0] 83 | 84 | for elem in elems: 85 | if elem.tag == 'header': 86 | new_body.append(elem) 87 | continue 88 | num = int(elem.get('num')) 89 | if num == start: 90 | current_para = ET.Element('p') 91 | current_para.set('num', str(para_num)) 92 | if num >= start and num <= end: 93 | current_para.append(elem) 94 | if num == end: 95 | new_body.append(current_para) 96 | para_num += 1 97 | start = end + 1 98 | if para_num < len(para_end_sentences): 99 | end = para_end_sentences[para_num] 100 | else: 101 | end = None 102 | 103 | idx = [idx for idx, elem in enumerate(book) if elem.tag == 'body'][0] 104 | book[idx] = new_body 105 | 106 | tree = ET.ElementTree(book) 107 | 108 | return tree 109 | 110 | 111 | def process_book(header_annot_dir, lemma_dir, tree_dir, book_id): 112 | 113 | if os.path.exists(os.path.join(tree_dir, book_id + '.xml')) and os.path.exists(os.path.join(lemma_dir, book_id + '.pkl')): 114 | return book_id, 'Exists' 115 | 116 | os.environ["CORENLP_HOME"] = "~/stanford_corenlp/stanford-corenlp-full-2018-10-05" 117 | 118 | try: 119 | with CoreNLPClient(annotators=['tokenize','lemma'], timeout=30000, max_char_length=100000000, be_quiet=True, start_server=False) as client: 120 | tree, para_end_sentences, lemma_dict = sentencize(header_annot_dir, client, book_id) 121 | 122 | tree2 = paragraphize(tree, para_end_sentences) 123 | 124 | filename = os.path.join(tree_dir, book_id + '.xml') 125 | tree2.write(filename, pretty_print=True) 126 | 127 | with open(os.path.join(lemma_dir, book_id + '.pkl'), 'wb') as f: 128 | pickle.dump(lemma_dict, f) 129 | except Exception as e: 130 | print(book_id, e) 131 | return book_id, e 132 | 133 | print(book_id, 'Success!') 134 | return book_id, 'Success' 135 | 136 | 137 | if __name__ == "__main__": 138 | 139 | # Use appropriate locations 140 | header_annot_dir = 'annot_header_dir/' 141 | 142 | lemma_dir = 'lemmas/' 143 | tree_dir = 'sentencized/' 144 | 145 | if not os.path.exists(lemma_dir): 146 | os.makedirs(lemma_dir) 147 | if not os.path.exists(tree_dir): 148 | os.makedirs(tree_dir) 149 | 150 | with open('train_book_ids.txt', 'r') as f: 151 | books = f.read().splitlines() 152 | 153 | with open('test_book_ids_seg.txt', 'r') as f: 154 | books += f.read().splitlines() 155 | 156 | func = partial(process_book, header_annot_dir, lemma_dir, tree_dir) 157 | 158 | pool = multiprocessing.Pool(processes=50) 159 | data = pool.map(func, books) 160 | pool.close() 161 | pool.join() 162 | 163 | df = pd.DataFrame(data) 164 | df.rename(columns={0:'book_id', 1:'status'}, inplace=True) 165 | df.to_csv('./log_file.csv', index=False) 166 | 167 | print('Done!') 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /segmentation/weighted_overlap/compute_densities.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import pickle 4 | from functools import partial 5 | import multiprocessing 6 | import pandas as pd 7 | from nltk.corpus import stopwords 8 | 9 | def build_graph(lemma_dict, N): 10 | edges = list() 11 | for idx in range(max(lemma_dict.keys()) + 1): 12 | for n in range(1, 1 + N): 13 | if idx + n not in lemma_dict: 14 | continue 15 | common_lemmas = lemma_dict[idx].intersection(lemma_dict[idx + n]) 16 | new_common_lemmas = set() 17 | for x in common_lemmas: 18 | if x not in "!\"#$%&'()*+, -./:;<=>?@[\]^_`{|}~": 19 | new_common_lemmas.add(x) 20 | common_lemmas = new_common_lemmas 21 | for i in range(len(common_lemmas)): 22 | edges.append([idx, idx + n]) 23 | return edges 24 | 25 | def compute_density(edges, max_sent_num): 26 | density = {x: 0 for x in range(0, max_sent_num)} 27 | for x, y in edges: 28 | for i in range(x, y): 29 | left_dist = i - x + 1 30 | right_dist = y - i 31 | density[i] += 1 / (left_dist + right_dist) 32 | return dict(density) 33 | 34 | def process_book(lemma_dir, output_dir, N, book_id): 35 | with open(os.path.join(lemma_dir, book_id + '.pkl'), 'rb') as f: 36 | lemmas = pickle.load(f) 37 | 38 | for k in lemmas: 39 | lemmas[k] = set(lemmas[k]) 40 | lemmas[k] = lemmas[k].difference(stop_words) 41 | 42 | edges = build_graph(lemmas, N) 43 | 44 | density = compute_density(edges, max(lemmas.keys())) 45 | # Save density to pkl file 46 | with open(os.path.join(output_dir, book_id + '.pkl'), 'wb') as f: 47 | pickle.dump(density, f) 48 | print(book_id, 'success') 49 | return book_id, 'Success' 50 | 51 | 52 | if __name__ == "__main__": 53 | lemma_dir = 'use_books_lemmas/' 54 | output_dir = 'window_size_200/test_books_densities/' 55 | if not os.path.exists(lemma_dir): 56 | print('Invalid lemma dir') 57 | exit() 58 | 59 | if not os.path.exists(output_dir): 60 | os.makedirs(output_dir) 61 | 62 | test_books_list_file = 'test_books.txt' 63 | with open(test_books_list_file, 'r') as f: 64 | test_book_ids = [x.strip() for x in f.readlines()] 65 | 66 | print(len(test_book_ids), 'books') 67 | 68 | stop_words = set(stopwords.words('english')) 69 | 70 | func = partial(process_book, lemma_dir, output_dir, 200) 71 | 72 | pool = multiprocessing.Pool(processes=48) 73 | data = pool.map(func, test_book_ids) 74 | pool.close() 75 | pool.join() 76 | 77 | df = pd.DataFrame(data) 78 | df.rename(columns={0:'book_id', 1:'status'}, inplace=True) 79 | df.to_csv('./log_file.csv', index=False) 80 | print('Done!') 81 | -------------------------------------------------------------------------------- /segmentation/weighted_overlap/compute_peaks_prominences.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import pickle 4 | import lxml.etree as ET 5 | import pandas as pd 6 | from scipy import signal 7 | 8 | from functools import partial 9 | import multiprocessing 10 | 11 | 12 | def process_book(para_to_sent_dir, density_dir, output_dir, book_id): 13 | 14 | 15 | # Read densities file 16 | with open(os.path.join(density_dir, book_id + '.pkl'), 'rb') as f: 17 | densities = pickle.load(f) 18 | 19 | # Read para_to_sent file 20 | with open(os.path.join(para_to_sent_dir, book_id + '.pkl'), 'rb') as f: 21 | para_to_sent = pickle.load(f) 22 | 23 | # Get valid sentence numbers (that come at ends of paragraphs) 24 | valid_sent_nums = list(para_to_sent.values()) 25 | 26 | # Corresponding densities 27 | valid_densities = [densities[x] for x in sorted(valid_sent_nums[:-1])] 28 | 29 | # Get peak indices and prominences 30 | peaks, _ = signal.find_peaks([-x for x in valid_densities]) 31 | prominences = signal.peak_prominences([-x for x in valid_densities], peaks)[0] 32 | 33 | # Get sentence numbers corresponding to peak indices 34 | peak_sent_nums = [valid_sent_nums[idx] for idx in peaks] 35 | peak_sent_proms = prominences 36 | 37 | with open(os.path.join(output_dir, book_id + '.pkl'), 'wb') as f: 38 | pickle.dump([list(peak_sent_nums), list(peak_sent_proms)], f) 39 | 40 | print(book_id, ' success!') 41 | return book_id, 'Success' 42 | 43 | 44 | 45 | if __name__ == "__main__": 46 | para_to_sent_dir = 'test_books_para_to_sent/' 47 | density_dir = 'window_size_200/test_books_densities/' 48 | 49 | output_dir = 'window_size_200/test_books_peaks_proms/' 50 | 51 | if not os.path.exists(density_dir): 52 | print('Invalid lemma dir') 53 | exit() 54 | 55 | if not os.path.exists(output_dir): 56 | os.makedirs(output_dir) 57 | 58 | test_books_list_file = 'test_books.txt' 59 | with open(test_books_list_file, 'r') as f: 60 | test_book_ids = [x.strip() for x in f.readlines()] 61 | 62 | print(len(test_book_ids), 'books') 63 | 64 | func = partial(process_book, para_to_sent_dir, density_dir, output_dir) 65 | 66 | pool = multiprocessing.Pool(processes=32) 67 | data = pool.map(func, test_book_ids) 68 | pool.close() 69 | pool.join() 70 | 71 | df = pd.DataFrame(data) 72 | df.rename(columns={0:'book_id', 1:'status'}, inplace=True) 73 | df.to_csv('./log_file_peaks.csv', index=False) 74 | print('Done!') 75 | -------------------------------------------------------------------------------- /segmentation/weighted_overlap/get_predictions_dp.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import pickle 4 | import pandas as pd 5 | 6 | from functools import partial 7 | import multiprocessing 8 | 9 | 10 | class DPSolver(): 11 | def __init__(self, peaks, prominences, num_sentences, breaks_to_insert, alpha): 12 | self.peaks = list(peaks) 13 | self.peaks.append(num_sentences) 14 | self.prominences = prominences 15 | self.prominences.append(0) 16 | self.num_sentences = num_sentences 17 | self.k = breaks_to_insert 18 | self.dp_dict = dict() 19 | self.N = len(self.peaks) - 1 20 | self.ideal_length = num_sentences / (self.k + 1) 21 | self.prev_dict = dict() 22 | self.alpha = alpha 23 | 24 | def get_distance(self, idx1, idx2): 25 | if idx1 == -1: 26 | return self.peaks[idx2] / self.ideal_length 27 | sent_1 = self.peaks[idx1] 28 | sent_2 = self.peaks[idx2] 29 | return (abs((sent_2 - sent_1) - self.ideal_length)) / self.ideal_length 30 | 31 | def get_mins(self, costs): 32 | keys = list(costs.keys()) 33 | min_key = keys[0] 34 | min_val = costs[keys[0]] 35 | for k in keys[1:]: 36 | if costs[k] < min_val: 37 | min_val = costs[k] 38 | min_key = k 39 | return min_key, min_val 40 | 41 | def dp_func(self, N, k): 42 | if k > N: 43 | return None 44 | # Memoized 45 | if (N, k) in self.dp_dict: 46 | return self.dp_dict[(N, k)] 47 | # Base case 48 | if k == 0: 49 | self.dp_dict[(N, k)] = -(self.prominences[N] * self.alpha) + (self.get_distance(-1, N) * (1 - self.alpha)) 50 | return self.dp_dict[(N, k)] 51 | 52 | # Recursive call 53 | costs = dict() 54 | for i in range(0, N): 55 | c = self.dp_func(i, k - 1) 56 | if c: 57 | costs[i] = c + (self.get_distance(i, N) * (1 - self.alpha)) 58 | if len(costs) == 0: 59 | self.dp_dict[(N, k)] = None 60 | return None 61 | 62 | min_N, min_cost = self.get_mins(costs) 63 | 64 | ans = min_cost - (self.prominences[N] * self.alpha) 65 | self.dp_dict[(N, k)] = ans 66 | self.prev_dict[(N, k)] = min_N 67 | return ans 68 | 69 | def solve(self): 70 | x = self.dp_func(self.N, self.k) 71 | return x 72 | 73 | def get_best_sequence(self): 74 | x = self.solve() 75 | ans_seq = list() 76 | N = self.N 77 | k = self.k 78 | while True: 79 | if (N, k) not in self.prev_dict: 80 | break 81 | previous = self.prev_dict[(N, k)] 82 | ans_seq.append(previous) 83 | N = previous 84 | k -= 1 85 | return ans_seq[::-1] 86 | 87 | 88 | 89 | def get_predictions(peaks, prominences, num_preds, max_sent_num, alpha): 90 | 91 | dps = DPSolver(peaks, prominences, max_sent_num + 1, num_preds, alpha) 92 | preds = dps.get_best_sequence() 93 | dp_predictions = [peaks[x] for x in preds] 94 | return dp_predictions 95 | 96 | def process_book(peaks_dir, density_dir, gt_dir, output_dir, book_id): 97 | if os.path.exists(os.path.join(output_dir, book_id + '_alpha_100.pkl')): 98 | return book_id, 'exists' 99 | try: 100 | with open(os.path.join(peaks_dir, book_id + '.pkl'), 'rb') as f: 101 | peaks, prominences = pickle.load(f) 102 | 103 | with open(os.path.join(density_dir, book_id + '.pkl'), 'rb') as f: 104 | density = pickle.load(f) 105 | 106 | if len(density) < 2: 107 | print(book_id, 'empty') 108 | return book_id, 'empty' 109 | 110 | maximum, minimum = max(prominences), min(prominences) 111 | denom = maximum - minimum 112 | prominences = [(x - minimum) / denom for x in prominences] 113 | 114 | with open(os.path.join(gt_dir, book_id + '_gt_sents.pkl'), 'rb') as f: 115 | gt = pickle.load(f) 116 | with open(os.path.join(gt_dir, book_id + '_max_sent_num.pkl'), 'rb') as f: 117 | max_sent_num = int(pickle.load(f)) 118 | 119 | 120 | num_preds = len(gt) 121 | 122 | for alpha in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]: 123 | preds = get_predictions(peaks, prominences, num_preds, max_sent_num, alpha) 124 | with open(os.path.join(output_dir, book_id + '_alpha_' + str(int(alpha * 100)) + '.pkl'), 'wb') as f: 125 | pickle.dump(preds, f) 126 | 127 | print(book_id, 'success') 128 | return book_id, 'Success' 129 | except Exception as e: 130 | print(book_id, 'failed') 131 | return book_id, e 132 | 133 | if __name__ == "__main__": 134 | 135 | test_books_list_file = 'test_books.txt' 136 | with open(test_books_list_file, 'r') as f: 137 | test_book_ids = [x.strip() for x in f.readlines()] 138 | 139 | 140 | window_size = 150 141 | 142 | # Use appropriate locations 143 | peaks_dir = 'window_size_' + str(window_size) + '/test_books_peaks_proms/' 144 | density_dir = 'window_size_' + str(window_size) + '/test_books_densities/' 145 | gt_dir = 'test_gt_sentences/' 146 | output_dir = 'window_size_' + str(window_size) + '/dp/' 147 | 148 | if not os.path.exists(peaks_dir): 149 | print('Invalid peaks dir') 150 | exit() 151 | 152 | if not os.path.exists(gt_dir): 153 | print('Invalid ground truth dir') 154 | exit() 155 | 156 | if not os.path.exists(output_dir): 157 | os.makedirs(output_dir) 158 | 159 | func = partial(process_book, peaks_dir, density_dir, gt_dir, output_dir) 160 | 161 | pool = multiprocessing.Pool(processes=48) 162 | data = pool.map(func, test_book_ids) 163 | pool.close() 164 | pool.join() 165 | 166 | print('Done!') 167 | --------------------------------------------------------------------------------