├── README.md ├── conll2003.py ├── conll2003dep.py ├── conll2003dep ├── parse.sh └── task.pbtxt ├── dependency_utils.py ├── evaluate.py ├── load_conll_2012 ├── coreference_reading.py ├── head_finder.py ├── load_conll.py ├── parse_errors.py ├── pstree.py └── treebanks.py ├── ontochinese.py ├── ontochinese ├── ne.txt └── pos.txt ├── ontonotes.py ├── ontonotes ├── ne.txt └── pos.txt └── rnn.py /README.md: -------------------------------------------------------------------------------- 1 | # TF_RNN 2 | This repository contains a special Bidirectional Recursive Neural Network implemented with Tensorflow described in [1](#leveraging-linguistic-structures-for-named-entity-recognition-with-bidirectional-recursive-neural-networks). 3 | ```python 4 | rnn.py # containing the RNN model class 5 | evaluate.py # the training and testing script 6 | ontonotes.py # utilities to extract the OntoNotes 5.0 dataset 7 | ``` 8 | 9 | ## How to set up the OntoNotes 5.0 dataset 10 | ### 1. Get data 11 | 12 | Download OntoNotes 5.0 from CoNLL-2012 website. 13 | 14 | Download SENNA from Collobert's website. 15 | 16 | Set their custom paths in ontonotes.py 17 | ```python 18 | data_path_prefix = "/home/danniel/Desktop/CONLL2012-intern/conll-2012/v4/data" 19 | test_auto_data_path_prefix = "/home/danniel/Downloads/wu_conll_test/v9/data" 20 | senna_path = "/home/danniel/Downloads/senna/hash" 21 | ``` 22 | 23 | ### 2. Get the load data helpers 24 | 25 | The "load_conll_2012/" directory contains libraries to read the CoNLL-2012 format of OntoNotes. They are provided by Jheng-Long Wu (jlwu@iis.sinica.edu.tw) and Canasai (https://github.com/canasai/mps). 26 | 27 | Set the custom path to import them in ontonotes.py 28 | ```python 29 | sys.path.append("/home/danniel/Desktop/CONLL2012-intern") 30 | from load_conll import load_data 31 | from pstree import PSTree 32 | ``` 33 | 34 | ### 3. Get pre-trained GloVe embeddings 35 | 36 | Download them from the GloVe website. 37 | 38 | Set the custom path in ontonotes.py 39 | ```python 40 | glove_file = "/home/danniel/Downloads/glove.840B.300d.txt" 41 | ``` 42 | 43 | ### 4. Extract the alphabet, vocabulary, and embeddings 44 | 45 | Modify and run ontonotes.py 46 | ```python 47 | if __name__ == "__main__": 48 | extract_vocabulary_and_alphabet() 49 | extract_glove_embeddings() 50 | # read_dataset() 51 | exit() 52 | ``` 53 | 54 | ## How to train and test 55 | ### 1. Train a model on OntoNotes 5.0. 56 | 57 | ``` 58 | python evaluate.py 2> tmp.txt 59 | ``` 60 | This generates model files tmp.model.* 61 | 62 | ### 2. Test the model on the test split of OntoNotes 5.0 63 | 64 | ``` 65 | python evaluate.py -m evaluate -s test 2> tmp.txt 66 | ``` 67 | 68 | ### 3. Options 69 | 70 | To see all options, run 71 | ``` 72 | python evaluate.py -h 73 | ``` 74 | 75 | ## References 76 | The high-level description of the project and the evaluation results can be found in [1](#leveraging-linguistic-structures-for-named-entity-recognition-with-bidirectional-recursive-neural-networks). 77 | 78 | [1] PH Li, RP Dong, YS Wang, JC Chou, WY Ma, [*Leveraging Linguistic Structures for Named Entity Recognition with Bidirectional Recursive Neural Networks*](https://www.aclweb.org/anthology/D17-1282/) 79 | 80 | ``` 81 | @InProceedings{li-EtAl:2017:EMNLP20177, 82 | author = {Li, Peng-Hsuan and Dong, Ruo-Ping and Wang, Yu-Siang and Chou, Ju-Chieh and Ma, Wei-Yun}, 83 | title = {Leveraging Linguistic Structures for Named Entity Recognition with Bidirectional Recursive Neural Networks}, 84 | booktitle = {Proceedings of the 2017 Conference on Empirical Methods in Natural Language Processing}, 85 | year = {2017}, 86 | publisher = {Association for Computational Linguistics}, 87 | pages = {2654--2659} 88 | } 89 | ``` 90 | -------------------------------------------------------------------------------- /conll2003.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import sys 4 | import codecs 5 | import subprocess 6 | from collections import defaultdict 7 | 8 | import numpy as np 9 | 10 | sys.path.append("/home/danniel/Desktop/CONLL2012-intern") 11 | import pstree 12 | import head_finder 13 | 14 | from rnn import Node 15 | 16 | dataset = "conll2003" 17 | character_file = os.path.join(dataset, "character.txt") 18 | word_file = os.path.join(dataset, "word.txt") 19 | pos_file = os.path.join(dataset, "pos.txt") 20 | ne_file = os.path.join(dataset, "ne.txt") 21 | pretrained_word_file = os.path.join(dataset, "word.npy") 22 | pretrained_embedding_file = os.path.join(dataset, "embedding.npy") 23 | 24 | senna_path = "/home/danniel/Downloads/senna/hash" 25 | lexicon_meta_list = [ 26 | {"ne": "PER", "path": os.path.join(dataset, "senna_per.txt"), "senna": os.path.join(senna_path, "ner.per.lst")}, 27 | {"ne": "ORG", "path": os.path.join(dataset, "senna_org.txt"), "senna": os.path.join(senna_path, "ner.org.lst")}, 28 | {"ne": "LOC", "path": os.path.join(dataset, "senna_loc.txt"), "senna": os.path.join(senna_path, "ner.loc.lst")}, 29 | {"ne": "MISC", "path": os.path.join(dataset, "senna_misc.txt"), "senna": os.path.join(senna_path, "ner.misc.lst")}] 30 | 31 | data_path = "/home/danniel/Downloads/ner" 32 | glove_file = "/home/danniel/Downloads/glove.840B.300d.txt" 33 | parser_classpath = "/home/danniel/Downloads/stanford-parser-full-2015-12-09/*:" 34 | 35 | split_raw = {"train": "eng.train", "validate": "eng.testa", "test": "eng.testb"} 36 | split_sentence = {"train": "sentence_train.txt", "validate": "sentence_validate.txt", "test": "sentence_test.txt"} 37 | split_parse = {"train": "parse_train.txt", "validate": "parse_validate.txt", "test": "parse_test.txt"} 38 | 39 | def log(msg): 40 | sys.stdout.write(msg) 41 | sys.stdout.flush() 42 | return 43 | 44 | def read_list_file(file_path, encoding="utf8"): 45 | log("Read %s..." % file_path) 46 | 47 | with codecs.open(file_path, "r", encoding=encoding) as f: 48 | line_list = f.read().splitlines() 49 | line_to_index = {line: index for index, line in enumerate(line_list)} 50 | 51 | log(" %d lines\n" % len(line_to_index)) 52 | return line_list, line_to_index 53 | 54 | def group_sequential_label(seq_ne_list): 55 | span_ne_dict = {} 56 | 57 | start, ne = -1, None 58 | for index, label in enumerate(seq_ne_list + ["O"]): 59 | if (label[0]=="O" or label[0]=="B") and ne: 60 | span_ne_dict[(start, index)] = ne 61 | start, ne = -1, None 62 | 63 | if label[0]=="B" or (label[0]=="I" and not ne): 64 | start, ne = index, label[2:] 65 | 66 | return span_ne_dict 67 | 68 | def extract_ner(split): 69 | #with open("tmp.txt", "r") as f: 70 | with open(os.path.join(data_path, split_raw[split]), "r") as f: 71 | line_list = f.read().splitlines() 72 | 73 | sentence_list = [] 74 | ner_list = [] 75 | 76 | sentence = [] 77 | ner = [] 78 | for line in line_list[2:]: 79 | if line[:10] == "-DOCSTART-": continue 80 | if not line: 81 | if sentence: 82 | sentence_list.append(sentence) 83 | ner_list.append(group_sequential_label(ner)) 84 | sentence = [] 85 | ner = [] 86 | continue 87 | word, _, _, sequential_label = line.split() 88 | sentence.append(word) 89 | ner.append(sequential_label) 90 | """ 91 | for i, j in enumerate(sentence_list): 92 | print "" 93 | print j 94 | print ner_list[i] 95 | """ 96 | return sentence_list, ner_list 97 | """ 98 | def get_parse_tree(parse_string, pos_set=None): 99 | node = Node() 100 | 101 | # get POS 102 | header, parse_string = parse_string.split(" ", 1) 103 | node.pos = header[1:] 104 | if pos_set is not None: pos_set.add(node.pos) 105 | 106 | # bottom condition: hit a word 107 | if parse_string[0] != "(": 108 | node.word, parse_string = parse_string.split(")", 1) 109 | return node, parse_string 110 | node.word = None 111 | #node.word = "" 112 | 113 | # Process children 114 | while True: 115 | child, parse_string = get_parse_tree(parse_string, pos_set) 116 | node.add_child(child) 117 | delimiter, parse_string = parse_string[0], parse_string[1:] 118 | if delimiter == ")": 119 | return node, parse_string 120 | 121 | def print_parse_tree(node, indent): 122 | print indent, node.pos, node.word 123 | for child in node.child_list: 124 | print_parse_tree(child, indent+" ") 125 | return 126 | """ 127 | def extract_pos_from_pstree(tree, pos_set): 128 | pos_set.add(tree.label) 129 | 130 | for child in tree.subtrees: 131 | extract_pos_from_pstree(child, pos_set) 132 | return 133 | 134 | def print_pstree(node, indent): 135 | word = node.word if node.word else "" 136 | print indent + node.label + " "+ word 137 | 138 | for child in node.subtrees: 139 | print_pstree(child, indent+" ") 140 | return 141 | 142 | def prepare_dataset(): 143 | ne_set = set() 144 | word_set = set() 145 | character_set = set() 146 | pos_set = set() 147 | 148 | for split in split_raw: 149 | sentence_list, ner_list = extract_ner(split) 150 | 151 | # Procecss raw NER 152 | for ner in ner_list: 153 | for ne in ner.itervalues(): 154 | ne_set.add(ne) 155 | 156 | # Procecss raw sentences 157 | split_sentence_file = os.path.join(dataset, split_sentence[split]) 158 | with open(split_sentence_file, "w") as f: 159 | for sentence in sentence_list: 160 | f.write(" ".join(sentence)+"\n") 161 | for word in sentence: 162 | word_set.add(word) 163 | for character in word: 164 | character_set.add(character) 165 | word_set |= set(["``", "''", "-LSB-", "-RSB-", 166 | "-LRB-", "25.49,-LRB-3-yr", "6-7-LRB-3-7", "Videoton-LRB-*", 167 | "-RRB-", "1-RRB-266", "12.177-RRB-.", "53.04-RRB-.", "Austria-RRB-118"]) 168 | 169 | split_parse_file = os.path.join(dataset, split_parse[split]) 170 | 171 | # Generate parses 172 | with open(split_parse_file, "w") as f: 173 | subprocess.call(["java", "-cp", parser_classpath, 174 | "edu.stanford.nlp.parser.lexparser.LexicalizedParser", 175 | "-outputFormat", "oneline", "-sentences", "newline", "-tokenized", 176 | "-escaper", "edu.stanford.nlp.process.PTBEscapingProcessor", 177 | "edu/stanford/nlp/models/lexparser/englishRNN.ser.gz", 178 | split_sentence_file], stdout=f) 179 | 180 | # Process parses 181 | with open(split_parse_file, "r") as f: 182 | line_list = f.read().splitlines() 183 | for line in line_list: 184 | #get_parse_tree(line, pos_set) 185 | tree = pstree.tree_from_text(line) 186 | extract_pos_from_pstree(tree, pos_set) 187 | 188 | with open(ne_file, "w") as f: 189 | for ne in sorted(ne_set): 190 | f.write(ne + '\n') 191 | 192 | with open(word_file, "w") as f: 193 | for word in sorted(word_set): 194 | f.write(word + '\n') 195 | 196 | with open(character_file, "w") as f: 197 | for character in sorted(character_set): 198 | f.write(character + '\n') 199 | 200 | with open(pos_file, "w") as f: 201 | for pos in sorted(pos_set): 202 | f.write(pos + '\n') 203 | return 204 | 205 | def extract_glove_embeddings(): 206 | log("extract_glove_embeddings()...") 207 | 208 | _, word_to_index = read_list_file(word_file) 209 | word_list = [] 210 | embedding_list = [] 211 | with open(glove_file, "r") as f: 212 | for line in f: 213 | line = line.strip().split() 214 | word = line[0] 215 | if word not in word_to_index: continue 216 | embedding = np.array([float(i) for i in line[1:]]) 217 | word_list.append(word) 218 | embedding_list.append(embedding) 219 | 220 | np.save(pretrained_word_file, word_list) 221 | np.save(pretrained_embedding_file, embedding_list) 222 | 223 | log(" %d pre-trained words\n" % len(word_list)) 224 | return 225 | 226 | def construct_node(node, tree, ner_raw_data, head_raw_data, text_raw_data, 227 | character_to_index, word_to_index, pos_to_index, index_to_lexicon, 228 | pos_count, ne_count, pos_ne_count, lexicon_hits, span_to_node): 229 | pos = tree.label 230 | word = tree.word 231 | span = tree.span 232 | head = tree.head if hasattr(tree, "head") else head_raw_data[(span, pos)][1] 233 | ne = ner_raw_data[span] if span in ner_raw_data else "NONE" 234 | constituent = " ".join(text_raw_data[span[0]:span[1]]).lower() 235 | 236 | # Process pos info 237 | node.pos_index = pos_to_index[pos] 238 | pos_count[pos] += 1 239 | 240 | # Process word info 241 | node.word_split = [character_to_index[character] for character in word] if word else [] 242 | node.word_index = word_to_index[word] if word else -1 243 | 244 | # Process head info 245 | node.head_split = [character_to_index[character] for character in head] 246 | #if head == "-LSB-": print text_raw_data 247 | node.head_index = word_to_index[head] 248 | 249 | # Process ne info 250 | node.ne = ne 251 | if ne != "NONE": 252 | if not node.parent or node.parent.span!=span: 253 | ne_count[ne] += 1 254 | pos_ne_count[pos] += 1 255 | 256 | # Process span info 257 | node.span = span 258 | span_to_node[span] = node 259 | 260 | # Process lexicon info 261 | node.lexicon_hit = [0] * len(index_to_lexicon) 262 | hits = 0 263 | for index, lexicon in index_to_lexicon.iteritems(): 264 | if constituent in lexicon: 265 | node.lexicon_hit[index] = 1 266 | hits = 1 267 | lexicon_hits[0] += hits 268 | 269 | # Binarize children 270 | if len(tree.subtrees) > 2: 271 | side_child_pos = tree.subtrees[-1].label 272 | side_child_span = tree.subtrees[-1].span 273 | side_child_head = head_raw_data[(side_child_span, side_child_pos)][1] 274 | if side_child_head != head: 275 | sub_subtrees = tree.subtrees[:-1] 276 | else: 277 | sub_subtrees = tree.subtrees[1:] 278 | new_span = (sub_subtrees[0].span[0], sub_subtrees[-1].span[1]) 279 | new_tree = pstree.PSTree(label=pos, span=new_span, subtrees=sub_subtrees) 280 | new_tree.head = head 281 | if side_child_head != head: 282 | tree.subtrees = [new_tree, tree.subtrees[-1]] 283 | else: 284 | tree.subtrees = [tree.subtrees[0], new_tree] 285 | 286 | # Process children 287 | nodes = 1 288 | for subtree in tree.subtrees: 289 | child = Node() 290 | node.add_child(child) 291 | child_nodes = construct_node(child, subtree, ner_raw_data, head_raw_data, text_raw_data, 292 | character_to_index, word_to_index, pos_to_index, index_to_lexicon, 293 | pos_count, ne_count, pos_ne_count, lexicon_hits, span_to_node) 294 | nodes += child_nodes 295 | return nodes 296 | 297 | def create_dense_nodes(ner_raw_data, text_raw_data, pos_to_index, index_to_lexicon, 298 | pos_count, ne_count, pos_ne_count, lexicon_hits, span_to_node): 299 | node_list = [] 300 | max_dense_span = 3 301 | # Start from bigram, since all unigrams are already covered by parses 302 | for span_length in range(2, 1+max_dense_span): 303 | for span_start in range(0, 1+len(text_raw_data)-span_length): 304 | span = (span_start, span_start+span_length) 305 | if span in span_to_node: continue 306 | pos = "NONE" 307 | ne = ner_raw_data[span] if span in ner_raw_data else "NONE" 308 | constituent = " ".join(text_raw_data[span[0]:span[1]]).lower() 309 | 310 | # span, child 311 | # TODO: sibling 312 | node = Node() 313 | node_list.append(node) 314 | node.span = span 315 | span_to_node[span] = node 316 | node.child_list = [span_to_node[(span[0],span[1]-1)], span_to_node[(span[0]+1,span[1])]] 317 | 318 | # word, head, pos 319 | node.pos_index = pos_to_index[pos] 320 | pos_count[pos] += 1 321 | node.word_split = [] 322 | node.word_index = -1 323 | node.head_split = [] 324 | node.head_index = -1 325 | 326 | # ne 327 | node.ne = ne 328 | if ne != "NONE": 329 | ne_count[ne] += 1 330 | pos_ne_count[pos] += 1 331 | 332 | # lexicon 333 | node.lexicon_hit = [0] * len(index_to_lexicon) 334 | hits = 0 335 | for index, lexicon in index_to_lexicon.iteritems(): 336 | if constituent in lexicon: 337 | node.lexicon_hit[index] = 1 338 | hits = 1 339 | lexicon_hits[0] += hits 340 | 341 | return node_list 342 | 343 | def get_tree_data(sentence_list, parse_list, ner_list, 344 | character_to_index, word_to_index, pos_to_index, index_to_lexicon): 345 | log("get_tree_data()...") 346 | """ Get tree structured data from CoNLL-2003 347 | 348 | Stores into Node data structure 349 | """ 350 | tree_pyramid_list = [] 351 | word_count = 0 352 | pos_count = defaultdict(lambda: 0) 353 | ne_count = defaultdict(lambda: 0) 354 | pos_ne_count = defaultdict(lambda: 0) 355 | lexicon_hits = [0] 356 | 357 | for index, parse in enumerate(parse_list): 358 | text_raw_data = sentence_list[index] 359 | word_count += len(text_raw_data) 360 | span_to_node = {} 361 | head_raw_data = head_finder.collins_find_heads(parse) 362 | 363 | root_node = Node() 364 | nodes = construct_node( 365 | root_node, parse, ner_list[index], head_raw_data, text_raw_data, 366 | character_to_index, word_to_index, pos_to_index, index_to_lexicon, 367 | pos_count, ne_count, pos_ne_count, lexicon_hits, span_to_node) 368 | root_node.nodes = nodes 369 | root_node.tokens = len(text_raw_data) 370 | 371 | additional_node_list = create_dense_nodes( 372 | ner_list[index], text_raw_data, 373 | pos_to_index, index_to_lexicon, 374 | pos_count, ne_count, pos_ne_count, lexicon_hits, span_to_node) 375 | 376 | tree_pyramid_list.append((root_node, additional_node_list)) 377 | 378 | log(" %d sentences\n" % len(tree_pyramid_list)) 379 | return tree_pyramid_list, word_count, pos_count, ne_count, pos_ne_count, lexicon_hits[0] 380 | 381 | def label_tree_data(node, pos_to_index, ne_to_index): 382 | node.y = ne_to_index[node.ne] 383 | # node.y = ne_to_index[":".join(node.ner)] 384 | 385 | for child in node.child_list: 386 | label_tree_data(child, pos_to_index, ne_to_index) 387 | return 388 | 389 | def read_dataset(data_split_list = ["train", "validate", "test"]): 390 | # Read all raw data 391 | sentence_data = {} 392 | ner_data = {} 393 | parse_data = {} 394 | for split in data_split_list: 395 | sentence_data[split], ner_data[split] = extract_ner(split) 396 | 397 | split_parse_file = os.path.join(dataset, split_parse[split]) 398 | with open(split_parse_file, "r") as f: 399 | line_list = f.read().splitlines() 400 | parse_data[split] = [pstree.tree_from_text(line) for line in line_list] 401 | 402 | # Read lists of annotations 403 | character_list, character_to_index = read_list_file(character_file) 404 | word_list, word_to_index = read_list_file(word_file) 405 | pos_list, pos_to_index = read_list_file(pos_file) 406 | ne_list, ne_to_index = read_list_file(ne_file) 407 | 408 | pos_to_index["NONE"] = len(pos_to_index) 409 | 410 | # Read lexicon 411 | index_to_lexicon = {} 412 | for index, meta in enumerate(lexicon_meta_list): 413 | _, index_to_lexicon[index] = read_list_file(meta["senna"], "iso8859-15") 414 | 415 | # Build a tree structure for each sentence 416 | data = {} 417 | word_count = {} 418 | pos_count = {} 419 | ne_count = {} 420 | pos_ne_count = {} 421 | lexicon_hits = {} 422 | for split in data_split_list: 423 | (tree_pyramid_list, 424 | word_count[split], pos_count[split], ne_count[split], pos_ne_count[split], 425 | lexicon_hits[split]) = get_tree_data( 426 | sentence_data[split], parse_data[split], ner_data[split], 427 | character_to_index, word_to_index, pos_to_index, index_to_lexicon) 428 | data[split] = {"tree_pyramid_list": tree_pyramid_list, "ner_list": ner_data[split]} 429 | 430 | # Show statistics of each data split 431 | print "-" * 80 432 | print "%10s%10s%9s%9s%7s%12s%13s" % ("split", "sentence", "token", "node", "NE", "spanned_NE", 433 | "lexicon_hit") 434 | print "-" * 80 435 | for split in data_split_list: 436 | print "%10s%10d%9d%9d%7d%12d%13d" % (split, 437 | len(data[split]["tree_pyramid_list"]), 438 | word_count[split], 439 | sum(pos_count[split].itervalues()), 440 | sum(len(ner) for ner in data[split]["ner_list"]), 441 | sum(ne_count[split].itervalues()), 442 | lexicon_hits[split]) 443 | 444 | # Show POS distribution 445 | total_pos_count = defaultdict(lambda: 0) 446 | for split in data_split_list: 447 | for pos in pos_count[split]: 448 | total_pos_count[pos] += pos_count[split][pos] 449 | nodes = sum(total_pos_count.itervalues()) 450 | print "\nTotal %d nodes" % nodes 451 | print "-"*80 + "\n POS count ratio\n" + "-"*80 452 | for pos, count in sorted(total_pos_count.iteritems(), key=lambda x: x[1], reverse=True)[:10]: 453 | print "%6s %7d %5.1f%%" % (pos, count, count*100./nodes) 454 | 455 | # Show NE distribution in [train, validate] 456 | total_ne_count = defaultdict(lambda: 0) 457 | for split in data_split_list: 458 | if split == "test": continue 459 | for ne in ne_count[split]: 460 | total_ne_count[ne] += ne_count[split][ne] 461 | nes = sum(total_ne_count.itervalues()) 462 | print "\nTotal %d spanned named entities in [train, validate]" % nes 463 | print "-"*80 + "\n NE count ratio\n" + "-"*80 464 | for ne, count in sorted(total_ne_count.iteritems(), key=lambda x: x[1], reverse=True): 465 | print "%12s %6d %5.1f%%" % (ne, count, count*100./nes) 466 | 467 | # Show POS-NE distribution 468 | total_pos_ne_count = defaultdict(lambda: 0) 469 | for split in data_split_list: 470 | if split == "test": continue 471 | for pos in pos_ne_count[split]: 472 | total_pos_ne_count[pos] += pos_ne_count[split][pos] 473 | print "-"*80 + "\n POS NE total ratio\n" + "-"*80 474 | for pos, count in sorted(total_pos_ne_count.iteritems(), key=lambda x: x[1], reverse=True)[:10]: 475 | total = total_pos_count[pos] 476 | print "%6s %6d %7d %5.1f%%" % (pos, count, total, count*100./total) 477 | 478 | # Compute the mapping to labels 479 | ne_to_index["NONE"] = len(ne_to_index) 480 | 481 | # Add label to nodes 482 | for split in data_split_list: 483 | for tree, pyramid in data[split]["tree_pyramid_list"]: 484 | label_tree_data(tree, pos_to_index, ne_to_index) 485 | for node in pyramid: 486 | node.y = ne_to_index[node.ne] 487 | 488 | return (data, word_list, ne_list, 489 | len(character_to_index), len(pos_to_index), len(ne_to_index), len(index_to_lexicon)) 490 | 491 | if __name__ == "__main__": 492 | #prepare_dataset() 493 | """ 494 | print "" 495 | parse_string = "(ROOT (S (NP (NNP EU)) (VP (VBZ rejects) (NP (JJ German) (NN call)) (PP (TO to) (NP (NN boycott) (JJ British) (NN lamb)))) (. .)))" 496 | root = pstree.tree_from_text(parse_string) 497 | print_pstree(root, "") 498 | print "" 499 | for i, j in head_finder.collins_find_heads(root).iteritems(): print i, j 500 | """ 501 | #extract_glove_embeddings() 502 | read_dataset() 503 | exit() 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | -------------------------------------------------------------------------------- /conll2003dep.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import sys 4 | import time 5 | import codecs 6 | import difflib 7 | import subprocess 8 | from collections import defaultdict 9 | 10 | import numpy as np 11 | 12 | sys.path.append("/home/danniel/Desktop/CONLL2012-intern") 13 | import pstree 14 | import head_finder 15 | 16 | from rnn import Node 17 | import dependency_utils 18 | 19 | dataset = "conll2003dep" 20 | character_file = os.path.join(dataset, "character.txt") 21 | word_file = os.path.join(dataset, "word.txt") 22 | pos_file = os.path.join(dataset, "pos.txt") 23 | ne_file = os.path.join(dataset, "ne.txt") 24 | pretrained_word_file = os.path.join(dataset, "word.npy") 25 | pretrained_embedding_file = os.path.join(dataset, "embedding.npy") 26 | 27 | lexicon_phrase_file = os.path.join(dataset, "lexicon_phrase.npy") 28 | lexicon_embedding_file = os.path.join(dataset, "lexicon_embedding.npy") 29 | senna_path = "/home/danniel/Downloads/senna/hash" 30 | dbpedia_path = "/home/danniel/Desktop/dbpedia_lexicon" 31 | lexicon_meta_list = [ 32 | {"ne": "PER", "encoding": "iso8859-15", "clean": os.path.join(dataset, "senna_per.txt"), "raw": os.path.join(senna_path, "ner.per.lst")}, 33 | {"ne": "ORG", "encoding": "iso8859-15", "clean": os.path.join(dataset, "senna_org.txt"), "raw": os.path.join(senna_path, "ner.org.lst")}, 34 | {"ne": "LOC", "encoding": "iso8859-15", "clean": os.path.join(dataset, "senna_loc.txt"), "raw": os.path.join(senna_path, "ner.loc.lst")}, 35 | {"ne": "MISC", "encoding": "iso8859-15", "clean": os.path.join(dataset, "senna_misc.txt"), "raw": os.path.join(senna_path, "ner.misc.lst")} 36 | #{"ne": "PER", "encoding": "utf8", "clean": os.path.join(dataset, "dbpedia_per.txt"), "raw": os.path.join(dbpedia_path, "dbpedia_person.txt")}, 37 | #{"ne": "ORG", "encoding": "utf8", "clean": os.path.join(dataset, "dbpedia_org.txt"), "raw": os.path.join(dbpedia_path, "dbpedia_organisation.txt")}, 38 | #{"ne": "LOC", "encoding": "utf8", "clean": os.path.join(dataset, "dbpedia_loc.txt"), "raw": os.path.join(dbpedia_path, "dbpedia_place.txt")}, 39 | #{"ne": "MISC", "encoding": "utf8", "clean": os.path.join(dataset, "dbpedia_misc.txt"), "raw": os.path.join(dbpedia_path, "dbpedia_work.txt")} 40 | ] 41 | 42 | project_path = "/home/danniel/Desktop/rnn_ner" 43 | parse_script = os.path.join(project_path, dataset, "parse.sh") 44 | 45 | data_path = "/home/danniel/Downloads/ner" 46 | glove_file = "/home/danniel/Downloads/glove.840B.300d.txt" 47 | syntaxnet_path = "/home/danniel/Downloads/tf_models/syntaxnet" 48 | 49 | split_raw = {"train": "eng.train", "validate": "eng.testa", "test": "eng.testb"} 50 | split_sentence = {"train": "sentence_train.conllu", "validate": "sentence_validate.conllu", "test": "sentence_test.conllu"} 51 | split_dependency = {"train": "dependency_train.conllu", "validate": "dependency_validate.conllu", "test": "dependency_test.conllu"} 52 | split_constituency = {"train": "constituency_train.txt", "validate": "constituency_validate.txt", "test": "constituency_test.txt"} 53 | 54 | def log(msg): 55 | sys.stdout.write(msg) 56 | sys.stdout.flush() 57 | return 58 | 59 | def read_list_file(file_path, encoding="utf8"): 60 | log("Read %s..." % file_path) 61 | 62 | with codecs.open(file_path, "r", encoding=encoding) as f: 63 | line_list = f.read().splitlines() 64 | line_to_index = {line: index for index, line in enumerate(line_list)} 65 | 66 | log(" %d lines\n" % len(line_to_index)) 67 | return line_list, line_to_index 68 | 69 | def group_sequential_label(seq_ne_list): 70 | span_ne_dict = {} 71 | 72 | start, ne = -1, None 73 | for index, label in enumerate(seq_ne_list + ["O"]): 74 | if (label[0]=="O" or label[0]=="B") and ne: 75 | span_ne_dict[(start, index)] = ne 76 | start, ne = -1, None 77 | 78 | if label[0]=="B" or (label[0]=="I" and not ne): 79 | start, ne = index, label[2:] 80 | 81 | return span_ne_dict 82 | 83 | def extract_ner(split): 84 | with open(os.path.join(data_path, split_raw[split]), "r") as f: 85 | line_list = f.read().splitlines() 86 | 87 | sentence_list = [] 88 | ner_list = [] 89 | 90 | sentence = [] 91 | ner = [] 92 | for line in line_list[2:]: 93 | if line[:10] == "-DOCSTART-": continue 94 | if not line: 95 | if sentence: 96 | sentence_list.append(sentence) 97 | ner_list.append(group_sequential_label(ner)) 98 | sentence = [] 99 | ner = [] 100 | continue 101 | word, _, _, sequential_label = line.split() 102 | sentence.append(word) 103 | ner.append(sequential_label) 104 | 105 | return sentence_list, ner_list 106 | """ 107 | def get_parse_tree(parse_string, pos_set=None): 108 | node = Node() 109 | 110 | # get POS 111 | header, parse_string = parse_string.split(" ", 1) 112 | node.pos = header[1:] 113 | if pos_set is not None: pos_set.add(node.pos) 114 | 115 | # bottom condition: hit a word 116 | if parse_string[0] != "(": 117 | node.word, parse_string = parse_string.split(")", 1) 118 | return node, parse_string 119 | node.word = None 120 | #node.word = "" 121 | 122 | # Process children 123 | while True: 124 | child, parse_string = get_parse_tree(parse_string, pos_set) 125 | node.add_child(child) 126 | delimiter, parse_string = parse_string[0], parse_string[1:] 127 | if delimiter == ")": 128 | return node, parse_string 129 | 130 | def print_parse_tree(node, indent): 131 | print indent, node.pos, node.word 132 | for child in node.child_list: 133 | print_parse_tree(child, indent+" ") 134 | return 135 | """ 136 | def extract_pos_from_tree(tree, pos_set): 137 | pos_set.add(tree.pos) 138 | 139 | for child in tree.child_list: 140 | extract_pos_from_tree(child, pos_set) 141 | return 142 | 143 | def prepare_dataset(): 144 | ne_set = set() 145 | word_set = set() 146 | character_set = set() 147 | pos_set = set() 148 | 149 | for split in split_raw: 150 | sentence_list, ner_list = extract_ner(split) 151 | 152 | # Procecss raw NER 153 | for ner in ner_list: 154 | for ne in ner.itervalues(): 155 | ne_set.add(ne) 156 | 157 | # Procecss raw sentences and store into conllu format 158 | sentence_file = os.path.join(dataset, split_sentence[split]) 159 | with open(sentence_file, "w") as f: 160 | for sentence in sentence_list: 161 | f.write("#" + " ".join(sentence) + "\n") 162 | for i, word in enumerate(sentence): 163 | f.write("%d\t"%(i+1) + word + "\t_"*8 + "\n") 164 | word_set.add(word) 165 | for character in word: 166 | character_set.add(character) 167 | f.write("\n") 168 | 169 | # Generate dependency parses 170 | subprocess.call([parse_script, split], cwd=syntaxnet_path) 171 | 172 | # Transform dependency parses to constituency parses 173 | dependency_file = os.path.join(dataset, split_dependency[split]) 174 | dependency_list = dependency_utils.read_conllu(dependency_file) 175 | for dependency_parse in dependency_list: 176 | constituency_parse = dependency_utils.dependency_to_constituency(*dependency_parse) 177 | extract_pos_from_tree(constituency_parse, pos_set) 178 | 179 | with open(ne_file, "w") as f: 180 | for ne in sorted(ne_set): 181 | f.write(ne + '\n') 182 | 183 | with open(word_file, "w") as f: 184 | for word in sorted(word_set): 185 | f.write(word + '\n') 186 | 187 | with open(character_file, "w") as f: 188 | for character in sorted(character_set): 189 | f.write(character + '\n') 190 | 191 | with open(pos_file, "w") as f: 192 | for pos in sorted(pos_set): 193 | f.write(pos + '\n') 194 | return 195 | 196 | def traverse_tree(node, ner_raw_data, text_raw_data, lexicon_list, span_set): 197 | span_set.add(node.span) 198 | node.ne = ner_raw_data[node.span] if node.span in ner_raw_data else "NONE" 199 | node.constituent = " ".join(text_raw_data[node.span[0]:node.span[1]]).lower() 200 | 201 | for index, lexicon in enumerate(lexicon_list): 202 | #if node.constituent in lexicon and node.ne != lexicon_meta_list[index]["ne"]: del lexicon[node.constituent] 203 | #difflib.get_close_matches(node.constituent, lexicon[ne].iterkeys(), 1, 0.8) 204 | #all(difflib.SequenceMatcher(a=node.constituent, b=phrase).ratio() < 0.8 for phrase in lexicon[ne]) 205 | if node.constituent in lexicon: 206 | lexicon[node.constituent][0] += 1 207 | if node.ne == lexicon_meta_list[index]["ne"]: 208 | lexicon[node.constituent][1] += 1 209 | 210 | # Process children 211 | for child in node.child_list: 212 | traverse_tree(child, ner_raw_data, text_raw_data, lexicon_list, span_set) 213 | return 214 | 215 | def traverse_pyramid(ner_raw_data, text_raw_data, lexicon_list, span_set): 216 | max_dense_span = 3 217 | # Start from bigram, since all unigrams are already covered by parses 218 | for span_length in range(2, 1+max_dense_span): 219 | for span_start in range(0, 1+len(text_raw_data)-span_length): 220 | span = (span_start, span_start+span_length) 221 | if span in span_set: continue 222 | ne = ner_raw_data[span] if span in ner_raw_data else "NONE" 223 | constituent = " ".join(text_raw_data[span[0]:span[1]]).lower() 224 | 225 | for index, lexicon in enumerate(lexicon_list): 226 | if constituent in lexicon: 227 | lexicon[constituent][0] += 1 228 | if ne == lexicon_meta_list[index]["ne"]: 229 | lexicon[constituent][1] += 1 230 | return 231 | 232 | def extract_clean_lexicon(): 233 | lexicon_list = [] 234 | 235 | print "\nReading raw lexicons..." 236 | for meta in lexicon_meta_list: 237 | lexicon_list.append(read_list_file(meta["raw"], encoding=meta["encoding"])[1]) 238 | 239 | print "-"*50 + "\n ne phrases shortest\n" + "-"*50 240 | for index, lexicon in enumerate(lexicon_list): 241 | for phrase in lexicon: 242 | lexicon[phrase] = [0.,0.] 243 | shortest_phrase = min(lexicon.iterkeys(), key=lambda phrase: len(phrase)) 244 | print "%5s %8d %s" % (lexicon_meta_list[index]["ne"], len(lexicon), shortest_phrase) 245 | 246 | log("\nReading training data...") 247 | data_split_list = ["train", "validate"] 248 | sentence_data = {} 249 | ner_data = {} 250 | parse_data = {} 251 | for split in data_split_list: 252 | sentence_data[split], ner_data[split] = extract_ner(split) 253 | 254 | dependency_file = os.path.join(dataset, split_dependency[split]) 255 | dependency_parse_list = dependency_utils.read_conllu(dependency_file) 256 | parse_data[split] = [dependency_utils.dependency_to_constituency(*parse) 257 | for parse in dependency_parse_list] 258 | log(" done\n") 259 | 260 | log("\nCleaning lexicon by training data...") 261 | for split in data_split_list: 262 | for index, parse in enumerate(parse_data[split]): 263 | span_set = set() 264 | traverse_tree(parse, ner_data[split][index], sentence_data[split][index], lexicon_list, 265 | span_set) 266 | traverse_pyramid(ner_data[split][index], sentence_data[split][index], lexicon_list, 267 | span_set) 268 | log(" done\n") 269 | 270 | print "-"*50 + "\n ne phrases shortest\n" + "-"*50 271 | for index, lexicon in enumerate(lexicon_list): 272 | for phrase, count in lexicon.items(): 273 | if count[0]>0 and count[1]/count[0]<0.1: 274 | del lexicon[phrase] 275 | shortest_phrase = min(lexicon.iterkeys(), key=lambda phrase: len(phrase)) 276 | print "%5s %8d %s" % (lexicon_meta_list[index]["ne"], len(lexicon), shortest_phrase) 277 | 278 | for index, lexicon in enumerate(lexicon_list): 279 | meta = lexicon_meta_list[index] 280 | with codecs.open(meta["clean"], "w", encoding=meta["encoding"]) as f: 281 | for phrase in sorted(lexicon.iterkeys()): 282 | f.write("%s\n" % phrase) 283 | return 284 | 285 | def extract_lexicon_embeddings(): 286 | log("extract_lexicon_embeddings()...") 287 | 288 | # Read senna lexicon 289 | lexicon = defaultdict(lambda: [0]*len(lexicon_meta_list)) 290 | for index, meta in enumerate(lexicon_meta_list): 291 | for phrase in read_list_file(meta["path"], "iso8859-15")[0]: 292 | lexicon[phrase][index] = 1 293 | 294 | # Create embeddings 295 | phrase_list, embedding_list = zip(*lexicon.iteritems()) 296 | np.save(lexicon_phrase_file, phrase_list) 297 | np.save(lexicon_embedding_file, embedding_list) 298 | 299 | log(" %d phrases in lexicon\n" % len(phrase_list)) 300 | return 301 | 302 | def extract_glove_embeddings(): 303 | log("extract_glove_embeddings()...") 304 | 305 | _, word_to_index = read_list_file(word_file) 306 | word_list = [] 307 | embedding_list = [] 308 | with open(glove_file, "r") as f: 309 | for line in f: 310 | line = line.strip().split() 311 | word = line[0] 312 | if word not in word_to_index: continue 313 | embedding = np.array([float(i) for i in line[1:]]) 314 | word_list.append(word) 315 | embedding_list.append(embedding) 316 | 317 | np.save(pretrained_word_file, word_list) 318 | np.save(pretrained_embedding_file, embedding_list) 319 | 320 | log(" %d pre-trained words\n" % len(word_list)) 321 | return 322 | 323 | def construct_node(node, ner_raw_data, text_raw_data, 324 | character_to_index, word_to_index, pos_to_index, lexicon_list, 325 | pos_count, ne_count, pos_ne_count, lexicon_hits, span_to_node): 326 | pos = node.pos 327 | word = node.word 328 | head = node.head 329 | span = node.span 330 | ne = ner_raw_data[span] if span in ner_raw_data else "NONE" 331 | constituent = " ".join(text_raw_data[span[0]:span[1]]).lower() 332 | 333 | # Process pos info 334 | node.pos_index = pos_to_index[pos] 335 | pos_count[pos] += 1 336 | 337 | # Process word info 338 | node.word_split = [character_to_index[character] for character in word] if word else [] 339 | node.word_index = word_to_index[word] if word else -1 340 | 341 | # Process head info 342 | node.head_split = [character_to_index[character] for character in head] 343 | node.head_index = word_to_index[head] 344 | 345 | # Process ne info 346 | node.ne = ne 347 | if ne != "NONE": 348 | if not node.parent or node.parent.span!=span: 349 | ne_count[ne] += 1 350 | pos_ne_count[pos] += 1 351 | 352 | # Process span info 353 | span_to_node[span] = node 354 | 355 | # Process lexicon info 356 | node.lexicon_hit = [0] * len(lexicon_list) 357 | hits = 0 358 | for index, lexicon in enumerate(lexicon_list): 359 | if constituent in lexicon: 360 | lexicon[constituent] += 1 361 | node.lexicon_hit[index] = 1 362 | hits = 1 363 | lexicon_hits[0] += hits 364 | 365 | # Process children 366 | nodes = 1 367 | for child in node.child_list: 368 | child_nodes = construct_node(child, ner_raw_data, text_raw_data, 369 | character_to_index, word_to_index, pos_to_index, lexicon_list, 370 | pos_count, ne_count, pos_ne_count, lexicon_hits, span_to_node) 371 | nodes += child_nodes 372 | return nodes 373 | 374 | def create_dense_nodes(ner_raw_data, text_raw_data, pos_to_index, lexicon_list, 375 | pos_count, ne_count, pos_ne_count, lexicon_hits, span_to_node): 376 | node_list = [] 377 | max_dense_span = 3 378 | # Start from bigram, since all unigrams are already covered by parses 379 | for span_length in range(2, 1+max_dense_span): 380 | for span_start in range(0, 1+len(text_raw_data)-span_length): 381 | span = (span_start, span_start+span_length) 382 | if span in span_to_node: continue 383 | pos = "NONE" 384 | ne = ner_raw_data[span] if span in ner_raw_data else "NONE" 385 | constituent = " ".join(text_raw_data[span[0]:span[1]]).lower() 386 | 387 | # span, child 388 | # TODO: sibling 389 | node = Node(family=1) 390 | node_list.append(node) 391 | node.span = span 392 | span_to_node[span] = node 393 | node.child_list = [span_to_node[(span[0],span[1]-1)], span_to_node[(span[0]+1,span[1])]] 394 | 395 | # word, head, pos 396 | node.pos_index = pos_to_index[pos] 397 | pos_count[pos] += 1 398 | node.word_split = [] 399 | node.word_index = -1 400 | node.head_split = [] 401 | node.head_index = -1 402 | 403 | # ne 404 | node.ne = ne 405 | if ne != "NONE": 406 | ne_count[ne] += 1 407 | pos_ne_count[pos] += 1 408 | 409 | # lexicon 410 | node.lexicon_hit = [0] * len(lexicon_list) 411 | hits = 0 412 | for index, lexicon in enumerate(lexicon_list): 413 | if constituent in lexicon: 414 | lexicon[constituent] += 1 415 | node.lexicon_hit[index] = 1 416 | hits = 1 417 | lexicon_hits[0] += hits 418 | 419 | return node_list 420 | 421 | def get_tree_data(sentence_list, parse_list, ner_list, 422 | character_to_index, word_to_index, pos_to_index, lexicon_list): 423 | log("get_tree_data()...") 424 | """ Get tree structured data from CoNLL-2003 425 | 426 | Stores into Node data structure 427 | """ 428 | tree_pyramid_list = [] 429 | word_count = 0 430 | pos_count = defaultdict(lambda: 0) 431 | ne_count = defaultdict(lambda: 0) 432 | pos_ne_count = defaultdict(lambda: 0) 433 | lexicon_hits = [0] 434 | 435 | for index, parse in enumerate(parse_list): 436 | text_raw_data = sentence_list[index] 437 | word_count += len(text_raw_data) 438 | span_to_node = {} 439 | 440 | nodes = construct_node( 441 | parse, ner_list[index], text_raw_data, 442 | character_to_index, word_to_index, pos_to_index, lexicon_list, 443 | pos_count, ne_count, pos_ne_count, lexicon_hits, span_to_node) 444 | parse.nodes = nodes 445 | 446 | #additional_node_list = [] 447 | additional_node_list = create_dense_nodes( 448 | ner_list[index], text_raw_data, 449 | pos_to_index, lexicon_list, 450 | pos_count, ne_count, pos_ne_count, lexicon_hits, span_to_node) 451 | 452 | tree_pyramid_list.append((parse, additional_node_list)) 453 | 454 | log(" %d sentences\n" % len(tree_pyramid_list)) 455 | return tree_pyramid_list, word_count, pos_count, ne_count, pos_ne_count, lexicon_hits[0] 456 | 457 | def label_tree_data(node, pos_to_index, ne_to_index): 458 | node.y = ne_to_index[node.ne] 459 | 460 | for child in node.child_list: 461 | label_tree_data(child, pos_to_index, ne_to_index) 462 | return 463 | 464 | def read_dataset(data_split_list = ["train", "validate", "test"]): 465 | # Read all raw data 466 | sentence_data = {} 467 | ner_data = {} 468 | parse_data = {} 469 | for split in data_split_list: 470 | sentence_data[split], ner_data[split] = extract_ner(split) 471 | 472 | dependency_file = os.path.join(dataset, split_dependency[split]) 473 | dependency_parse_list = dependency_utils.read_conllu(dependency_file) 474 | parse_data[split] = [dependency_utils.dependency_to_constituency(*parse) 475 | for parse in dependency_parse_list] 476 | 477 | # Read lists of annotations 478 | character_list, character_to_index = read_list_file(character_file) 479 | word_list, word_to_index = read_list_file(word_file) 480 | pos_list, pos_to_index = read_list_file(pos_file) 481 | ne_list, ne_to_index = read_list_file(ne_file) 482 | 483 | pos_to_index["NONE"] = len(pos_to_index) 484 | 485 | # Read lexicon 486 | lexicon_list = [] 487 | for meta in lexicon_meta_list: 488 | lexicon_list.append(read_list_file(meta["raw"], encoding=meta["encoding"])[1]) 489 | #lexicon_list.append(read_list_file(meta["clean"], encoding=meta["encoding"])[1]) 490 | 491 | for lexicon in lexicon_list: 492 | for phrase in lexicon: 493 | lexicon[phrase] = 0 494 | 495 | # Build a tree structure for each sentence 496 | data = {} 497 | word_count = {} 498 | pos_count = {} 499 | ne_count = {} 500 | pos_ne_count = {} 501 | lexicon_hits = {} 502 | for split in data_split_list: 503 | (tree_pyramid_list, 504 | word_count[split], pos_count[split], ne_count[split], pos_ne_count[split], 505 | lexicon_hits[split]) = get_tree_data( 506 | sentence_data[split], parse_data[split], ner_data[split], 507 | character_to_index, word_to_index, pos_to_index, lexicon_list) 508 | data[split] = {"tree_pyramid_list": tree_pyramid_list, "ner_list": ner_data[split]} 509 | 510 | for index, lexicon in enumerate(lexicon_list): 511 | with codecs.open("tmp_%d.txt" % index, "w", encoding="utf8") as f: 512 | for name, count in sorted(lexicon.iteritems(), key=lambda x: (-x[1], x[0])): 513 | if count == 0: break 514 | f.write("%9d %s\n" % (count, name)) 515 | 516 | # Show statistics of each data split 517 | print "-" * 80 518 | print "%10s%10s%9s%9s%7s%12s%13s" % ("split", "sentence", "token", "node", "NE", "spanned_NE", 519 | "lexicon_hit") 520 | print "-" * 80 521 | for split in data_split_list: 522 | print "%10s%10d%9d%9d%7d%12d%13d" % (split, 523 | len(data[split]["tree_pyramid_list"]), 524 | word_count[split], 525 | sum(pos_count[split].itervalues()), 526 | sum(len(ner) for ner in data[split]["ner_list"]), 527 | sum(ne_count[split].itervalues()), 528 | lexicon_hits[split]) 529 | 530 | # Show POS distribution 531 | total_pos_count = defaultdict(lambda: 0) 532 | for split in data_split_list: 533 | for pos in pos_count[split]: 534 | total_pos_count[pos] += pos_count[split][pos] 535 | nodes = sum(total_pos_count.itervalues()) 536 | print "\nTotal %d nodes" % nodes 537 | print "-"*80 + "\n POS count ratio\n" + "-"*80 538 | for pos, count in sorted(total_pos_count.iteritems(), key=lambda x: x[1], reverse=True)[:10]: 539 | print "%6s %7d %5.1f%%" % (pos, count, count*100./nodes) 540 | 541 | # Show NE distribution in [train, validate] 542 | total_ne_count = defaultdict(lambda: 0) 543 | for split in data_split_list: 544 | if split == "test": continue 545 | for ne in ne_count[split]: 546 | total_ne_count[ne] += ne_count[split][ne] 547 | nes = sum(total_ne_count.itervalues()) 548 | print "\nTotal %d spanned named entities in [train, validate]" % nes 549 | print "-"*80 + "\n NE count ratio\n" + "-"*80 550 | for ne, count in sorted(total_ne_count.iteritems(), key=lambda x: x[1], reverse=True): 551 | print "%12s %6d %5.1f%%" % (ne, count, count*100./nes) 552 | 553 | # Show POS-NE distribution in [train, validate] 554 | total_pos_ne_count = defaultdict(lambda: 0) 555 | for split in data_split_list: 556 | if split == "test": continue 557 | for pos in pos_ne_count[split]: 558 | total_pos_ne_count[pos] += pos_ne_count[split][pos] 559 | print "-"*80 + "\n POS NE total ratio\n" + "-"*80 560 | for pos, count in sorted(total_pos_ne_count.iteritems(), key=lambda x: x[1], reverse=True)[:10]: 561 | total = total_pos_count[pos] 562 | print "%6s %6d %7d %5.1f%%" % (pos, count, total, count*100./total) 563 | 564 | # Compute the mapping to labels 565 | ne_to_index["NONE"] = len(ne_to_index) 566 | 567 | # Add label to nodes 568 | for split in data_split_list: 569 | for tree, pyramid in data[split]["tree_pyramid_list"]: 570 | label_tree_data(tree, pos_to_index, ne_to_index) 571 | for node in pyramid: 572 | node.y = ne_to_index[node.ne] 573 | 574 | return (data, word_list, ne_list, 575 | len(character_to_index), len(pos_to_index), len(ne_to_index), len(lexicon_list)) 576 | 577 | if __name__ == "__main__": 578 | #prepare_dataset() 579 | #extract_glove_embeddings() 580 | #extract_clean_lexicon() 581 | #extract_lexicon_embeddings() 582 | read_dataset() 583 | exit() 584 | 585 | 586 | 587 | 588 | 589 | 590 | 591 | 592 | 593 | 594 | 595 | -------------------------------------------------------------------------------- /conll2003dep/parse.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PARSER_EVAL=/home/danniel/Downloads/tf_models/syntaxnet/bazel-bin/syntaxnet/parser_eval 4 | MODEL_DIR=/home/danniel/Downloads/tf_models/syntaxnet/syntaxnet/models/parsey_mcparseface 5 | TASK_SPEC=/home/danniel/Desktop/rnn_ner/conll2003dep/task.pbtxt 6 | 7 | $PARSER_EVAL \ 8 | --input=conll2003_sentence_$1 \ 9 | --output=stdout-conll \ 10 | --hidden_layer_sizes=64 \ 11 | --arg_prefix=brain_tagger \ 12 | --graph_builder=structured \ 13 | --task_context=$TASK_SPEC \ 14 | --model_path=$MODEL_DIR/tagger-params \ 15 | --slim_model \ 16 | --batch_size=1024 \ 17 | | \ 18 | $PARSER_EVAL \ 19 | --input=stdin-conll \ 20 | --output=conll2003_parse_$1 \ 21 | --hidden_layer_sizes=512,512 \ 22 | --arg_prefix=brain_parser \ 23 | --graph_builder=structured \ 24 | --task_context=$TASK_SPEC \ 25 | --model_path=$MODEL_DIR/parser-params \ 26 | --slim_model \ 27 | --batch_size=1024 \ 28 | 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /conll2003dep/task.pbtxt: -------------------------------------------------------------------------------- 1 | Parameter { 2 | name: "brain_parser_embedding_dims" 3 | value: "32;32;64" 4 | } 5 | Parameter { 6 | name: "brain_parser_embedding_names" 7 | value: "labels;tags;words" 8 | } 9 | Parameter { 10 | name: 'brain_parser_scoring' 11 | value: 'default' 12 | } 13 | Parameter { 14 | name: "brain_parser_features" 15 | value: 16 | 'stack.child(1).label ' 17 | 'stack.child(1).sibling(-1).label ' 18 | 'stack.child(-1).label ' 19 | 'stack.child(-1).sibling(1).label ' 20 | 'stack.child(2).label ' 21 | 'stack.child(-2).label ' 22 | 'stack(1).child(1).label ' 23 | 'stack(1).child(1).sibling(-1).label ' 24 | 'stack(1).child(-1).label ' 25 | 'stack(1).child(-1).sibling(1).label ' 26 | 'stack(1).child(2).label ' 27 | 'stack(1).child(-2).label; ' 28 | 'input.token.tag ' 29 | 'input(1).token.tag ' 30 | 'input(2).token.tag ' 31 | 'input(3).token.tag ' 32 | 'stack.token.tag ' 33 | 'stack.child(1).token.tag ' 34 | 'stack.child(1).sibling(-1).token.tag ' 35 | 'stack.child(-1).token.tag ' 36 | 'stack.child(-1).sibling(1).token.tag ' 37 | 'stack.child(2).token.tag ' 38 | 'stack.child(-2).token.tag ' 39 | 'stack(1).token.tag ' 40 | 'stack(1).child(1).token.tag ' 41 | 'stack(1).child(1).sibling(-1).token.tag ' 42 | 'stack(1).child(-1).token.tag ' 43 | 'stack(1).child(-1).sibling(1).token.tag ' 44 | 'stack(1).child(2).token.tag ' 45 | 'stack(1).child(-2).token.tag ' 46 | 'stack(2).token.tag ' 47 | 'stack(3).token.tag; ' 48 | 'input.token.word ' 49 | 'input(1).token.word ' 50 | 'input(2).token.word ' 51 | 'input(3).token.word ' 52 | 'stack.token.word ' 53 | 'stack.child(1).token.word ' 54 | 'stack.child(1).sibling(-1).token.word ' 55 | 'stack.child(-1).token.word ' 56 | 'stack.child(-1).sibling(1).token.word ' 57 | 'stack.child(2).token.word ' 58 | 'stack.child(-2).token.word ' 59 | 'stack(1).token.word ' 60 | 'stack(1).child(1).token.word ' 61 | 'stack(1).child(1).sibling(-1).token.word ' 62 | 'stack(1).child(-1).token.word ' 63 | 'stack(1).child(-1).sibling(1).token.word ' 64 | 'stack(1).child(2).token.word ' 65 | 'stack(1).child(-2).token.word ' 66 | 'stack(2).token.word ' 67 | 'stack(3).token.word ' 68 | } 69 | Parameter { 70 | name: "brain_parser_transition_system" 71 | value: "arc-standard" 72 | } 73 | 74 | Parameter { 75 | name: "brain_tagger_embedding_dims" 76 | value: "8;16;16;16;16;64" 77 | } 78 | Parameter { 79 | name: "brain_tagger_embedding_names" 80 | value: "other;prefix2;prefix3;suffix2;suffix3;words" 81 | } 82 | Parameter { 83 | name: "brain_tagger_features" 84 | value: 85 | 'input.digit ' 86 | 'input.hyphen; ' 87 | 'input.prefix(length="2") ' 88 | 'input(1).prefix(length="2") ' 89 | 'input(2).prefix(length="2") ' 90 | 'input(3).prefix(length="2") ' 91 | 'input(-1).prefix(length="2") ' 92 | 'input(-2).prefix(length="2") ' 93 | 'input(-3).prefix(length="2") ' 94 | 'input(-4).prefix(length="2"); ' 95 | 'input.prefix(length="3") ' 96 | 'input(1).prefix(length="3") ' 97 | 'input(2).prefix(length="3") ' 98 | 'input(3).prefix(length="3") ' 99 | 'input(-1).prefix(length="3") ' 100 | 'input(-2).prefix(length="3") ' 101 | 'input(-3).prefix(length="3") ' 102 | 'input(-4).prefix(length="3"); ' 103 | 'input.suffix(length="2") ' 104 | 'input(1).suffix(length="2") ' 105 | 'input(2).suffix(length="2") ' 106 | 'input(3).suffix(length="2") ' 107 | 'input(-1).suffix(length="2") ' 108 | 'input(-2).suffix(length="2") ' 109 | 'input(-3).suffix(length="2") ' 110 | 'input(-4).suffix(length="2"); ' 111 | 'input.suffix(length="3") ' 112 | 'input(1).suffix(length="3") ' 113 | 'input(2).suffix(length="3") ' 114 | 'input(3).suffix(length="3") ' 115 | 'input(-1).suffix(length="3") ' 116 | 'input(-2).suffix(length="3") ' 117 | 'input(-3).suffix(length="3") ' 118 | 'input(-4).suffix(length="3"); ' 119 | 'input.token.word ' 120 | 'input(1).token.word ' 121 | 'input(2).token.word ' 122 | 'input(3).token.word ' 123 | 'input(-1).token.word ' 124 | 'input(-2).token.word ' 125 | 'input(-3).token.word ' 126 | 'input(-4).token.word ' 127 | } 128 | Parameter { 129 | name: "brain_tagger_transition_system" 130 | value: "tagger" 131 | } 132 | 133 | input { 134 | name: "tag-map" 135 | Part { 136 | file_pattern: "syntaxnet/models/parsey_mcparseface/tag-map" 137 | } 138 | } 139 | input { 140 | name: "tag-to-category" 141 | Part { 142 | file_pattern: "syntaxnet/models/parsey_mcparseface/fine-to-universal.map" 143 | } 144 | } 145 | input { 146 | name: "word-map" 147 | Part { 148 | file_pattern: "syntaxnet/models/parsey_mcparseface/word-map" 149 | } 150 | } 151 | input { 152 | name: "label-map" 153 | Part { 154 | file_pattern: "syntaxnet/models/parsey_mcparseface/label-map" 155 | } 156 | } 157 | input { 158 | name: "prefix-table" 159 | Part { 160 | file_pattern: "syntaxnet/models/parsey_mcparseface/prefix-table" 161 | } 162 | } 163 | input { 164 | name: "suffix-table" 165 | Part { 166 | file_pattern: "syntaxnet/models/parsey_mcparseface/suffix-table" 167 | } 168 | } 169 | input { 170 | name: 'stdin' 171 | record_format: 'english-text' 172 | Part { 173 | file_pattern: '-' 174 | } 175 | } 176 | input { 177 | name: 'stdin-conll' 178 | record_format: 'conll-sentence' 179 | Part { 180 | file_pattern: '-' 181 | } 182 | } 183 | input { 184 | name: 'stdout-conll' 185 | record_format: 'conll-sentence' 186 | Part { 187 | file_pattern: '-' 188 | } 189 | } 190 | 191 | input { 192 | name: 'conll2003_sentence_train' 193 | record_format: 'conll-sentence' 194 | Part { 195 | file_pattern: '/home/danniel/Desktop/rnn_ner/conll2003dep/sentence_train.conllu' 196 | } 197 | } 198 | input { 199 | name: 'conll2003_sentence_validate' 200 | record_format: 'conll-sentence' 201 | Part { 202 | file_pattern: '/home/danniel/Desktop/rnn_ner/conll2003dep/sentence_validate.conllu' 203 | } 204 | } 205 | input { 206 | name: 'conll2003_sentence_test' 207 | record_format: 'conll-sentence' 208 | Part { 209 | file_pattern: '/home/danniel/Desktop/rnn_ner/conll2003dep/sentence_test.conllu' 210 | } 211 | } 212 | 213 | input { 214 | name: 'conll2003_parse_train' 215 | record_format: 'conll-sentence' 216 | Part { 217 | file_pattern: '/home/danniel/Desktop/rnn_ner/conll2003dep/dependency_train.conllu' 218 | } 219 | } 220 | input { 221 | name: 'conll2003_parse_validate' 222 | record_format: 'conll-sentence' 223 | Part { 224 | file_pattern: '/home/danniel/Desktop/rnn_ner/conll2003dep/dependency_validate.conllu' 225 | } 226 | } 227 | input { 228 | name: 'conll2003_parse_test' 229 | record_format: 'conll-sentence' 230 | Part { 231 | file_pattern: '/home/danniel/Desktop/rnn_ner/conll2003dep/dependency_test.conllu' 232 | } 233 | } 234 | 235 | 236 | 237 | -------------------------------------------------------------------------------- /dependency_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import sys 4 | import codecs 5 | import subprocess 6 | from collections import deque, defaultdict 7 | 8 | sys.path.append("/home/danniel/Desktop/CONLL2012-intern") 9 | import pstree 10 | 11 | from rnn import Node 12 | 13 | def get_reversed_head_list(head_list): 14 | words = len(head_list) 15 | left_list = [[] for i in head_list] 16 | right_list = [[] for i in head_list] 17 | root = -1 18 | 19 | for index, head in enumerate(head_list): 20 | if head == -1: 21 | root = index 22 | elif index < head: 23 | left_list[head].append(index) 24 | elif index > head: 25 | right_list[head].append(index) 26 | 27 | reversed_head_list = [] 28 | for index in xrange(len(head_list)): 29 | reversed_head_list.append(left_list[index][::-1] + right_list[index]) 30 | return reversed_head_list, root 31 | 32 | def read_conllu(conllu_path): 33 | with open(conllu_path, "r") as f: 34 | line_list = f.read().splitlines() 35 | 36 | sentence_list = [] 37 | 38 | word_list, pos_list, head_list, relation_list = [], [], [], [] 39 | for line in line_list: 40 | if not line: 41 | head_list, root = get_reversed_head_list(head_list) 42 | sentence_list.append([word_list, pos_list, head_list, relation_list, root]) 43 | word_list, pos_list, head_list, relation_list = [], [], [], [] 44 | continue 45 | _, word, _, _, pos, _, head, relation, _, _ = line.split("\t") 46 | word_list.append(word) 47 | pos_list.append(pos) 48 | head_list.append(int(head)-1) 49 | relation_list.append(relation) 50 | 51 | return sentence_list 52 | 53 | def dependency_to_constituency(word_list, pos_list, head_list, relation_list, index): 54 | leaf = Node() 55 | leaf.word = word_list[index] 56 | leaf.pos = pos_list[index] 57 | leaf.span = (index, index+1) 58 | leaf.head = leaf.word 59 | 60 | root = leaf 61 | for child_index in head_list[index]: 62 | child_root = dependency_to_constituency(word_list, pos_list, head_list, relation_list, 63 | child_index) 64 | new_root = Node() 65 | new_root.word = None 66 | new_root.pos = relation_list[child_index] 67 | if child_index < index: 68 | new_root.span = (child_root.span[0], root.span[1]) 69 | new_root.add_child(child_root) 70 | new_root.add_child(root) 71 | else: 72 | new_root.span = (root.span[0], child_root.span[1]) 73 | new_root.add_child(root) 74 | new_root.add_child(child_root) 75 | new_root.head = root.head 76 | root = new_root 77 | 78 | return root 79 | 80 | def show_tree(node, indent): 81 | word = node.word if node.word else "" 82 | print indent + node.pos + " " + word + " " + repr(node.span) 83 | 84 | for child in node.child_list: 85 | show_tree(child, indent+" ") 86 | return 87 | 88 | if __name__ == "__main__": 89 | #read_conllu("conll2003dep/dependency_train.conllu") 90 | 91 | """ 92 | rlist, root = get_reversed_head_list([3,3,3,-1,6,6,3,6,11,10,11,6]) 93 | print "root", root 94 | for i, j in enumerate(rlist): 95 | print i, j 96 | """ 97 | 98 | sentence_list = read_conllu("conll2003dep/tmp.conllu") 99 | for word_list, pos_list, head_list, relation_list, root in sentence_list: 100 | print word_list 101 | print pos_list 102 | print head_list 103 | print relation_list 104 | print root 105 | sentence = sentence_list[0] 106 | root = dependency_to_constituency(*sentence) 107 | show_tree(root, "") 108 | 109 | exit() -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import codecs 6 | import argparse 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | import rnn 12 | 13 | batch_nodes = 500 14 | batch_trees = 16 15 | patience = 20 16 | max_epoches = 100 17 | 18 | def load_embedding(model, word_list, dataset): 19 | """ Load pre-trained word embeddings into the dictionary of model 20 | 21 | word.npy: an array of pre-trained words 22 | embedding.npy: a 2d array of pre-trained word vectors 23 | word_list: a list of words in the dictionary of model 24 | """ 25 | # load pre-trained word embeddings from file 26 | word_array = np.load(os.path.join(dataset, "word.npy")) 27 | embedding_array = np.load(os.path.join(dataset, "embedding.npy")) 28 | word_to_embedding = {} 29 | for i, word in enumerate(word_array): 30 | word_to_embedding[word] = embedding_array[i] 31 | 32 | # Store pre-trained word embeddings into the dictionary of model 33 | L = model.sess.run(model.L) 34 | for index, word in enumerate(word_list): 35 | if word in word_to_embedding: 36 | L[index] = word_to_embedding[word] 37 | model.sess.run(model.L.assign(L)) 38 | return 39 | 40 | def load_lexicon_embedding(model, dataset): 41 | """ Load lexicon embeddings into the lexicon dictionary of model 42 | 43 | lexicon_embedding.npy: a 2d array of phrase vectors 44 | """ 45 | embedding_array = np.load(os.path.join(dataset, "lexicon_embedding.npy")) 46 | model.sess.run(model.L_phrase.assign(embedding_array)) 47 | return 48 | 49 | def load_data_and_initialize_model(dataset, split_list=["train", "validate", "test"], 50 | use_pretrained_embedding=True): 51 | """ Get tree data and initialize a model 52 | 53 | #data: a dictionary; key-value example: "train"-(tree_list, ner_list) 54 | data: a dictionary; key-value example: 55 | "train"-{"tree_pyramid_list": tree_pyramid_list, "ner_list": ner_list} 56 | tree_pyramid_list: a list of (tree, pyramid) tuples 57 | ner_list: a list of dictionaries; key-value example: (3,5)-"PERSON" 58 | ne_list: a list of distinct string labels, e.g. "PERSON" 59 | """ 60 | # Select the implementation of loading data according to dataset 61 | if dataset == "ontonotes": 62 | import ontonotes as data_utils 63 | elif dataset == "ontochinese": 64 | import ontochinese as data_utils 65 | elif dataset == "conll2003": 66 | import conll2003 as data_utils 67 | elif dataset == "conll2003dep": 68 | import conll2003dep as data_utils 69 | 70 | # Load data and determine dataset related hyperparameters 71 | config = rnn.Config() 72 | (data, word_list, ne_list, 73 | config.alphabet_size, config.pos_dimension, config.output_dimension, config.lexicons 74 | ) = data_utils.read_dataset(split_list) 75 | config.vocabulary_size = len(word_list) 76 | 77 | # Initialize a model 78 | model = rnn.RNN(config) 79 | """ 80 | tf_config = tf.ConfigProto() 81 | tf_config.gpu_options.allow_growth = True 82 | model.sess = tf.Session(config=tf_config) 83 | """ 84 | model.sess = tf.Session() 85 | model.sess.run(tf.global_variables_initializer()) 86 | if use_pretrained_embedding: load_embedding(model, word_list, dataset) 87 | return data, ne_list, model 88 | 89 | def make_batch_list(tree_pyramid_list): 90 | """ Create a list of batches of (tree, pyramid) tuples 91 | 92 | The (tree, pyramid) tuples in the same batch have similar numbers of nodes, 93 | so later padding can be minimized. 94 | """ 95 | index_tree_pyramid_list = sorted(enumerate(tree_pyramid_list), 96 | key=lambda x: x[1][0].nodes+len(x[1][1])) 97 | 98 | batch_list = [] 99 | batch = [] 100 | for index, tree_pyramid in index_tree_pyramid_list: 101 | nodes = tree_pyramid[0].nodes + len(tree_pyramid[1]) 102 | if len(batch)+1 > batch_trees or (len(batch)+1)*nodes > batch_nodes: 103 | batch_list.append(batch) 104 | batch = [] 105 | batch.append((index, tree_pyramid)) 106 | batch_list.append(batch) 107 | 108 | random.shuffle(batch_list) 109 | #batch_list = batch_list[::-1] 110 | return batch_list 111 | 112 | def train_an_epoch(model, tree_pyramid_list): 113 | """ Update model parameters for every tree once 114 | """ 115 | batch_list = make_batch_list(tree_pyramid_list) 116 | 117 | total_trees = len(tree_pyramid_list) 118 | trees = 0 119 | loss = 0. 120 | for i, batch in enumerate(batch_list): 121 | _, tree_pyramid_list = zip(*batch) 122 | #print "YOLO %d %d" % (tree_pyramid_list[-1][0].nodes, len(tree_pyramid_list[-1][1])) 123 | loss += model.train(tree_pyramid_list) 124 | trees += len(batch) 125 | sys.stdout.write("\r(%5d/%5d) average loss %.3f " % (trees, total_trees, loss/trees)) 126 | sys.stdout.flush() 127 | 128 | sys.stdout.write("\r" + " "*64 + "\r") 129 | return loss / total_trees 130 | 131 | def predict_dataset(model, tree_pyramid_list, ne_list): 132 | """ Get dictionarues of predicted positive spans and their labels for every tree 133 | """ 134 | batch_list = make_batch_list(tree_pyramid_list) 135 | 136 | ner_list = [None] * len(tree_pyramid_list) 137 | for batch in batch_list: 138 | index_list, tree_pyramid_list = zip(*batch) 139 | for i, span_y in enumerate(model.predict(tree_pyramid_list)): 140 | ner_list[index_list[i]] = {span: ne_list[y] for span, y in span_y.iteritems()} 141 | return ner_list 142 | 143 | def evaluate_prediction(ner_list, ner_hat_list): 144 | """ Compute the score of the prediction of trees 145 | """ 146 | reals = 0. 147 | positives = 0. 148 | true_positives = 0. 149 | for index, ner in enumerate(ner_list): 150 | ner_hat = ner_hat_list[index] 151 | reals += len(ner) 152 | positives += len(ner_hat) 153 | for span in ner_hat.iterkeys(): 154 | if span not in ner: continue 155 | if ner[span] == ner_hat[span]: 156 | true_positives += 1 157 | 158 | try: 159 | precision = true_positives / positives 160 | except ZeroDivisionError: 161 | precision = 1. 162 | 163 | try: 164 | recall = true_positives / reals 165 | except ZeroDivisionError: 166 | recall = 1. 167 | 168 | try: 169 | f1 = 2*precision*recall / (precision + recall) 170 | except ZeroDivisionError: 171 | f1 = 0. 172 | 173 | return precision*100, recall*100, f1*100 174 | 175 | def train_model(dataset, pretrain): 176 | """ Update model parameters until it converges or reaches maximum epochs 177 | """ 178 | data, ne_list, model = load_data_and_initialize_model(dataset, 179 | use_pretrained_embedding=not pretrain) 180 | 181 | saver = tf.train.Saver() 182 | if pretrain: 183 | saver.restore(model.sess, "./tmp.model") 184 | 185 | best_epoch = 0 186 | best_score = (-1, -1, -1) 187 | best_loss = float("inf") 188 | for epoch in xrange(1, max_epoches+1): 189 | print "\n" % epoch 190 | 191 | start_time = time.time() 192 | loss = train_an_epoch(model, data["train"]["tree_pyramid_list"]) 193 | print "[train] average loss %.3f; elapsed %.0fs" % (loss, time.time()-start_time) 194 | 195 | start_time = time.time() 196 | ner_hat_list = predict_dataset(model, data["validate"]["tree_pyramid_list"], ne_list) 197 | score = evaluate_prediction(data["validate"]["ner_list"], ner_hat_list) 198 | print "[validate] precision=%.1f%% recall=%.1f%% f1=%.3f%%; elapsed %.0fs;" % (score+(time.time()-start_time,)), 199 | 200 | if best_score[2] < score[2]: 201 | print "best" 202 | best_epoch = epoch 203 | best_score = score 204 | best_loss = loss 205 | saver.save(model.sess, "tmp.model") 206 | else: 207 | print "worse #%d" % (epoch-best_epoch) 208 | if epoch-best_epoch >= patience: break 209 | 210 | print "\n" % best_epoch 211 | print "[train] average loss %.3f" % best_loss 212 | print "[validate] precision=%.1f%% recall=%.1f%% f1=%.3f%%" % best_score 213 | saver.restore(model.sess, "./tmp.model") 214 | ner_hat_list = predict_dataset(model, data["test"]["tree_pyramid_list"], ne_list) 215 | score = evaluate_prediction(data["test"]["ner_list"], ner_hat_list) 216 | print "[test] precision=%.1f%% recall=%.1f%% f1=%.3f%%" % score 217 | return 218 | 219 | def ner_diff(ner_a_list, ner_b_list): 220 | """ 221 | Compute the differences of two ner predictions 222 | 223 | ner_list: a list of the ner prediction of each sentence 224 | ner: a dict of span-ne pairs 225 | """ 226 | sentences = len(ner_a_list) 227 | print "%d sentences" % sentences 228 | print "a: %d nes" % sum(len(ner) for ner in ner_a_list) 229 | print "b: %d nes" % sum(len(ner) for ner in ner_b_list) 230 | 231 | ner_aa_list = [] 232 | ner_bb_list = [] 233 | ner_ab_list = [] 234 | for i in xrange(sentences): 235 | ner_aa = {span: ne for span, ne in ner_a_list[i].iteritems()} 236 | ner_bb = {span: ne for span, ne in ner_b_list[i].iteritems()} 237 | ner_ab = {} 238 | for span, ne in ner_aa.items(): 239 | if span in ner_bb and ner_aa[span] == ner_bb[span]: 240 | del ner_aa[span] 241 | del ner_bb[span] 242 | ner_ab[span] = ne 243 | ner_aa_list.append(ner_aa) 244 | ner_bb_list.append(ner_bb) 245 | ner_ab_list.append(ner_ab) 246 | 247 | return ner_aa_list, ner_bb_list, ner_ab_list 248 | 249 | def write_ner(target_file, text_raw_data, ner_list): 250 | """ 251 | Write the ner prediction of each sentence to file, 252 | indexing the senteces from 0. 253 | 254 | ner_list: a list of the ner prediction of each sentence 255 | ner: a dict of span-ne pairs 256 | """ 257 | print "" 258 | print target_file 259 | sentences = len(text_raw_data) 260 | print "%d sentences" % sentences 261 | 262 | with codecs.open(target_file, "w", encoding="utf8") as f: 263 | for i in xrange(sentences): 264 | if len(ner_list[i]) == 0: continue 265 | f.write("\n%d\n" % i) 266 | f.write("%s\n" % " ".join(text_raw_data[i])) 267 | for span, ne in ner_list[i].iteritems(): 268 | text_chunk = " ".join(text_raw_data[i][span[0]:span[1]]) 269 | f.write("%d %d %s <%s>\n" % (span[0], span[1], ne, text_chunk)) 270 | 271 | print "%d nes" % sum(len(ner) for ner in ner_list) 272 | return 273 | 274 | def read_ner(source_file): 275 | """ 276 | Read the ner prediction of each sentence from file, 277 | 278 | index_ner: a dict of setence index-ner pairs 279 | ner: a dict of span-ne pairs 280 | """ 281 | with codecs.open(source_file, "r", encoding="utf8") as f: 282 | line_list = f.readlines() 283 | 284 | index_ner = {} 285 | sentence_index = -1 286 | line_index = -1 287 | while line_index+1 < len(line_list): 288 | line_index += 1 289 | line = line_list[line_index].strip().split() 290 | if not line: continue 291 | if len(line) == 1: 292 | sentence_index = int(line[0]) 293 | index_ner[sentence_index] = {} 294 | line_index += 1 295 | else: 296 | l, r, ne = line[:3] 297 | index_ner[sentence_index][(int(l),int(r))] = ne 298 | return index_ner 299 | 300 | def evaluate_model(dataset, split): 301 | """ Compute the scores of an existing model 302 | """ 303 | data, ne_list, model = load_data_and_initialize_model(dataset, split_list=[split]) 304 | 305 | saver = tf.train.Saver() 306 | saver.restore(model.sess, "./tmp.model") 307 | ner_hat_list = predict_dataset(model, data[split]["tree_pyramid_list"], ne_list) 308 | score = evaluate_prediction(data[split]["ner_list"], ner_hat_list) 309 | print "[%s]" % split + " precision=%.1f%% recall=%.1f%% f1=%.3f%%" % score 310 | """ 311 | # YOLO 312 | text_raw_data = [tree.text_raw_data for tree, pyramid in data[split]["tree_pyramid_list"]] 313 | false_negative, false_positive, correct = ner_diff(data[split]["ner_list"], ner_hat_list) 314 | write_ner("bi_fn.txt", text_raw_data, false_negative) 315 | write_ner("bi_fp.txt", text_raw_data, false_positive) 316 | """ 317 | return 318 | 319 | def compare_model(dataset, split, bad_ner_file, good_ner_file, diff_file): 320 | data, ne_list, model = load_data_and_initialize_model(dataset, split_list=[split]) 321 | text_raw_data = [tree.text_raw_data for tree, pyramid in data[split]["tree_pyramid_list"]] 322 | gold_ner_list = data[split]["ner_list"] 323 | 324 | bad_index_ner = read_ner(bad_ner_file) 325 | print "bad: %d nes", sum(len(ner) for index, ner in bad_index_ner.iteritems()) 326 | good_index_ner = read_ner(good_ner_file) 327 | print "good: %d nes", sum(len(ner) for index, ner in good_index_ner.iteritems()) 328 | 329 | index_span = {} 330 | for sentence_index in bad_index_ner: 331 | bad_ner = bad_index_ner[sentence_index] 332 | good_ner = good_index_ner[sentence_index] if sentence_index in good_index_ner else {} 333 | gold_ner = gold_ner_list[sentence_index] 334 | span_set = set() 335 | for span, ne in bad_ner.iteritems(): 336 | if span not in good_ner: 337 | span_set.add(span) 338 | if span_set: 339 | index_span[sentence_index] = span_set 340 | 341 | with codecs.open(diff_file, "w", encoding="utf8") as f: 342 | for sentence_index, span_set in sorted(index_span.iteritems()): 343 | f.write("\n%d\n" % sentence_index) 344 | f.write("%s\n" % " ".join(text_raw_data[sentence_index])) 345 | for span in span_set: 346 | text_chunk = " ".join(text_raw_data[sentence_index][span[0]:span[1]]) 347 | bad_ne = bad_index_ner[sentence_index][span] 348 | gold_ne = gold_ner_list[sentence_index][span] if span in gold_ner_list[sentence_index] else "NONE" 349 | f.write("%d %d %s->%s <%s>\n" % (span[0], span[1], bad_ne, gold_ne, text_chunk)) 350 | return 351 | 352 | def show_tree(dataset, split, tree_index): 353 | data, _, _ = load_data_and_initialize_model(dataset, split_list=[split]) 354 | tree = data[split]["tree_pyramid_list"][tree_index][0] 355 | tree.show_tree() 356 | return 357 | 358 | def main(): 359 | parser = argparse.ArgumentParser() 360 | parser.add_argument("-m", dest="mode", default="train", 361 | choices=["train", "evaluate", "compare", "showtree"]) 362 | parser.add_argument("-s", dest="split", default="validate", 363 | choices=["train", "validate", "test"]) 364 | parser.add_argument("-d", dest="dataset", default="ontonotes", 365 | choices=["ontonotes", "ontochinese", "conll2003", "conll2003dep"]) 366 | parser.add_argument("-p", dest="pretrain", action="store_true") 367 | parser.add_argument("-i", dest="tree_index", default="24") 368 | arg = parser.parse_args() 369 | 370 | if arg.mode == "train": 371 | train_model(arg.dataset, arg.pretrain) 372 | elif arg.mode == "evaluate": 373 | evaluate_model(arg.dataset, arg.split) 374 | elif arg.mode == "compare": 375 | compare_model(arg.dataset, arg.split, "bot_fp.txt", "bi_fp.txt", "bot_to_bi.txt") 376 | elif arg.mode == "showtree": 377 | show_tree(arg.dataset, arg.split, int(arg.tree_index)) 378 | return 379 | 380 | if __name__ == "__main__": 381 | main() 382 | 383 | 384 | 385 | 386 | 387 | 388 | -------------------------------------------------------------------------------- /load_conll_2012/coreference_reading.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys, os 5 | import pstree, treebanks, head_finder 6 | from collections import defaultdict 7 | from StringIO import StringIO 8 | import re 9 | import codecs 10 | 11 | 12 | def get_parse_spans(parses, word ,word_index): 13 | if len(parses.subtrees) > 0: 14 | for parse in parses.subtrees: 15 | index = get_parse_spans(parse, word ,word_index) 16 | if index: 17 | return index 18 | else: 19 | if parses.word == word and parses.span[0] == word_index[0] and parses.span[1] == word_index[1]: 20 | return parses.parent.span 21 | pass 22 | 23 | def read_conll_parses(lines): 24 | in_file = StringIO(''.join(lines)) 25 | return treebanks.read_trees(in_file, treebanks.conll_read_tree) 26 | 27 | def read_conll_text(lines): 28 | text = [[]] 29 | for line in lines: 30 | line = line.strip() 31 | fields = re.split(r'\s+', line) 32 | if len(line) == 0: 33 | text.append([]) 34 | else: 35 | text[-1].append(fields[3]) 36 | if len(text[-1]) == 0: 37 | text.pop() 38 | return text 39 | 40 | def read_conll_ner(lines): 41 | info = {} 42 | word = 0 43 | sentence = 0 44 | cur = [] 45 | for line in lines: 46 | line = line.strip() 47 | fields = re.split(r'\s+', line) 48 | if len(fields) >= 11: 49 | ner_info = fields[10] 50 | if '(' in ner_info and '*' in ner_info: 51 | cur.append((ner_info[1:-1], sentence, word)) 52 | elif '(' in ner_info and ')' in ner_info: 53 | info[sentence, word, word +1] = ner_info[1:-1] 54 | elif ')' in ner_info and '*' in ner_info: 55 | start = cur.pop() 56 | if sentence != start[1]: 57 | print >> sys.stderr, "Something mucked up", sentence, word, start 58 | info[sentence, start[2], word +1] = start[0] 59 | word += 1 60 | if len(line) == 0: 61 | sentence += 1 62 | word = 0 63 | return info 64 | 65 | 66 | def read_conll_speakers(lines): 67 | info = {} 68 | word = 0 69 | sentence = 0 70 | for line in lines: 71 | line = line.strip() 72 | fields = re.split(r'\s+', line) 73 | if len(fields) >= 10: 74 | spk_info = fields[9] 75 | if spk_info != '-' and len(spk_info) > 1: 76 | if sentence not in info: 77 | info[sentence] = {} 78 | info[sentence][sentence, word, word + 1] = spk_info 79 | word += 1 80 | if len(line) == 0: 81 | sentence += 1 82 | word = 0 83 | return info 84 | 85 | def read_conll_fcol(lines): 86 | info = [[]] 87 | for line in lines: 88 | line = line.strip() 89 | fields = re.split(r'\s+', line) 90 | if len(line) == 0: 91 | info.append([]) 92 | else: 93 | info[-1].append(fields[0]) 94 | if len(info[-1]) == 0: 95 | info.pop() 96 | return info 97 | 98 | 99 | def read_conll_coref(lines): 100 | # Assumes: 101 | # - Reading a single part 102 | # - If duplicate mentions occur, use the first 103 | regex = "([(][0-9]*[)])|([(][0-9]*)|([0-9]*[)])|([|])" 104 | mentions = {} # (sentence, start, end+1) -> ID 105 | clusters = {} # ID -> list of (sentence, start, end+1)s 106 | unmatched_mentions = defaultdict(lambda: []) 107 | sentence = 0 108 | word = 0 109 | line_no = 0 110 | for line in lines: 111 | line_no += 1 112 | if len(line) > 0 and line[0] =='#': 113 | continue 114 | line = line.strip() 115 | if len(line) == 0: 116 | sentence += 1 117 | word = 0 118 | unmatched_mentions = defaultdict(lambda: []) 119 | continue 120 | # Canasai's comment out: fields = line.strip().split() 121 | fields = re.split(r'\s+', line.strip()) 122 | for triple in re.findall(regex, fields[-1]): 123 | if triple[1] != '': 124 | val = int(triple[1][1:]) 125 | unmatched_mentions[(sentence, val)].append(word) 126 | elif triple[0] != '' or triple[2] != '': 127 | start = word 128 | val = -1 129 | if triple[0] != '': 130 | val = int(triple[0][1:-1]) 131 | else: 132 | val = int(triple[2][:-1]) 133 | if (sentence, val) not in unmatched_mentions: 134 | print >> sys.stderr, "Ignoring a mention with no start", str(val), line.strip(), line_no 135 | continue 136 | if len(unmatched_mentions[(sentence, val)]) == 0: 137 | print >> sys.stderr, "No other start available", str(val), line.strip(), line_no 138 | continue 139 | start = unmatched_mentions[(sentence, val)].pop() 140 | end = word + 1 141 | if (sentence, start, end) in mentions: 142 | print >> sys.stderr, "Duplicate mention", sentence, start, end, val, mentions[sentence, start, end] 143 | else: 144 | mentions[sentence, start, end] = val 145 | if val not in clusters: 146 | clusters[val] = [] 147 | clusters[val].append((sentence, start, end)) 148 | word += 1 149 | for key in unmatched_mentions: 150 | if len(unmatched_mentions[key]) > 0: 151 | print >> sys.stderr, "Mention started, but did not end ", str(unmatched_mentions[key]) 152 | return mentions, clusters 153 | 154 | 155 | def read_conll_doc(filename, ans=None, rtext=True, rparses=True, rheads=True, rclusters=True, rner=True, rspeakers=True, rfcol=False): 156 | if ans is None: 157 | ans = {} 158 | cur = [] 159 | keys = None 160 | for line in codecs.open(filename, 'r', 'utf-8'): 161 | if len(line) > 0 and line.startswith('#begin') or line.startswith('#end'): 162 | if 'begin' in line: 163 | desc = line.split() 164 | location = desc[2].strip('();') 165 | keys = (location, desc[-1]) 166 | if len(cur) > 0: 167 | if keys is None: 168 | print >> sys.stderr, "Error reading conll file - invalid #begin statemen\n", line 169 | else: 170 | info = {} 171 | if rtext: 172 | info['text'] = read_conll_text(cur) 173 | if rparses: 174 | info['parses'] = read_conll_parses(cur) 175 | if rheads: 176 | info['heads'] = [head_finder.collins_find_heads(parse) for parse in info['parses']] 177 | if rclusters: 178 | info['mentions'], info['clusters'] = read_conll_coref(cur) 179 | if rner: 180 | info['ner'] = read_conll_ner(cur) 181 | if rspeakers: 182 | info['speakers'] = read_conll_speakers(cur) 183 | if rfcol: 184 | info['fcol'] = read_conll_fcol(cur) 185 | if keys[0] not in ans: 186 | ans[keys[0]] = {} 187 | ans[keys[0]][keys[1]] = info 188 | keys = None 189 | cur = [] 190 | else: 191 | cur.append(line) 192 | return ans 193 | -------------------------------------------------------------------------------- /load_conll_2012/head_finder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # vim: set ts=2 sw=2 noet: 4 | 5 | import sys 6 | 7 | #TODO: Handle other langauges 8 | 9 | collins_mapping_table = { 10 | 'ADJP': ('right', ['NNS', 'QP', 'NN', '$', 'ADVP', 'JJ', 'VBN', 'VBG', 'ADJP', 'JJR', 'NP', 'JJS', 'DT', 'FW', 'RBR', 'RBS', 'SBAR', 'RB']), 11 | 'ADVP': ('left', ['RB', 'RBR', 'RBS', 'FW', 'ADVP', 'TO', 'CD', 'JJR', 'JJ', 'IN', 'NP', 'JJS', 'NN']), 12 | 'CONJP': ('left', ['CC', 'RB', 'IN']), 13 | 'FRAG': ('left', []), 14 | 'INTJ': ('right', []), 15 | 'LST': ('left', ['LS', ':']), 16 | 'NAC': ('right', ['NN', 'NNS', 'NNP', 'NNPS', 'NP', 'NAC', 'EX', '$', 'CD', 'QP', 'PRP', 'VBG', 'JJ', 'JJS', 'JJR', 'ADJP', 'FW']), 17 | 'PP': ('left', ['IN', 'TO', 'VBG', 'VBN', 'RP', 'FW']), 18 | 'PRN': ('right', []), 19 | 'PRT': ('left', ['RP']), 20 | 'QP': ('right', ['$', 'IN', 'NNS', 'NN', 'JJ', 'RB', 'DT', 'CD', 'NCD', 'QP', 'JJR', 'JJS']), 21 | 'RRC': ('left', ['VP', 'NP', 'ADVP', 'ADJP', 'PP']), 22 | 'S': ('right', ['TO', 'IN', 'VP', 'S', 'SBAR', 'ADJP', 'UCP', 'NP']), 23 | 'SBAR': ('right', ['WHNP', 'WHPP', 'WHADVP', 'WHADJP', 'IN', 'DT', 'S', 'SQ', 'SINV', 'SBAR', 'FRAG']), 24 | 'SBARQ': ('right', ['SQ', 'S', 'SINV', 'SBARQ', 'FRAG']), 25 | 'SINV': ('right', ['VBZ', 'VBD', 'VBP', 'VB', 'MD', 'VP', 'S', 'SINV', 'ADJP', 'NP']), 26 | 'SQ': ('right', ['VBZ', 'VBD', 'VBP', 'VB', 'MD', 'VP', 'SQ']), 27 | 'UCP': ('left', []), 28 | 'VP': ('right', ['TO', 'VBD', 'VBN', 'MD', 'VBZ', 'VB', 'VBG', 'VBP', 'VP', 'ADJP', 'NN', 'NNS', 'NP']), 29 | 'WHADJP': ('right', ['CC', 'WRB', 'JJ', 'ADJP']), 30 | 'WHADVP': ('left', ['CC', 'WRB']), 31 | 'WHNP': ('right', ['WDT', 'WP', 'WP$', 'WHADJP', 'WHPP', 'WHNP']), 32 | 'WHPP': ('left', ['IN', 'TO', 'FW']), 33 | # Added by me: 34 | 'NX': ('right', ['NN', 'NNS', 'NNP', 'NNPS', 'NP', 'NAC', 'EX', '$', 'CD', 'QP', 'PRP', 'VBG', 'JJ', 'JJS', 'JJR', 'ADJP', 'FW']), 35 | 'X': ('right', ['NN', 'NNS', 'NNP', 'NNPS', 'NP', 'NAC', 'EX', '$', 'CD', 'QP', 'PRP', 'VBG', 'JJ', 'JJS', 'JJR', 'ADJP', 'FW']), 36 | 'META': ('right', []) 37 | } 38 | 39 | def add_head(head_map, tree, head): 40 | tree_repr = (tree.span, tree.label) 41 | head_map[tree_repr] = head 42 | 43 | def get_head(head_map, tree): 44 | tree_repr = (tree.span, tree.label) 45 | return head_map[tree_repr] 46 | 47 | def first_search(tree, options, head_map): 48 | for subtree in tree.subtrees: 49 | if get_head(head_map, subtree)[2] in options or subtree.label in options: 50 | add_head(head_map, tree, get_head(head_map, subtree)) 51 | return True 52 | return False 53 | 54 | def last_search(tree, options, head_map): 55 | for i in xrange(len(tree.subtrees) - 1, -1, -1): 56 | subtree = tree.subtrees[i] 57 | if get_head(head_map, subtree)[2] in options or subtree.label in options: 58 | add_head(head_map, tree, get_head(head_map, subtree)) 59 | return True 60 | return False 61 | 62 | # Canasai's addition begin 63 | def coordinated_search(tree, idx, options, head_map): 64 | subtree = tree.subtrees[idx] 65 | if get_head(head_map, subtree)[2] in options or subtree.label in options: 66 | add_head(head_map, tree, get_head(head_map, subtree)) 67 | return True 68 | return False 69 | # Canasai's addition end 70 | 71 | def collins_NP(tree, head_map): 72 | for subtree in tree.subtrees: 73 | collins_find_heads(subtree, head_map) 74 | #TODO:todo Extra special cases for NPs 75 | ### Ignore the row for NPs -- I use a special set of rules for this. For these 76 | ### I initially remove ADJPs, QPs, and also NPs which dominate a possesive 77 | ### (tagged POS, e.g. (NP (NP the man 's) telescope ) becomes 78 | ### (NP the man 's telescope)). These are recovered as a post-processing stage 79 | ### after parsing. The following rules are then used to recover the NP head: 80 | 81 | #TODO:todo handle NML properly 82 | # Canasai's addition begin 83 | pos_to_look = set(['NN', 'NNP', 'NNPS', 'NNS', 'NX', 'POS', 'JJR']) 84 | # NP -> NP , NP , 85 | if (len(tree.subtrees) == 4 and 86 | tree.subtrees[0].label == 'NP' and 87 | tree.subtrees[1].label == ',' and 88 | tree.subtrees[2].label == 'NP' and 89 | tree.subtrees[3].label == ','): 90 | if coordinated_search(tree, 0, pos_to_look, head_map): 91 | return 92 | # NP -> NP CC NP 93 | # NP -> NNP CC NNP 94 | # NP -> NP , NP 95 | #if tree.word_yield().startswith("Anderson"): 96 | # print tree, len(tree.subtrees) 97 | 98 | if (len(tree.subtrees) == 3 and 99 | tree.subtrees[0].label in {'NP', 'NNP'} and 100 | tree.subtrees[1].label in {'CC', ','} and 101 | tree.subtrees[2].label in {'NP', 'NNP'}): 102 | if coordinated_search(tree, 0, pos_to_look, head_map): 103 | return 104 | # NP -> NP NP 105 | if (len(tree.subtrees) == 2 and 106 | tree.subtrees[0].label == 'NP' and 107 | tree.subtrees[1].label == 'NP'): 108 | if coordinated_search(tree, 0, pos_to_look, head_map): 109 | return 110 | # Canasai's addition end 111 | 112 | if get_head(head_map, tree.subtrees[-1])[2] == 'POS': 113 | # Canasai's comment out: add_head(head_map, tree, get_head(head_map, tree.subtrees[-1])) 114 | if len(tree.subtrees) > 1: 115 | add_head(head_map, tree, get_head(head_map, tree.subtrees[-2])) 116 | else: 117 | add_head(head_map, tree, get_head(head_map, tree.subtrees[-1])) 118 | return 119 | if last_search(tree, set(['NN', 'NNP', 'NNPS', 'NNS', 'NX', 'POS', 'JJR']), head_map): 120 | return 121 | if first_search(tree, set(['NP', 'NML']), head_map): 122 | return 123 | if last_search(tree, set(['$', 'ADJP', 'PRN']), head_map): 124 | return 125 | if last_search(tree, set(['CD']), head_map): 126 | return 127 | if last_search(tree, set(['JJ', 'JJS', 'RB', 'QP']), head_map): 128 | return 129 | add_head(head_map, tree, get_head(head_map, tree.subtrees[-1])) 130 | 131 | def collins_find_heads(tree, head_map=None): 132 | if head_map is None: 133 | head_map = {} 134 | for subtree in tree.subtrees: 135 | collins_find_heads(subtree, head_map) 136 | 137 | # A word is it's own head 138 | if tree.word is not None: 139 | head = (tree.span, tree.word, tree.label) 140 | add_head(head_map, tree, head) 141 | return head_map 142 | 143 | # If the label for this node is not in the table we are either at the bottom, 144 | # at an NP, or have an error 145 | if tree.label not in collins_mapping_table: 146 | if tree.label in ['NP', 'NML']: 147 | collins_NP(tree, head_map) 148 | else: 149 | # TODO: Consider alternative error announcement means 150 | ### if tree.label not in ['ROOT', 'TOP', 'S1', '']: 151 | ### print >> sys.stderr, "Unknown Label: %s" % tree.label 152 | ### print >> sys.stderr, "In tree:", tree.root() 153 | add_head(head_map, tree, get_head(head_map, tree.subtrees[-1])) 154 | return head_map 155 | 156 | # Look through and take the first/last occurrence that matches 157 | info = collins_mapping_table[tree.label] 158 | for label in info[1]: 159 | for i in xrange(len(tree.subtrees)): 160 | if info[0] == 'right': 161 | i = len(tree.subtrees) - i - 1 162 | subtree = tree.subtrees[i] 163 | if subtree.label == label or get_head(head_map, subtree)[2] == label: 164 | add_head(head_map, tree, get_head(head_map, subtree)) 165 | return head_map 166 | 167 | # Final fallback 168 | if info[0] == 'left': 169 | add_head(head_map, tree, get_head(head_map, tree.subtrees[0])) 170 | else: 171 | add_head(head_map, tree, get_head(head_map, tree.subtrees[-1])) 172 | 173 | return head_map 174 | 175 | '''Text from Collins' website: 176 | 177 | This file describes the table used to identify head-words in the papers 178 | 179 | Three Generative, Lexicalised Models for Statistical Parsing (ACL/EACL97) 180 | A New Statistical Parser Based on Bigram Lexical Dependencies (ACL96) 181 | 182 | There are two parts to this file: 183 | 184 | [1] an email from David Magerman describing the head-table used in 185 | D. Magerman. 1995. Statistical Decision-Tree Models for Parsing. 186 | {\it Proceedings of the 33rd Annual Meeting of 187 | the Association for Computational Linguistics}, pages 276-283. 188 | 189 | [2] A modified version of David's head-table which I used in my experiments. 190 | 191 | Many thanks to David Magerman for allowing me to distribute his table. 192 | 193 | 194 | [1] 195 | 196 | From magerman@bbn.com Thu May 25 13:48 EDT 1995 197 | Posted-Date: Thu, 25 May 1995 13:48:07 -0400 198 | Received-Date: Thu, 25 May 1995 13:48:43 +0500 199 | Message-Id: <199505251748.NAA02892@thane.bbn.com> 200 | To: mcollins@gradient.cis.upenn.edu, robertm@unagi.cis.upenn.edu, 201 | mitch@linc.cis.upenn.edu 202 | Cc: magerman@bbn.com 203 | Subject: Re: Head words table 204 | In-Reply-To: Your message of "Thu, 25 May 1995 13:17:14 EDT." 205 | <9505251717.AA17874@gradient.cis.upenn.edu> 206 | Date: Thu, 25 May 1995 13:48:07 -0400 207 | From: David Magerman 208 | Content-Type: text 209 | Content-Length: 2972 210 | 211 | 212 | Hi all. Mike and Robert asked me for the Tree Head Table, so I 213 | thought I'd pass it along to everyone in one shot. Feel free to 214 | distribute it to whomever at Penn wants it. 215 | 216 | Note that it's not complete, and that I've invented a tag (% for the 217 | symbol %) and a label (NP$ for NP's that end in POS). I also have 218 | some optional mapping mechanisms that: (a) convert to_TO -> to_IN when 219 | in a prepositional phrase and (b) translate (PRT x_RP) -> (ADVP x_RB), 220 | thus mapping away the distinction between particles and adverbs. I 221 | currently use transformation (b) in my parser, but don't use (a). 222 | These facts may or may not be relevant, depending on how you want to 223 | use this table. 224 | 225 | Cheers, 226 | -- David 227 | 228 | Tree Head Table 229 | --------------- 230 | 231 | Instructions: 232 | 233 | 1. The first column is the non-terminal. The second column indicates 234 | where you start when you are looking for a head (left is for 235 | head-initial categories, right is for head-final categories). The 236 | rest of the line is a list of non-terminal and pre-terminal categories 237 | which represent the head rule. 238 | 239 | 2. ** is a wildcard value. Any non-terminal with ** in its rule means 240 | that anything can be its head. So, for a head-initial category, ** 241 | means the first word is always the head, and for a head-final 242 | category, ** means the last word is always the head. In most cases, 243 | ** means I didn't investigate good head rules for that category, so it 244 | might be worthwhile to do so yourself. 245 | 246 | 3. The Tree Head Table is used as follows: 247 | 248 | a. Use tree head rule based on NT category of constituent 249 | b. For each category X in tree head rule, scan the children of 250 | the constituent for the first (or last, for head-final) 251 | occurrence of category X. If 252 | X occurs, that child is the head. 253 | c. If no child matches any category in the list, use the first 254 | (or last, for head-final) child as the head. 255 | 256 | 4. I treat the NP category as a special case. Before consulting the 257 | head rule for NP, I look for the rightmost child with a label 258 | beginning with the letter N. If one exists, I use that child as the 259 | head. If no child's tag begins with N, I use the tree head rule. 260 | 261 | ADJP right % QP JJ VBN VBG ADJP $ JJR JJS DT FW **** RBR RBS RB 262 | ADVP left RBR RB RBS FW ADVP CD **** JJR JJS JJ 263 | CONJP left CC RB IN 264 | FRAG left ** 265 | INTJ right ** 266 | LST left LS : 267 | NAC right NN NNS NNP NNPS NP NAC EX $ CD QP PRP VBG JJ JJS JJR ADJP FW 268 | NP right EX $ CD QP PRP VBG JJ JJS JJR ADJP DT FW RB SYM PRP$ 269 | NP$ right NN NNS NNP NNPS NP NAC EX $ CD QP PRP VBG JJ JJS JJR ADJP FW SYM 270 | PNP right ** 271 | PP left IN TO FW 272 | PRN left ** 273 | PRT left RP 274 | QP right CD NCD % QP JJ JJR JJS DT 275 | RRC left VP NP ADVP ADJP PP 276 | S right VP SBAR ADJP UCP NP 277 | SBAR right S SQ SINV SBAR FRAG X 278 | SBARQ right SQ S SINV SBARQ FRAG X 279 | SINV right S VP VBZ VBD VBP VB SINV ADJP NP 280 | SQ right VP VBZ VBD VBP VB MD SQ 281 | UCP left ** 282 | VP left VBD VBN MD VBZ TO VB VP VBG VBP ADJP NP 283 | WHADJP right JJ ADJP 284 | WHADVP left WRB 285 | WHNP right WDT WP WP$ WHADJP WHPP WHNP 286 | WHPP left IN TO FW 287 | X left ** 288 | 289 | 290 | [2] 291 | 292 | Here's the head table which I used in my experiments below. The first column 293 | is just the number of fields on that line. Otherwise, the format is the same 294 | as David's. 295 | 296 | Ignore the row for NPs -- I use a special set of rules for this. For these 297 | I initially remove ADJPs, QPs, and also NPs which dominate a possesive 298 | (tagged POS, e.g. (NP (NP the man 's) telescope ) becomes 299 | (NP the man 's telescope)). These are recovered as a post-processing stage 300 | after parsing. The following rules are then used to recover the NP head: 301 | 302 | If the last word is tagged POS, return (last-word); 303 | 304 | Else search from right to left for the first child which is an NN, NNP, NNPS, NNS, NX, POS, or JJR 305 | 306 | Else search from left to right for first child which is an NP 307 | 308 | Else search from right to left for the first child which is a $, ADJP or PRN 309 | 310 | Else search from right to left for the first child which is a CD 311 | 312 | Else search from right to left for the first child which is a JJ, JJS, RB or QP 313 | 314 | Else return the last word 315 | 316 | 317 | 20 ADJP 0 NNS QP NN $ ADVP JJ VBN VBG ADJP JJR NP JJS DT FW RBR RBS SBAR RB 318 | 15 ADVP 1 RB RBR RBS FW ADVP TO CD JJR JJ IN NP JJS NN 319 | 5 CONJP 1 CC RB IN 320 | 2 FRAG 1 321 | 2 INTJ 0 322 | 4 LST 1 LS : 323 | 19 NAC 0 NN NNS NNP NNPS NP NAC EX $ CD QP PRP VBG JJ JJS JJR ADJP FW 324 | 8 PP 1 IN TO VBG VBN RP FW 325 | 2 PRN 0 326 | 3 PRT 1 RP 327 | 14 QP 0 $ IN NNS NN JJ RB DT CD NCD QP JJR JJS 328 | 7 RRC 1 VP NP ADVP ADJP PP 329 | 10 S 0 TO IN VP S SBAR ADJP UCP NP 330 | 13 SBAR 0 WHNP WHPP WHADVP WHADJP IN DT S SQ SINV SBAR FRAG 331 | 7 SBARQ 0 SQ S SINV SBARQ FRAG 332 | 12 SINV 0 VBZ VBD VBP VB MD VP S SINV ADJP NP 333 | 9 SQ 0 VBZ VBD VBP VB MD VP SQ 334 | 2 UCP 1 335 | 15 VP 0 TO VBD VBN MD VBZ VB VBG VBP VP ADJP NN NNS NP 336 | 6 WHADJP 0 CC WRB JJ ADJP 337 | 4 WHADVP 1 CC WRB 338 | 8 WHNP 0 WDT WP WP$ WHADJP WHPP WHNP 339 | 5 WHPP 1 IN TO FW''' 340 | 341 | 342 | if __name__ == "__main__": 343 | print "Running doctest" 344 | import doctest 345 | doctest.testmod() 346 | 347 | -------------------------------------------------------------------------------- /load_conll_2012/load_conll.py: -------------------------------------------------------------------------------- 1 | import os, sys, fnmatch 2 | import coreference_reading 3 | def load_data(config): 4 | suffix = config["file_suffix"] 5 | dir_prefix = config["dir_prefix"] 6 | print "Load conll documents from:", dir_prefix, " with suffix = ", suffix 7 | data = None 8 | count = 0 9 | source = "" 10 | for root, dirnames, filenames in os.walk(dir_prefix): 11 | #if lang not in root or sets not in root: 12 | #continue 13 | for filename in fnmatch.filter(filenames, '*' + suffix): 14 | file_path = os.path.join(root, filename) 15 | 16 | index = filename.find("_") 17 | if index == -1: 18 | source2 = filename 19 | else: 20 | source2 = filename[:index] 21 | if source != source2: 22 | source = source2 23 | print " <%s>" % source 24 | #print ' ' + filename 25 | 26 | data = coreference_reading.read_conll_doc(file_path, data) 27 | count += 1 28 | if data is None or len(data) == 0: 29 | print ("Cannot load data in '%s' with suffix '%s'" % 30 | (dir_prefix, suffix)) 31 | sys.exit(1) 32 | print "Total doc.: " + str(count) 33 | 34 | return data 35 | 36 | if __name__ == '__main__': 37 | config = {"file_suffix": "gold_conll", 38 | "dir_prefix": "conll-2012/v4/data/train/data/english/annotations/bc/cnn"} 39 | data = load_data(config) 40 | for doc in data: 41 | print 'document:', doc 42 | for part in data[doc]: 43 | yolo = False 44 | for text in data[doc][part]["text"]: 45 | if "Rumsfeld" in text: 46 | yolo = True 47 | break 48 | if not yolo: continue 49 | 50 | print 'part:', part 51 | print 'attrs.:', data[doc][part].keys() 52 | 53 | print "\narrtr: " 54 | text = data[doc][part]["text"] 55 | print type(text) 56 | print len(text) 57 | print text[0] 58 | 59 | print "\narrtr: " 60 | parses = data[doc][part]["parses"] 61 | print type(parses) 62 | print len(parses) 63 | print parses[0] 64 | 65 | print "\narrtr: " 66 | ner = data[doc][part]["ner"] 67 | print type(ner) 68 | print len(ner) 69 | print ner 70 | 71 | print "\narrtr: " 72 | heads = data[doc][part]["heads"] 73 | print type(heads) 74 | print len(heads) 75 | for i, j in heads[0].iteritems(): print i,j 76 | print len(heads[0]) 77 | print heads[0][((5,13), u"VP")] 78 | print heads[0][((0,14), u"S")] 79 | exit() 80 | -------------------------------------------------------------------------------- /load_conll_2012/parse_errors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # vim: set ts=2 sw=2 noet: 4 | 5 | import pstree 6 | 7 | class Parse_Error_Set: 8 | def __init__(self, gold=None, test=None, include_terminals=False): 9 | self.missing = [] 10 | self.crossing = [] 11 | self.extra = [] 12 | self.POS = [] 13 | self.spans = {} 14 | 15 | if gold is not None and test is not None: 16 | errors = get_errors(test, gold, include_terminals) 17 | for error in errors: 18 | self.add_error(error[0], error[1], error[2], error[3]) 19 | 20 | def add_error(self, etype, span, label, node): 21 | error = (etype, span, label, node) 22 | if span not in self.spans: 23 | self.spans[span] = {} 24 | if label not in self.spans[span]: 25 | self.spans[span][label] = [] 26 | self.spans[span][label].append(error) 27 | if etype == 'missing': 28 | self.missing.append(error) 29 | elif etype == 'crossing': 30 | self.crossing.append(error) 31 | elif etype == 'extra': 32 | self.extra.append(error) 33 | elif etype == 'diff POS': 34 | self.POS.append(error) 35 | 36 | def is_extra(self, node): 37 | if node.span in self.spans: 38 | if node.label in self.spans[node.span]: 39 | for error in self.spans[node.span][node.label]: 40 | if error[0] == 'extra': 41 | return True 42 | return False 43 | 44 | def __len__(self): 45 | return len(self.missing) + len(self.extra) + len(self.crossing) + (2*len(self.POS)) 46 | 47 | def get_errors(test, gold, include_terminals=False): 48 | ans = [] 49 | 50 | # Different POS 51 | if include_terminals: 52 | for tnode in test: 53 | if tnode.word is not None: 54 | for gnode in gold: 55 | if gnode.word is not None and gnode.span == tnode.span: 56 | if gnode.label != tnode.label: 57 | ans.append(('diff POS', tnode.span, tnode.label, tnode, gnode.label)) 58 | 59 | test_spans = [(span.span[0], span.span[1], span) for span in test] 60 | test_spans.sort() 61 | test_span_set = {} 62 | to_remove = [] 63 | for span in test_spans: 64 | if span[2].is_terminal(): 65 | to_remove.append(span) 66 | continue 67 | key = (span[0], span[1], span[2].label) 68 | if key not in test_span_set: 69 | test_span_set[key] = 0 70 | test_span_set[key] += 1 71 | for span in to_remove: 72 | test_spans.remove(span) 73 | 74 | gold_spans = [(span.span[0], span.span[1], span) for span in gold] 75 | gold_spans.sort() 76 | gold_span_set = {} 77 | to_remove = [] 78 | for span in gold_spans: 79 | if span[2].is_terminal(): 80 | to_remove.append(span) 81 | continue 82 | key = (span[0], span[1], span[2].label) 83 | if key not in gold_span_set: 84 | gold_span_set[key] = 0 85 | gold_span_set[key] += 1 86 | for span in to_remove: 87 | gold_spans.remove(span) 88 | 89 | # Extra 90 | for span in test_spans: 91 | key = (span[0], span[1], span[2].label) 92 | if key in gold_span_set and gold_span_set[key] > 0: 93 | gold_span_set[key] -= 1 94 | else: 95 | ans.append(('extra', span[2].span, span[2].label, span[2])) 96 | 97 | # Missing and crossing 98 | for span in gold_spans: 99 | key = (span[0], span[1], span[2].label) 100 | if key in test_span_set and test_span_set[key] > 0: 101 | test_span_set[key] -= 1 102 | else: 103 | name = 'missing' 104 | for tspan in test_span_set: 105 | if tspan[0] < span[0] < tspan[1] < span[1]: 106 | name = 'crossing' 107 | break 108 | if span[0] < tspan[0] < span[1] < tspan[1]: 109 | name = 'crossing' 110 | break 111 | ans.append((name, span[2].span, span[2].label, span[2])) 112 | return ans 113 | 114 | def counts_for_prf(test, gold, include_root=False, include_terminals=False): 115 | # Note - currently assumes the roots match 116 | tcount = 0 117 | for node in test: 118 | if node.is_terminal() and not include_terminals: 119 | continue 120 | if node.parent is None and not include_root: 121 | continue 122 | tcount += 1 123 | gcount = 0 124 | for node in gold: 125 | if node.is_terminal() and not include_terminals: 126 | continue 127 | if node.parent is None and not include_root: 128 | continue 129 | gcount += 1 130 | match = tcount 131 | errors = Parse_Error_Set(gold, test, True) 132 | match = tcount - len(errors.extra) 133 | if include_terminals: 134 | match -= len(errors.POS) 135 | return match, gcount, tcount, len(errors.crossing), len(errors.POS) 136 | 137 | if __name__ == '__main__': 138 | print "No unit testing implemented for Error_Set" 139 | -------------------------------------------------------------------------------- /load_conll_2012/pstree.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from collections import defaultdict 5 | 6 | DEFAULT_LABEL = 'label_not_set' 7 | TRACE_LABEL = '-NONE-' 8 | 9 | class TreeIterator: 10 | '''Iterator for traversal of a tree. 11 | 12 | PSTree uses pre-order traversal by default, but this supports post-order too, e.g.: 13 | >>> tree = tree_from_text("(ROOT (S (NP-SBJ (NNP Ms.) (NNP Haag) ) (VP (VBZ plays) (NP (NNP Elianti) )) (. .) ))") 14 | >>> for node in TreeIterator(tree, 'post'): 15 | ... print node 16 | (NNP Ms.) 17 | (NNP Haag) 18 | (NP-SBJ (NNP Ms.) (NNP Haag)) 19 | (VBZ plays) 20 | (NNP Elianti) 21 | (NP (NNP Elianti)) 22 | (VP (VBZ plays) (NP (NNP Elianti))) 23 | (. .) 24 | (S (NP-SBJ (NNP Ms.) (NNP Haag)) (VP (VBZ plays) (NP (NNP Elianti))) (. .)) 25 | (ROOT (S (NP-SBJ (NNP Ms.) (NNP Haag)) (VP (VBZ plays) (NP (NNP Elianti))) (. .))) 26 | ''' 27 | def __init__(self, tree, order='pre'): 28 | self.tree = tree 29 | self.pos = [0] 30 | self.order = order 31 | 32 | def __iter__(self): 33 | return self 34 | 35 | def next(self): 36 | while True: 37 | if len(self.pos) == 0: 38 | raise StopIteration 39 | 40 | # For pre-order traversal, return nodes when first reached 41 | ans = None 42 | if self.order == 'pre' and self.pos[-1] == 0: 43 | ans = self.tree 44 | 45 | # Update internal state to point at the next node in the tree 46 | if self.pos[-1] < len(self.tree.subtrees): 47 | self.tree = self.tree.subtrees[self.pos[-1]] 48 | self.pos[-1] += 1 49 | self.pos.append(0) 50 | else: 51 | if self.order == 'post': 52 | ans = self.tree 53 | self.tree = self.tree.parent 54 | self.pos.pop() 55 | 56 | if ans is not None: 57 | return ans 58 | 59 | class PSTree(): 60 | '''Phrase Structure Tree 61 | 62 | >>> tree = tree_from_text("(ROOT (NP (NNP Newspaper)))") 63 | >>> print tree 64 | (ROOT (NP (NNP Newspaper))) 65 | >>> tree = tree_from_text("(ROOT (S (NP-SBJ (NNP Ms.) (NNP Haag) ) (VP (VBZ plays) (NP (NNP Elianti) )) (. .) ))") 66 | >>> print tree 67 | (ROOT (S (NP-SBJ (NNP Ms.) (NNP Haag)) (VP (VBZ plays) (NP (NNP Elianti))) (. .))) 68 | >>> print tree.word_yield() 69 | Ms. Haag plays Elianti . 70 | >>> tree = tree_from_text("(ROOT (NFP ...))") 71 | >>> print tree 72 | (ROOT (NFP ...)) 73 | >>> tree.word_yield() 74 | '...' 75 | >>> tree = tree_from_text("(VP (VBD was) (VP (VBN named) (S (NP-SBJ (-NONE- *-1) ) (NP-PRD (NP (DT a) (JJ nonexecutive) (NN director) ) (PP (IN of) (NP (DT this) (JJ British) (JJ industrial) (NN conglomerate) ))))))") 76 | >>> print tree 77 | (VP (VBD was) (VP (VBN named) (S (NP-SBJ (-NONE- *-1)) (NP-PRD (NP (DT a) (JJ nonexecutive) (NN director)) (PP (IN of) (NP (DT this) (JJ British) (JJ industrial) (NN conglomerate))))))) 78 | >>> tree.word_yield() 79 | 'was named *-1 a nonexecutive director of this British industrial conglomerate' 80 | ''' 81 | def __init__(self, word=None, label=DEFAULT_LABEL, span=(0, 0), parent=None, subtrees=None): 82 | self.word = word 83 | self.label = label 84 | self.span = span 85 | self.parent = parent 86 | self.subtrees = [] 87 | if subtrees is not None: 88 | self.subtrees = subtrees 89 | for subtree in subtrees: 90 | subtree.parent = self 91 | 92 | def __iter__(self): 93 | return TreeIterator(self, 'pre') 94 | 95 | def clone(self): 96 | ans = PSTree(self.word, self.label, self.span) 97 | for subtree in self.subtrees: 98 | subclone = subtree.clone() 99 | subclone.parent = ans 100 | ans.subtrees.append(subclone) 101 | return ans 102 | 103 | def is_terminal(self): 104 | '''Check if the tree has no children.''' 105 | return len(self.subtrees) == 0 106 | 107 | def is_trace(self): 108 | '''Check if this tree is the end of a trace.''' 109 | return self.label == TRACE_LABEL 110 | 111 | def root(self): 112 | '''Follow parents until a node is reached that has no parent.''' 113 | if self.parent is not None: 114 | return self.parent.root() 115 | else: 116 | return self 117 | 118 | def __repr__(self): 119 | '''Return a bracket notation style representation of the tree.''' 120 | ans = '(' 121 | if self.is_trace(): 122 | ans += TRACE_LABEL + ' ' + self.word 123 | elif self.is_terminal(): 124 | ans += self.label + ' ' + self.word 125 | else: 126 | ans += self.label 127 | for subtree in self.subtrees: 128 | ans += ' ' + subtree.__repr__() 129 | ans += ')' 130 | return ans 131 | 132 | def calculate_spans(self, left=0): 133 | '''Update the spans for every node in this tree.''' 134 | right = left 135 | if self.is_terminal(): 136 | right += 1 137 | for subtree in self.subtrees: 138 | right = subtree.calculate_spans(right) 139 | self.span = (left, right) 140 | return right 141 | 142 | def check_consistency(self): 143 | '''Check that the parents and spans are consistent with the tree 144 | structure.''' 145 | ans = True 146 | if len(self.subtrees) > 0: 147 | for i in xrange(len(self.subtrees)): 148 | subtree = self.subtrees[i] 149 | if subtree.parent != self: 150 | print "bad parent link" 151 | ans = False 152 | if i > 0 and self.subtrees[i - 1].span[1] != subtree.span[0]: 153 | print "Subtree spans don't match" 154 | ans = False 155 | ans = ans and subtree.check_consistency() 156 | if self.span != (self.subtrees[0].span[0], self.subtrees[-1].span[1]): 157 | print "Span doesn't match subtree spans" 158 | ans = False 159 | return ans 160 | 161 | def production_list(self, ans=None): 162 | '''Get a list of productions as: 163 | (node label, node span, ((subtree1, end1), (subtree2, end2)...))''' 164 | if ans is None: 165 | ans = [] 166 | if len(self.subtrees) > 0: 167 | cur = (self.label, self.span, tuple([(sub.label, sub.span[1]) for sub in self.subtrees])) 168 | ans.append(cur) 169 | for sub in self.subtrees: 170 | sub.production_list(ans) 171 | return ans 172 | 173 | def word_yield(self, span=None, as_list=False): 174 | '''Return the set of words at terminal nodes, either as a space separated 175 | string, or as a list.''' 176 | if self.is_terminal(): 177 | if span is None or span[0] <= self.span[0] < span[1]: 178 | if self.word is None: 179 | return None 180 | if as_list: 181 | return [self.word] 182 | else: 183 | return self.word 184 | else: 185 | return None 186 | else: 187 | ans = [] 188 | for subtree in self.subtrees: 189 | words = subtree.word_yield(span, as_list) 190 | if words is not None: 191 | if as_list: 192 | ans += words 193 | else: 194 | ans.append(words) 195 | if not as_list: 196 | ans = ' '.join(ans) 197 | return ans 198 | 199 | def node_dict(self, depth=0, node_dict=None): 200 | '''Get a dictionary of labelled nodes. Note that we use a dictionary to 201 | take into consideration unaries like (NP (NP ...))''' 202 | if node_dict is None: 203 | node_dict = defaultdict(lambda: []) 204 | for subtree in self.subtrees: 205 | subtree.node_dict(depth + 1, node_dict) 206 | node_dict[(self.label, self.span[0], self.span[1])].append(depth) 207 | return node_dict 208 | 209 | def get_nodes(self, request='all', start=-1, end=-1, node_list=None): 210 | '''Get the node(s) that have a given span. Unspecified endpoints are 211 | treated as wildcards. The request can be 'lowest', 'highest', or 'all'. 212 | For 'all', the list of nodes is in order from the highest first.''' 213 | if request not in ['highest', 'lowest', 'all']: 214 | raise Exception("%s is not a valid request" % str(request)) 215 | if request == 'lowest' and start < 0 and end < 0: 216 | raise Exception("Lowest is not well defined when both ends are wildcards") 217 | 218 | if request == 'all' and node_list is None: 219 | node_list = [] 220 | if request == 'highest': 221 | if self.span[0] == start or start < 0: 222 | if self.span[1] == end or end < 0: 223 | return self 224 | 225 | for subtree in self.subtrees: 226 | # Skip subtrees with no overlapping range 227 | if 0 < end <= subtree.span[0] or subtree.span[1] < start: 228 | continue 229 | ans = subtree.get_nodes(request, start, end, node_list) 230 | if ans is not None and request != 'all': 231 | return ans 232 | 233 | if self.span[0] == start or start < 0: 234 | if self.span[1] == end or end < 0: 235 | if request == 'lowest': 236 | return self 237 | elif request == 'all': 238 | node_list.insert(0, self) 239 | return node_list 240 | if request == 'all': 241 | return node_list 242 | else: 243 | return None 244 | 245 | def get_spanning_nodes(self, start, end, node_list=None): 246 | return_ans = False 247 | if node_list is None: 248 | return_ans = True 249 | node_list = [] 250 | 251 | if self.span[0] == start and self.span[1] <= end: 252 | node_list.append(self) 253 | start = self.span[1] 254 | else: 255 | for subtree in self.subtrees: 256 | if subtree.span[1] < start: 257 | continue 258 | start = subtree.get_spanning_nodes(start, end, node_list) 259 | if start == end: 260 | break 261 | 262 | if return_ans: 263 | if start == end: 264 | return node_list 265 | else: 266 | return None 267 | else: 268 | return start 269 | 270 | def tree_from_text(text, allow_empty_labels=False, allow_empty_words=False): 271 | '''Construct a PSTree from the provided string, which is assumed to represent 272 | a tree with nested round brackets. Nodes are labeled by the text between the 273 | open bracket and the next space (possibly an empty string). Words are the 274 | text after that space and before the close bracket.''' 275 | root = None 276 | cur = None 277 | pos = 0 278 | word = '' 279 | for char in text: 280 | # Consume random text up to the first '(' 281 | if cur is None: 282 | if char == '(': 283 | root = PSTree() 284 | cur = root 285 | continue 286 | 287 | if char == '(': 288 | word = word.strip() 289 | if cur.label is DEFAULT_LABEL: 290 | if len(word) == 0 and not allow_empty_labels: 291 | raise Exception("Empty label found\n%s" % text) 292 | cur.label = word 293 | word = '' 294 | if word != '': 295 | raise Exception("Stray '%s' while processing\n%s" % (word, text)) 296 | sub = PSTree() 297 | cur.subtrees.append(sub) 298 | sub.parent = cur 299 | cur = sub 300 | elif char == ')': 301 | word = word.strip() 302 | if word != '': 303 | if len(word) == 0 and not allow_empty_words: 304 | raise Exception("Empty word found\n%s" % text) 305 | cur.word = word 306 | word = '' 307 | cur.span = (pos, pos + 1) 308 | pos += 1 309 | else: 310 | cur.span = (cur.subtrees[0].span[0], cur.subtrees[-1].span[1]) 311 | cur = cur.parent 312 | elif char == ' ': 313 | if cur.label is DEFAULT_LABEL: 314 | if len(word) == 0 and not allow_empty_labels: 315 | raise Exception("Empty label found\n%s" % text) 316 | cur.label = word 317 | word = '' 318 | else: 319 | word += char 320 | else: 321 | word += char 322 | if cur is not None: 323 | raise Exception("Text did not include complete tree\n%s" % text) 324 | return root 325 | 326 | 327 | def clone_and_find(nodes): 328 | '''Clone the tree these nodes are in and finds the equivalent nodes in the 329 | new tree.''' 330 | return_list = True 331 | if type(nodes) != type([]): 332 | return_list = False 333 | nodes = [nodes] 334 | 335 | # Note the paths to the nodes 336 | paths = [] 337 | for node in nodes: 338 | paths.append([]) 339 | tree = node 340 | while tree.parent is not None: 341 | prev = tree 342 | tree = tree.parent 343 | paths[-1].append(tree.subtrees.index(prev)) 344 | 345 | # Duplicate and follow the path back to the equivalent node 346 | ntree = nodes[0].root().clone() 347 | ans = [] 348 | for path in paths: 349 | tree = ntree 350 | for index in path[::-1]: 351 | tree = tree.subtrees[index] 352 | ans.append(tree) 353 | if return_list: 354 | return ans 355 | else: 356 | return ans[0] 357 | 358 | 359 | if __name__ == '__main__': 360 | print "Running doctest" 361 | import doctest 362 | doctest.testmod() 363 | 364 | -------------------------------------------------------------------------------- /load_conll_2012/treebanks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # vim: set ts=2 sw=2 noet: 4 | 5 | # Canasai's addition begin 6 | import re 7 | # Canasai's addition end 8 | from pstree import * 9 | 10 | # TODO: Handle malformed input with trees that have random stuff instead of symbols 11 | # For chinese I found: 12 | ### leaf nodes split across lines: 13 | ### (blah 14 | ### )) 15 | ### lone tags: 16 | ### CP (IP... 17 | 18 | # At the moment the generator can't handle blank line indicating a 19 | # failed parse. 20 | 21 | ptb_tag_set = set(['S', 'SBAR', 'SBARQ', 'SINV', 'SQ', 'ADJP', 'ADVP', 'CONJP', 22 | 'FRAG', 'INTJ', 'LST', 'NAC', 'NP', 'NX', 'PP', 'PRN', 'PRT', 'QP', 'RRC', 23 | 'UCP', 'VP', 'WHADJP', 'WHADVP', 'WHNP', 'WHPP', 'X', 'NML']) 24 | 25 | word_to_word_mapping = { 26 | '{': '-LCB-', 27 | '}': '-RCB-' 28 | } 29 | word_to_POS_mapping = { 30 | '--': ':', 31 | '-': ':', 32 | ';': ':', 33 | ':': ':', 34 | '-LRB-': '-LRB-', 35 | '-RRB-': '-RRB-', 36 | '-LCB-': '-LRB-', 37 | '-RCB-': '-RRB-', 38 | '{': '-LRB-', 39 | '}': '-RRB-', 40 | } 41 | bugfix_word_to_POS = { 42 | 'Wa': 'NNP' 43 | } 44 | def ptb_cleaning(tree, in_place=True): 45 | '''Clean up some bugs/odd things in the PTB, and standardise punctuation.''' 46 | if not in_place: 47 | tree = tree.clone() 48 | for node in tree: 49 | # In a small number of cases multiple POS tags were assigned 50 | if '|' in node.label: 51 | if 'ADVP' in node.label: 52 | node.label = 'ADVP' 53 | else: 54 | node.label = node.label.split('|')[0] 55 | # Fix some issues with variation in output, and one error in the treebank 56 | # for a word with a punctuation POS 57 | # TODO: Look into the POS replacement leading to incorrect tagging for some 58 | # punctuation 59 | if node.word in word_to_word_mapping: 60 | node.word = word_to_word_mapping[node.word] 61 | if node.word in word_to_POS_mapping: 62 | node.label = word_to_POS_mapping[node.word] 63 | if node.word in bugfix_word_to_POS: 64 | node.label = bugfix_word_to_POS[node.word] 65 | return tree 66 | 67 | def remove_trivial_unaries(tree, in_place=True): 68 | '''Collapse A-over-A unary productions. 69 | 70 | >>> tree = tree_from_text("(ROOT (S (S (PP (PP (PP (IN By) (NP (CD 1997))))))))") 71 | >>> otree = remove_trivial_unaries(tree, False) 72 | >>> print otree 73 | (ROOT (S (PP (IN By) (NP (CD 1997))))) 74 | >>> print tree 75 | (ROOT (S (S (PP (PP (PP (IN By) (NP (CD 1997)))))))) 76 | >>> remove_trivial_unaries(tree) 77 | (ROOT (S (PP (IN By) (NP (CD 1997))))) 78 | ''' 79 | if in_place: 80 | if len(tree.subtrees) == 1 and tree.label == tree.subtrees[0].label: 81 | tree.subtrees = tree.subtrees[0].subtrees 82 | for subtree in tree.subtrees: 83 | subtree.parent = tree 84 | remove_trivial_unaries(tree, True) 85 | else: 86 | for subtree in tree.subtrees: 87 | remove_trivial_unaries(subtree, True) 88 | else: 89 | if len(tree.subtrees) == 1 and tree.label == tree.subtrees[0].label: 90 | return remove_trivial_unaries(tree.subtrees[0], False) 91 | subtrees = [remove_trivial_unaries(subtree, False) for subtree in tree.subtrees] 92 | tree = PSTree(tree.word, tree.label, tree.span, None, subtrees) 93 | for subtree in subtrees: 94 | subtree.parent = tree 95 | return tree 96 | 97 | def remove_nodes(tree, filter_func, in_place=True, preserve_subtrees=False, init_call=True): 98 | if filter_func(tree) and not preserve_subtrees: 99 | return None 100 | subtrees = [] 101 | for subtree in tree.subtrees: 102 | ans = remove_nodes(subtree, filter_func, in_place, preserve_subtrees, False) 103 | if ans is not None: 104 | if type(ans) == type([]): 105 | subtrees += ans 106 | else: 107 | subtrees.append(ans) 108 | if len(subtrees) == 0 and (not tree.is_terminal()): 109 | return None 110 | if filter_func(tree) and preserve_subtrees: 111 | return subtrees 112 | if in_place: 113 | tree.subtrees = subtrees 114 | for subtree in subtrees: 115 | subtree.parent = tree 116 | else: 117 | tree = PSTree(tree.word, tree.label, tree.span, None, subtrees) 118 | return tree 119 | 120 | def remove_traces(tree, in_place=True): 121 | '''Adjust the tree to remove traces. 122 | 123 | >>> tree = tree_from_text("(ROOT (S (PP (IN By) (NP (CD 1997))) (, ,) (NP (NP (ADJP (RB almost) (DT all)) (VBG remaining) (NNS uses)) (PP (IN of) (NP (JJ cancer-causing) (NN asbestos)))) (VP (MD will) (VP (VB be) (VP (VBN outlawed) (NP (-NONE- *-6))))) (. .)))") 124 | >>> remove_traces(tree, False) 125 | (ROOT (S (PP (IN By) (NP (CD 1997))) (, ,) (NP (NP (ADJP (RB almost) (DT all)) (VBG remaining) (NNS uses)) (PP (IN of) (NP (JJ cancer-causing) (NN asbestos)))) (VP (MD will) (VP (VB be) (VP (VBN outlawed)))) (. .))) 126 | ''' 127 | return remove_nodes(tree, PSTree.is_trace, in_place) 128 | 129 | def split_label_type_and_function(label): 130 | parts = label.split('=') 131 | if len(label) > 0 and label[0] != '-': 132 | cur = parts 133 | parts = [] 134 | for part in cur: 135 | parts += part.split('-') 136 | return parts 137 | 138 | def remove_function_tags(tree, in_place=True): 139 | '''Adjust the tree to remove function tags on labels. 140 | 141 | >>> tree = tree_from_text("(ROOT (S (NP-SBJ (NNP Ms.) (NNP Haag)) (VP (VBZ plays) (NP (NNP Elianti))) (. .)))") 142 | >>> remove_function_tags(tree, False) 143 | (ROOT (S (NP (NNP Ms.) (NNP Haag)) (VP (VBZ plays) (NP (NNP Elianti))) (. .))) 144 | 145 | # don't remove brackets 146 | >>> tree = tree_from_text("(ROOT (S (NP-SBJ (`` ``) (NP-TTL (NNP Funny) (NNP Business)) ('' '') (PRN (-LRB- -LRB-) (NP (NNP Soho)) (, ,) (NP (CD 228) (NNS pages)) (, ,) (NP ($ $) (CD 17.95)) (-RRB- -RRB-)) (PP (IN by) (NP (NNP Gary) (NNP Katzenstein)))) (VP (VBZ is) (NP-PRD (NP (NN anything)) (PP (RB but)))) (. .)))") 147 | >>> remove_function_tags(tree) 148 | (ROOT (S (NP (`` ``) (NP (NNP Funny) (NNP Business)) ('' '') (PRN (-LRB- -LRB-) (NP (NNP Soho)) (, ,) (NP (CD 228) (NNS pages)) (, ,) (NP ($ $) (CD 17.95)) (-RRB- -RRB-)) (PP (IN by) (NP (NNP Gary) (NNP Katzenstein)))) (VP (VBZ is) (NP (NP (NN anything)) (PP (RB but)))) (. .))) 149 | ''' 150 | label = split_label_type_and_function(tree.label)[0] 151 | if in_place: 152 | for subtree in tree.subtrees: 153 | remove_function_tags(subtree, True) 154 | tree.label = label 155 | else: 156 | subtrees = [remove_function_tags(subtree, False) for subtree in tree.subtrees] 157 | tree = PSTree(tree.word, label, tree.span, None, subtrees) 158 | for subtree in subtrees: 159 | subtree.parent = tree 160 | return tree 161 | 162 | # Applies rules to strip out the parts of the tree that are not used in the 163 | # standard evalb evaluation 164 | def apply_collins_rules(tree, in_place=True): 165 | '''Adjust the tree to remove parts not evaluated by the standard evalb 166 | config. 167 | 168 | # cutting punctuation and -X parts of labels 169 | >>> tree = tree_from_text("(ROOT (S (NP-SBJ (NNP Ms.) (NNP Haag) ) (VP (VBZ plays) (NP (NNP Elianti) )) (. .) ))") 170 | >>> apply_collins_rules(tree) 171 | (ROOT (S (NP (NNP Ms.) (NNP Haag)) (VP (VBZ plays) (NP (NNP Elianti))))) 172 | >>> print tree.word_yield() 173 | Ms. Haag plays Elianti 174 | 175 | # cutting nulls 176 | >>> tree = tree_from_text("(ROOT (S (PP-TMP (IN By) (NP (CD 1997))) (, ,) (NP-SBJ-6 (NP (ADJP (RB almost) (DT all)) (VBG remaining) (NNS uses)) (PP (IN of) (NP (JJ cancer-causing) (NN asbestos)))) (VP (MD will) (VP (VB be) (VP (VBN outlawed) (NP (-NONE- *-6))))) (. .)))") 177 | >>> apply_collins_rules(tree) 178 | (ROOT (S (PP (IN By) (NP (CD 1997))) (NP (NP (ADJP (RB almost) (DT all)) (VBG remaining) (NNS uses)) (PP (IN of) (NP (JJ cancer-causing) (NN asbestos)))) (VP (MD will) (VP (VB be) (VP (VBN outlawed)))))) 179 | 180 | # changing PRT to ADVP 181 | >>> tree = tree_from_text("(ROOT (S (NP-SBJ-41 (DT That) (NN fund)) (VP (VBD was) (VP (VBN put) (NP (-NONE- *-41)) (PRT (RP together)) (PP (IN by) (NP-LGS (NP (NNP Blackstone) (NNP Group)) (, ,) (NP (DT a) (NNP New) (NNP York) (NN investment) (NN bank)))))) (. .)))") 182 | >>> apply_collins_rules(tree) 183 | (ROOT (S (NP (DT That) (NN fund)) (VP (VBD was) (VP (VBN put) (ADVP (RP together)) (PP (IN by) (NP (NP (NNP Blackstone) (NNP Group)) (NP (DT a) (NNP New) (NNP York) (NN investment) (NN bank)))))))) 184 | 185 | # not removing brackets 186 | >>> tree = tree_from_text("(ROOT (S (NP-SBJ (`` ``) (NP-TTL (NNP Funny) (NNP Business)) ('' '') (PRN (-LRB- -LRB-) (NP (NNP Soho)) (, ,) (NP (CD 228) (NNS pages)) (, ,) (NP ($ $) (CD 17.95) (-NONE- *U*)) (-RRB- -RRB-)) (PP (IN by) (NP (NNP Gary) (NNP Katzenstein)))) (VP (VBZ is) (NP-PRD (NP (NN anything)) (PP (RB but) (NP (-NONE- *?*))))) (. .)))") 187 | >>> apply_collins_rules(tree) 188 | (ROOT (S (NP (NP (NNP Funny) (NNP Business)) (PRN (-LRB- -LRB-) (NP (NNP Soho)) (NP (CD 228) (NNS pages)) (NP ($ $) (CD 17.95)) (-RRB- -RRB-)) (PP (IN by) (NP (NNP Gary) (NNP Katzenstein)))) (VP (VBZ is) (NP (NP (NN anything)) (PP (RB but)))))) 189 | ''' 190 | tree = tree if in_place else tree.clone() 191 | remove_traces(tree, True) 192 | remove_function_tags(tree, True) 193 | ptb_cleaning(tree, True) 194 | 195 | # Remove Puncturation 196 | ### words_to_ignore = set(["'","`","''","``","--",":",";","-",",",".","...",".","?","!"]) 197 | labels_to_ignore = ["-NONE-",",",":","``","''","."] 198 | remove_nodes(tree, lambda(t): t.label in labels_to_ignore, True) 199 | 200 | # Set all PRTs to be ADVPs 201 | POS_to_convert = {'PRT': 'ADVP'} 202 | for node in tree: 203 | if node.label in POS_to_convert: 204 | node.label = POS_to_convert[node.label] 205 | 206 | tree.calculate_spans() 207 | return tree 208 | 209 | def homogenise_tree(tree, tag_set=ptb_tag_set): 210 | '''Change the top of the tree to be of a consistent form. 211 | 212 | >>> tree = tree_from_text("( (S (NP (NNP Example))))", True) 213 | >>> homogenise_tree(tree) 214 | (ROOT (S (NP (NNP Example)))) 215 | >>> tree = tree_from_text("( (ROOT (S (NP (NNP Example))) ) )", True) 216 | >>> homogenise_tree(tree) 217 | (ROOT (S (NP (NNP Example)))) 218 | >>> tree = tree_from_text("(S1 (S (NP (NNP Example))))") 219 | >>> homogenise_tree(tree) 220 | (ROOT (S (NP (NNP Example)))) 221 | ''' 222 | orig = tree 223 | tree = tree.root() 224 | if tree.label != 'ROOT': 225 | while split_label_type_and_function(tree.label)[0] not in tag_set: 226 | if len(tree.subtrees) > 1: 227 | break 228 | elif tree.is_terminal(): 229 | raise Exception("Tree has no labels in the tag set\n%s" % orig.__repr__()) 230 | tree = tree.subtrees[0] 231 | if split_label_type_and_function(tree.label)[0] not in tag_set: 232 | tree.label = 'ROOT' 233 | else: 234 | root = PSTree(None, 'ROOT', tree.span, None, []) 235 | root.subtrees.append(tree) 236 | tree.parent = root 237 | tree = root 238 | return tree 239 | 240 | def ptb_read_tree(source, return_empty=False, allow_empty_labels=False, allow_empty_words=False, blank_line_coverage=False): 241 | '''Read a single tree from the given PTB file. 242 | 243 | The function reads a character at a time, stopping as soon as a tree can be 244 | constructed, so multiple trees on a sinlge line are manageable. 245 | 246 | >>> from StringIO import StringIO 247 | >>> file_text = """(ROOT (S 248 | ... (NP-SBJ (NNP Scotty) ) 249 | ... (VP (VBD did) (RB not) 250 | ... (VP (VB go) 251 | ... (ADVP (RB back) ) 252 | ... (PP (TO to) 253 | ... (NP (NN school) )))) 254 | ... (. .) ))""" 255 | >>> in_file = StringIO(file_text) 256 | >>> ptb_read_tree(in_file) 257 | (ROOT (S (NP-SBJ (NNP Scotty)) (VP (VBD did) (RB not) (VP (VB go) (ADVP (RB back)) (PP (TO to) (NP (NN school))))) (. .)))''' 258 | cur_text = '' 259 | depth = 0 260 | while True: 261 | char = source.read(1) 262 | if char == '': 263 | return None 264 | break 265 | if char == '\n' and cur_text == ' ' and blank_line_coverage: 266 | return "Empty" 267 | if char in '\n\t': 268 | char = ' ' 269 | cur_text += char 270 | if char == '(': 271 | depth += 1 272 | elif char == ')': 273 | depth -= 1 274 | if depth == 0: 275 | if '()' in cur_text: 276 | if return_empty: 277 | return "Empty" 278 | cur_text = '' 279 | continue 280 | if '(' in cur_text: 281 | break 282 | 283 | tree = tree_from_text(cur_text, allow_empty_labels, allow_empty_words) 284 | ptb_cleaning(tree) 285 | return tree 286 | 287 | def conll_read_tree(source, return_empty=False, allow_empty_labels=False, allow_empty_words=False, blank_line_coverage=False): 288 | '''Read a single tree from the given CoNLL Shared Task OntoNotes data file. 289 | 290 | >>> from StringIO import StringIO 291 | >>> file_text = """#begin document (nw/wsj/00/wsj_0020) 292 | ... nw/wsj/00/wsj_0020 0 0 They PRP (TOP_(S_(NP_*) - - - - * (ARG1*) * (0) 293 | ... nw/wsj/00/wsj_0020 0 1 will MD (VP_* - - - - * (ARGM-M OD*) * - 294 | ... nw/wsj/00/wsj_0020 0 2 remain VB (VP_* remain 01 1 - * ( V*) * - 295 | ... nw/wsj/00/wsj_0020 0 3 on IN (PP_* - - - - * (AR G3* * - 296 | ... nw/wsj/00/wsj_0020 0 4 a DT (NP_(NP_* - - - - * * (ARG2* - 297 | ... nw/wsj/00/wsj_0020 0 5 lower JJR (NML_* - - - - * * * - 298 | ... nw/wsj/00/wsj_0020 0 6 - HYPH * - - - - * * * - 299 | ... nw/wsj/00/wsj_0020 0 7 priority NN *) - - - - * * * - 300 | ... nw/wsj/00/wsj_0020 0 8 list NN *) - - 1 - * * *) - 301 | ... nw/wsj/00/wsj_0020 0 9 that WDT (SBAR_(WHNP_*) - - - - * * * - 302 | ... nw/wsj/00/wsj_0020 0 10 includes VBZ (S_(VP_* - - 1 - * * (V*) - 303 | ... nw/wsj/00/wsj_0020 0 11 17 CD (NP_* - - - - (CARDINAL) * (ARG1* (10 304 | ... nw/wsj/00/wsj_0020 0 12 other JJ * - - - - * * * - 305 | ... nw/wsj/00/wsj_0020 0 13 countries NNS *)))))))) - - 3 - * *) *) 10) 306 | ... nw/wsj/00/wsj_0020 0 14 . . *)) - - - - * * * - 307 | ... 308 | ... """ 309 | >>> in_file = StringIO(file_text) 310 | >>> tree = conll_read_tree(in_file) 311 | >>> print tree 312 | (TOP (S (NP (PRP They)) (VP (MD will) (VP (VB remain) (PP (IN on) (NP (NP (DT a) (NML (JJR lower) (HYPH -) (NN priority)) (NN list)) (SBAR (WHNP (WDT that)) (S (VP (VBZ includes) (NP (CD 17) (JJ other) (NNS countries))))))))) (. .)))''' 313 | cur_text = [] 314 | while True: 315 | line = source.readline() 316 | # Check if we are out of input 317 | if line == '': 318 | return None 319 | # strip whitespace and see if this is then end of the parse 320 | line = line.strip() 321 | if line == '': 322 | break 323 | cur_text.append(line) 324 | 325 | text = '' 326 | for line in cur_text: 327 | # Canasai's addition begin 328 | line = line.strip() 329 | # Canasai's addition end 330 | if len(line) == 0 or line[0] == '#': 331 | continue 332 | # Canasai comment out: line = line.split() 333 | # Canasai's addition begin 334 | line = re.split(r'\s+', line) 335 | # Canasai's addition end 336 | word = line[3] 337 | pos = line[4] 338 | tree = line[5] 339 | tree = tree.split('*') 340 | text += '%s(%s %s)%s' % (tree[0], pos, word, tree[1]) 341 | return tree_from_text(text) 342 | 343 | def generate_trees(source, tree_reader=ptb_read_tree, max_sents=-1, return_empty=False, allow_empty_labels=False, allow_empty_words=False): 344 | '''Read trees from the given file (opening the file if only a string is given). 345 | 346 | >>> from StringIO import StringIO 347 | >>> file_text = """(ROOT (S 348 | ... (NP-SBJ (NNP Scotty) ) 349 | ... (VP (VBD did) (RB not) 350 | ... (VP (VB go) 351 | ... (ADVP (RB back) ) 352 | ... (PP (TO to) 353 | ... (NP (NN school) )))) 354 | ... (. .) )) 355 | ... 356 | ... (ROOT (S 357 | ... (NP-SBJ (DT The) (NN bandit) ) 358 | ... (VP (VBZ laughs) 359 | ... (PP (IN in) 360 | ... (NP (PRP$ his) (NN face) ))) 361 | ... (. .) ))""" 362 | >>> in_file = StringIO(file_text) 363 | >>> for tree in generate_trees(in_file): 364 | ... print tree 365 | (ROOT (S (NP-SBJ (NNP Scotty)) (VP (VBD did) (RB not) (VP (VB go) (ADVP (RB back)) (PP (TO to) (NP (NN school))))) (. .))) 366 | (ROOT (S (NP-SBJ (DT The) (NN bandit)) (VP (VBZ laughs) (PP (IN in) (NP (PRP$ his) (NN face)))) (. .)))''' 367 | if type(source) == type(''): 368 | source = open(source) 369 | count = 0 370 | while True: 371 | tree = tree_reader(source, return_empty, allow_empty_labels, allow_empty_words) 372 | if tree == "Empty": 373 | yield None 374 | continue 375 | if tree is None: 376 | return 377 | yield tree 378 | count += 1 379 | if count >= max_sents > 0: 380 | return 381 | 382 | def read_trees(source, tree_reader=ptb_read_tree, max_sents=-1, return_empty=False): 383 | return [tree for tree in generate_trees(source, tree_reader, max_sents, return_empty)] 384 | 385 | if __name__ == '__main__': 386 | print "Running doctest" 387 | import doctest 388 | doctest.testmod() 389 | 390 | -------------------------------------------------------------------------------- /ontochinese.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import sys 4 | import codecs 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import mafan 9 | 10 | sys.path.append("/home/danniel/Desktop/CONLL2012-intern") 11 | from load_conll import load_data 12 | from pstree import PSTree 13 | 14 | from rnn import Node 15 | 16 | dataset = "ontochinese" 17 | character_file = os.path.join(dataset, "character.txt") 18 | word_file = os.path.join(dataset, "word.txt") 19 | pos_file = os.path.join(dataset, "pos.txt") 20 | ne_file = os.path.join(dataset, "ne.txt") 21 | pretrained_word_file = os.path.join(dataset, "word.npy") 22 | pretrained_embedding_file = os.path.join(dataset, "embedding.npy") 23 | 24 | data_path_prefix = "/home/danniel/Desktop/CONLL2012-intern/conll-2012/v4/data" 25 | test_auto_data_path_prefix = "/home/danniel/Downloads/wu_conll_test/v9/data" 26 | data_path_suffix = "data/chinese/annotations" 27 | 28 | glove_file = "/home/danniel/Downloads/Glove_CNA_ASBC_300d.vec" 29 | 30 | def log(msg): 31 | sys.stdout.write(msg) 32 | sys.stdout.flush() 33 | return 34 | 35 | def read_list_file(file_path): 36 | log("Read %s..." % file_path) 37 | 38 | with codecs.open(file_path, "r", encoding="utf8") as f: 39 | line_list = f.read().splitlines() 40 | line_to_index = {line: index for index, line in enumerate(line_list)} 41 | 42 | log(" %d lines\n" % len(line_to_index)) 43 | return line_list, line_to_index 44 | 45 | def extract_vocabulary_and_alphabet(): 46 | log("extract_vocabulary_and_alphabet()...") 47 | 48 | character_set = set() 49 | word_set = set() 50 | for split in ["train", "development", "test"]: 51 | full_path = os.path.join(data_path_prefix, split, data_path_suffix) 52 | config = {"file_suffix": "gold_conll", "dir_prefix": full_path} 53 | raw_data = load_data(config) 54 | for document in raw_data: 55 | for part in raw_data[document]: 56 | for index, sentence in enumerate(raw_data[document][part]["text"]): 57 | for word in sentence: 58 | for character in word: 59 | character_set.add(character) 60 | word_set.add(word) 61 | 62 | with codecs.open(word_file, "w", encoding="utf8") as f: 63 | for word in sorted(word_set): 64 | f.write(word + '\n') 65 | 66 | with codecs.open(character_file, "w", encoding="utf8") as f: 67 | for character in sorted(character_set): 68 | f.write(character + '\n') 69 | 70 | log(" done\n") 71 | return 72 | 73 | def extract_glove_embeddings(): 74 | log("extract_glove_embeddings()...") 75 | 76 | _, word_to_index = read_list_file(word_file) 77 | word_list = [] 78 | embedding_list = [] 79 | with codecs.open(glove_file, "r", encoding="utf8") as f: 80 | for line in f: 81 | line = line.strip().split() 82 | word = mafan.simplify(line[0]) 83 | if word not in word_to_index: continue 84 | embedding = np.array([float(i) for i in line[1:]]) 85 | word_list.append(word) 86 | embedding_list.append(embedding) 87 | 88 | np.save(pretrained_word_file, word_list) 89 | np.save(pretrained_embedding_file, embedding_list) 90 | 91 | log(" %d pre-trained words\n" % len(word_list)) 92 | return 93 | 94 | def construct_node(node, tree, ner_raw_data, head_raw_data, text_raw_data, 95 | character_to_index, word_to_index, pos_to_index, 96 | pos_count, ne_count, pos_ne_count): 97 | pos = tree.label 98 | word = tree.word 99 | span = tree.span 100 | head = tree.head if hasattr(tree, "head") else head_raw_data[(span, pos)][1] 101 | ne = ner_raw_data[span] if span in ner_raw_data else "NONE" 102 | 103 | # Process pos info 104 | node.pos = pos 105 | node.pos_index = pos_to_index[pos] 106 | pos_count[pos] += 1 107 | 108 | # Process word info 109 | node.word_split = [character_to_index[character] for character in word] if word else [] 110 | node.word_index = word_to_index[word] if word else -1 111 | 112 | # Process head info 113 | node.head_split = [character_to_index[character] for character in head] 114 | node.head_index = word_to_index[head] 115 | 116 | # Process ne info 117 | node.ne = ne 118 | if not node.parent or node.parent.span!=span: 119 | ne_count[ne] += 1 120 | if ne != "NONE": 121 | pos_ne_count[pos] += 1 122 | 123 | # Process span info 124 | node.span = span 125 | 126 | # Binarize children 127 | if len(tree.subtrees) > 2: 128 | side_child_pos = tree.subtrees[-1].label 129 | side_child_span = tree.subtrees[-1].span 130 | side_child_head = head_raw_data[(side_child_span, side_child_pos)][1] 131 | if side_child_head != head: 132 | sub_subtrees = tree.subtrees[:-1] 133 | else: 134 | sub_subtrees = tree.subtrees[1:] 135 | new_span = (sub_subtrees[0].span[0], sub_subtrees[-1].span[1]) 136 | new_tree = PSTree(label=pos, span=new_span, subtrees=sub_subtrees) 137 | new_tree.head = head 138 | if side_child_head != head: 139 | tree.subtrees = [new_tree, tree.subtrees[-1]] 140 | else: 141 | tree.subtrees = [tree.subtrees[0], new_tree] 142 | 143 | # Process children 144 | nodes = 1 145 | for subtree in tree.subtrees: 146 | child = Node() 147 | node.add_child(child) 148 | child_nodes = construct_node(child, subtree, ner_raw_data, head_raw_data, text_raw_data, 149 | character_to_index, word_to_index, pos_to_index, 150 | pos_count, ne_count, pos_ne_count) 151 | nodes += child_nodes 152 | return nodes 153 | 154 | def get_tree_data(raw_data, character_to_index, word_to_index, pos_to_index): 155 | log("get_tree_data()...") 156 | """ Get tree structured data from CoNLL 2012 157 | 158 | Stores into Node data structure 159 | """ 160 | tree_list = [] 161 | ner_list = [] 162 | word_count = 0 163 | pos_count = defaultdict(lambda: 0) 164 | ne_count = defaultdict(lambda: 0) 165 | pos_ne_count = defaultdict(lambda: 0) 166 | 167 | for document in raw_data["auto"]: 168 | for part in raw_data["auto"][document]: 169 | 170 | ner_raw_data = defaultdict(lambda: {}) 171 | for k, v in raw_data["gold"][document][part]["ner"].iteritems(): 172 | ner_raw_data[k[0]][(k[1], k[2])] = v 173 | 174 | for index, parse in enumerate(raw_data["auto"][document][part]["parses"]): 175 | text_raw_data = raw_data["auto"][document][part]["text"][index] 176 | word_count += len(text_raw_data) 177 | 178 | if parse.subtrees[0].label == "NOPARSE": continue 179 | head_raw_data = raw_data["auto"][document][part]["heads"][index] 180 | 181 | root_node = Node() 182 | nodes = construct_node( 183 | root_node, parse, ner_raw_data[index], head_raw_data, text_raw_data, 184 | character_to_index, word_to_index, pos_to_index, 185 | pos_count, ne_count, pos_ne_count) 186 | root_node.nodes = nodes 187 | 188 | tree_list.append(root_node) 189 | ner_list.append(ner_raw_data[index]) 190 | 191 | log(" %d sentences\n" % len(tree_list)) 192 | return tree_list, ner_list, word_count, pos_count, ne_count, pos_ne_count 193 | 194 | def label_tree_data(node, pos_to_index, ne_to_index): 195 | node.y = ne_to_index[node.ne] 196 | # node.y = ne_to_index[":".join(node.ner)] 197 | 198 | for child in node.child_list: 199 | label_tree_data(child, pos_to_index, ne_to_index) 200 | return 201 | 202 | def read_dataset(data_split_list = ["train", "validate", "test"]): 203 | # Read all raw data 204 | annotation_method_list = ["gold", "auto"] 205 | raw_data = {} 206 | for split in data_split_list: 207 | raw_data[split] = {} 208 | for method in annotation_method_list: 209 | if split == "test" and method == "auto": 210 | full_path = os.path.join(test_auto_data_path_prefix, "test", data_path_suffix) 211 | else: 212 | if split == "validate": 213 | data_path_root = "development" 214 | else: 215 | data_path_root = split 216 | full_path = os.path.join(data_path_prefix, data_path_root, data_path_suffix) 217 | config = {"file_suffix": "%s_conll" % method, "dir_prefix": full_path} 218 | raw_data[split][method] = load_data(config) 219 | 220 | # Read lists of annotations 221 | character_list, character_to_index = read_list_file(character_file) 222 | word_list, word_to_index = read_list_file(word_file) 223 | pos_list, pos_to_index = read_list_file(pos_file) 224 | ne_list, ne_to_index = read_list_file(ne_file) 225 | 226 | # Build a tree structure for each sentence 227 | data = {} 228 | word_count = {} 229 | pos_count = {} 230 | ne_count = {} 231 | pos_ne_count = {} 232 | for split in data_split_list: 233 | (tree_list, ner_list, 234 | word_count[split], pos_count[split], ne_count[split], pos_ne_count[split]) = ( 235 | get_tree_data(raw_data[split], character_to_index, word_to_index, pos_to_index)) 236 | sentences = len(tree_list) 237 | nodes = sum(pos_count[split].itervalues()) 238 | nes = sum(pos_ne_count[split].itervalues()) 239 | data[split] = [tree_list, ner_list] 240 | log("<%s>\n %d sentences; %d nodes; %d named entities\n" 241 | % (split, sentences, nodes, nes)) 242 | 243 | # Show POS distribution 244 | total_pos_count = defaultdict(lambda: 0) 245 | for split in data_split_list: 246 | for pos in pos_count[split]: 247 | total_pos_count[pos] += pos_count[split][pos] 248 | nodes = sum(total_pos_count.itervalues()) 249 | print "\nTotal %d nodes" % nodes 250 | print "-"*50 + "\n POS count ratio\n" + "-"*50 251 | for pos, count in sorted(total_pos_count.iteritems(), key=lambda x: x[1], reverse=True): 252 | print "%6s %7d %5.1f%%" % (pos, count, count*100./nodes) 253 | 254 | # Show number of tokens and NEs in each split 255 | reals = 0 256 | split_nes_dict = {} 257 | for split in data_split_list: 258 | if split == "test": continue 259 | split_nes_dict[split] = sum(len(ner) for ner in data[split][1]) 260 | reals += split_nes_dict[split] 261 | print "\nTotal %d named entities" % reals 262 | print "-"*50 + "\n split token NE\n" + "-"*50 263 | for split in data_split_list: 264 | if split == "test": continue 265 | print "%12s %7d %6d" % (split, word_count[split], split_nes_dict[split]) 266 | 267 | # Show NE distribution 268 | total_ne_count = defaultdict(lambda: 0) 269 | for split in data_split_list: 270 | if split == "test": continue 271 | for ne in ne_count[split]: 272 | if ne == "NONE": continue 273 | total_ne_count[ne] += ne_count[split][ne] 274 | nes = sum(total_ne_count.itervalues()) 275 | print "\nTotal %d spanned named entities" % nes 276 | print "-"*50 + "\n NE count ratio\n" + "-"*50 277 | for ne, count in sorted(total_ne_count.iteritems(), key=lambda x: x[1], reverse=True): 278 | print "%12s %6d %5.1f%%" % (ne, count, count*100./nes) 279 | 280 | # Show POS-NE distribution 281 | total_pos_ne_count = defaultdict(lambda: 0) 282 | for split in data_split_list: 283 | if split == "test": continue 284 | for pos in pos_ne_count[split]: 285 | total_pos_ne_count[pos] += pos_ne_count[split][pos] 286 | print "-"*50 + "\n POS NE total ratio\n" + "-"*50 287 | for pos, count in sorted(total_pos_ne_count.iteritems(), key=lambda x: x[1], reverse=True): 288 | total = total_pos_count[pos] 289 | print "%6s %6d %7d %5.1f%%" % (pos, count, total, count*100./total) 290 | 291 | # Compute the mapping to labels 292 | ne_to_index["NONE"] = len(ne_to_index) 293 | 294 | # Add label to nodes 295 | for split in data_split_list: 296 | for tree in data[split][0]: 297 | label_tree_data(tree, pos_to_index, ne_to_index) 298 | return (data, word_list, ne_list, 299 | len(character_to_index), len(pos_to_index), len(ne_to_index)) 300 | 301 | if __name__ == "__main__": 302 | #extract_vocabulary_and_alphabet() 303 | #extract_glove_embeddings() 304 | read_dataset() 305 | exit() 306 | 307 | 308 | 309 | 310 | 311 | -------------------------------------------------------------------------------- /ontochinese/ne.txt: -------------------------------------------------------------------------------- 1 | PERSON 2 | NORP 3 | FAC 4 | ORG 5 | GPE 6 | LOC 7 | PRODUCT 8 | EVENT 9 | WORK_OF_ART 10 | LAW 11 | LANGUAGE 12 | DATE 13 | TIME 14 | PERCENT 15 | MONEY 16 | QUANTITY 17 | ORDINAL 18 | CARDINAL 19 | -------------------------------------------------------------------------------- /ontochinese/pos.txt: -------------------------------------------------------------------------------- 1 | NP 2 | VP 3 | IP 4 | NN 5 | PU 6 | VV 7 | AD 8 | ADVP 9 | TOP 10 | NR 11 | PN 12 | QP 13 | CP 14 | PP 15 | P 16 | CD 17 | DNP 18 | DEG 19 | M 20 | CLP 21 | ADJP 22 | JJ 23 | DEC 24 | DT 25 | VC 26 | VA 27 | DP 28 | UCP 29 | NT 30 | LCP 31 | LC 32 | SP 33 | AS 34 | CC 35 | IJ 36 | FLR 37 | VE 38 | INTJ 39 | FRAG 40 | DFL 41 | VRD 42 | MSP 43 | CS 44 | OD 45 | DVP 46 | DEV 47 | VPT 48 | BA 49 | VNV 50 | PRN 51 | ETC 52 | SB 53 | INC 54 | VCD 55 | DER 56 | VSB 57 | LB 58 | LST 59 | VCP 60 | URL 61 | MBD 62 | FW 63 | OTH 64 | INF 65 | SKIP 66 | ON 67 | -------------------------------------------------------------------------------- /ontonotes.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import sys 4 | import codecs 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | 9 | sys.path.append("/home/danniel/Desktop/CONLL2012-intern") 10 | from load_conll import load_data 11 | from pstree import PSTree 12 | 13 | from rnn import Node 14 | 15 | dataset = "ontonotes" 16 | character_file = os.path.join(dataset, "character.txt") 17 | word_file = os.path.join(dataset, "word.txt") 18 | pos_file = os.path.join(dataset, "pos.txt") 19 | ne_file = os.path.join(dataset, "ne.txt") 20 | pretrained_word_file = os.path.join(dataset, "word.npy") 21 | pretrained_embedding_file = os.path.join(dataset, "embedding.npy") 22 | 23 | data_path_prefix = "/home/danniel/Desktop/CONLL2012-intern/conll-2012/v4/data" 24 | test_auto_data_path_prefix = "/home/danniel/Downloads/wu_conll_test/v9/data" 25 | data_path_suffix = "data/english/annotations" 26 | 27 | glove_file = "/home/danniel/Downloads/glove.840B.300d.txt" 28 | 29 | senna_path = "/home/danniel/Downloads/senna/hash" 30 | dbpedia_path = "/home/danniel/Desktop/dbpedia_lexicon" 31 | lexicon_meta_list = [ 32 | {"ne": "PERSON", "encoding": "iso8859-15", "clean": os.path.join(dataset, "senna_PER.txt"), "raw": os.path.join(senna_path, "ner.per.lst")}, 33 | {"ne": "ORG", "encoding": "iso8859-15", "clean": os.path.join(dataset, "senna_ORG.txt"), "raw": os.path.join(senna_path, "ner.org.lst")}, 34 | {"ne": "LOC", "encoding": "iso8859-15", "clean": os.path.join(dataset, "senna_LOC.txt"), "raw": os.path.join(senna_path, "ner.loc.lst")} 35 | #{"ne": "WORK_OF_ART", "encoding": "iso8859-15", "clean": os.path.join(dataset, "senna_WOR.txt"), "raw": os.path.join(senna_path, "ner.misc.lst")}, 36 | #{"ne": "PERSON", "encoding": "utf8", "clean": os.path.join(dataset, "dbpedia_PER.txt"), "raw": os.path.join(dbpedia_path, "dbpedia_person.txt")}, 37 | #{"ne": "ORG", "encoding": "utf8", "clean": os.path.join(dataset, "dbpedia_ORG.txt"), "raw": os.path.join(dbpedia_path, "dbpedia_organisation.txt")}, 38 | #{"ne": "LOC", "encoding": "utf8", "clean": os.path.join(dataset, "dbpedia_LOC.txt"), "raw": os.path.join(dbpedia_path, "dbpedia_place.txt")} 39 | #{"ne": "WORK_OF_ART", "encoding": "utf8", "clean": os.path.join(dataset, "dbpedia_WOR.txt"), "raw": os.path.join(dbpedia_path, "dbpedia_work.txt")} 40 | ] 41 | 42 | def log(msg): 43 | sys.stdout.write(msg) 44 | sys.stdout.flush() 45 | return 46 | 47 | def read_list_file(file_path, encoding="utf8"): 48 | log("Read %s..." % file_path) 49 | 50 | with codecs.open(file_path, "r", encoding=encoding) as f: 51 | line_list = f.read().splitlines() 52 | line_to_index = {line: index for index, line in enumerate(line_list)} 53 | 54 | log(" %d lines\n" % len(line_to_index)) 55 | return line_list, line_to_index 56 | 57 | def extract_vocabulary_and_alphabet(): 58 | log("extract_vocabulary_and_alphabet()...") 59 | 60 | character_set = set() 61 | word_set = set() 62 | for split in ["train", "development", "test"]: 63 | full_path = os.path.join(data_path_prefix, split, data_path_suffix) 64 | config = {"file_suffix": "gold_conll", "dir_prefix": full_path} 65 | raw_data = load_data(config) 66 | for document in raw_data: 67 | for part in raw_data[document]: 68 | for index, sentence in enumerate(raw_data[document][part]["text"]): 69 | for word in sentence: 70 | for character in word: 71 | character_set.add(character) 72 | word_set.add(word) 73 | 74 | with codecs.open(word_file, "w", encoding="utf8") as f: 75 | for word in sorted(word_set): 76 | f.write(word + '\n') 77 | 78 | with codecs.open(character_file, "w", encoding="utf8") as f: 79 | for character in sorted(character_set): 80 | f.write(character + '\n') 81 | 82 | log(" done\n") 83 | return 84 | 85 | def extract_glove_embeddings(): 86 | log("extract_glove_embeddings()...") 87 | 88 | _, word_to_index = read_list_file(word_file) 89 | word_list = [] 90 | embedding_list = [] 91 | with open(glove_file, "r") as f: 92 | for line in f: 93 | line = line.strip().split() 94 | word = line[0] 95 | if word not in word_to_index: continue 96 | embedding = np.array([float(i) for i in line[1:]]) 97 | word_list.append(word) 98 | embedding_list.append(embedding) 99 | 100 | np.save(pretrained_word_file, word_list) 101 | np.save(pretrained_embedding_file, embedding_list) 102 | 103 | log(" %d pre-trained words\n" % len(word_list)) 104 | return 105 | 106 | def traverse_tree(tree, ner_raw_data, head_raw_data, text_raw_data, lexicon_list, span_set): 107 | pos = tree.label 108 | span = tree.span 109 | head = tree.head if hasattr(tree, "head") else head_raw_data[(span, pos)][1] 110 | ne = ner_raw_data[span] if span in ner_raw_data else "NONE" 111 | constituent = " ".join(text_raw_data[span[0]:span[1]]).lower() 112 | 113 | span_set.add(span) 114 | for index, lexicon in enumerate(lexicon_list): 115 | if constituent in lexicon: 116 | lexicon[constituent][0] += 1 117 | if ne == lexicon_meta_list[index]["ne"]: 118 | lexicon[constituent][1] += 1 119 | 120 | # Binarize children 121 | if len(tree.subtrees) > 2: 122 | side_child_pos = tree.subtrees[-1].label 123 | side_child_span = tree.subtrees[-1].span 124 | side_child_head = head_raw_data[(side_child_span, side_child_pos)][1] 125 | if side_child_head != head: 126 | sub_subtrees = tree.subtrees[:-1] 127 | else: 128 | sub_subtrees = tree.subtrees[1:] 129 | new_span = (sub_subtrees[0].span[0], sub_subtrees[-1].span[1]) 130 | new_tree = PSTree(label=pos, span=new_span, subtrees=sub_subtrees) 131 | new_tree.head = head 132 | if side_child_head != head: 133 | tree.subtrees = [new_tree, tree.subtrees[-1]] 134 | else: 135 | tree.subtrees = [tree.subtrees[0], new_tree] 136 | 137 | # Process children 138 | for subtree in tree.subtrees: 139 | traverse_tree(subtree, ner_raw_data, head_raw_data, text_raw_data, lexicon_list, span_set) 140 | return 141 | 142 | def traverse_pyramid(ner_raw_data, text_raw_data, lexicon_list, span_set): 143 | max_dense_span = 3 144 | # Start from bigram, since all unigrams are already covered by parses 145 | for span_length in range(2, 1+max_dense_span): 146 | for span_start in range(0, 1+len(text_raw_data)-span_length): 147 | span = (span_start, span_start+span_length) 148 | if span in span_set: continue 149 | ne = ner_raw_data[span] if span in ner_raw_data else "NONE" 150 | constituent = " ".join(text_raw_data[span[0]:span[1]]).lower() 151 | 152 | for index, lexicon in enumerate(lexicon_list): 153 | if constituent in lexicon: 154 | lexicon[constituent][0] += 1 155 | if ne == lexicon_meta_list[index]["ne"]: 156 | lexicon[constituent][1] += 1 157 | return 158 | 159 | def extract_clean_lexicon(): 160 | lexicon_list = [] 161 | 162 | print "\nReading raw lexicons..." 163 | for meta in lexicon_meta_list: 164 | lexicon_list.append(read_list_file(meta["raw"], encoding=meta["encoding"])[1]) 165 | print "-"*50 + "\n ne phrases shortest\n" + "-"*50 166 | for index, lexicon in enumerate(lexicon_list): 167 | for phrase in lexicon: 168 | lexicon[phrase] = [0.,0.] 169 | shortest_phrase = min(lexicon.iterkeys(), key=lambda phrase: len(phrase)) 170 | print "%12s %8d %s" % (lexicon_meta_list[index]["ne"], len(lexicon), shortest_phrase) 171 | 172 | print "Reading training data..." 173 | data_split_list = ["train", "validate"] 174 | annotation_method_list = ["gold", "auto"] 175 | raw_data = {} 176 | for split in data_split_list: 177 | raw_data[split] = {} 178 | for method in annotation_method_list: 179 | if split == "validate": 180 | data_path_root = "development" 181 | else: 182 | data_path_root = split 183 | full_path = os.path.join(data_path_prefix, data_path_root, data_path_suffix) 184 | config = {"file_suffix": "%s_conll" % method, "dir_prefix": full_path} 185 | raw_data[split][method] = load_data(config) 186 | 187 | log("\nCleaning lexicon by training data...") 188 | for split in data_split_list: 189 | for document in raw_data[split]["auto"]: 190 | for part in raw_data[split]["auto"][document]: 191 | 192 | ner_raw_data = defaultdict(lambda: {}) 193 | for k, v in raw_data[split]["gold"][document][part]["ner"].iteritems(): 194 | ner_raw_data[k[0]][(k[1], k[2])] = v 195 | 196 | for index, parse in enumerate(raw_data[split]["auto"][document][part]["parses"]): 197 | text_raw_data = raw_data[split]["auto"][document][part]["text"][index] 198 | 199 | if parse.subtrees[0].label == "NOPARSE": continue 200 | head_raw_data = raw_data[split]["auto"][document][part]["heads"][index] 201 | 202 | span_set = set() 203 | traverse_tree(parse, ner_raw_data[index], head_raw_data, text_raw_data, 204 | lexicon_list, span_set) 205 | traverse_pyramid(ner_raw_data[index], text_raw_data, lexicon_list, span_set) 206 | log(" done\n") 207 | 208 | print "-"*50 + "\n ne phrases shortest\n" + "-"*50 209 | for index, lexicon in enumerate(lexicon_list): 210 | for phrase, count in lexicon.items(): 211 | if count[0]>0 and count[1]/count[0]<0.1: 212 | del lexicon[phrase] 213 | shortest_phrase = min(lexicon.iterkeys(), key=lambda phrase: len(phrase)) 214 | print "%12s %8d %s" % (lexicon_meta_list[index]["ne"], len(lexicon), shortest_phrase) 215 | 216 | for index, lexicon in enumerate(lexicon_list): 217 | meta = lexicon_meta_list[index] 218 | with codecs.open(meta["clean"], "w", encoding=meta["encoding"]) as f: 219 | for phrase in sorted(lexicon.iterkeys()): 220 | f.write("%s\n" % phrase) 221 | return 222 | 223 | def construct_node(node, tree, ner_raw_data, head_raw_data, text_raw_data, 224 | character_to_index, word_to_index, pos_to_index, lexicon_list, 225 | pos_count, ne_count, pos_ne_count, lexicon_hits, span_to_node, under_ne): 226 | pos = tree.label 227 | word = tree.word 228 | span = tree.span 229 | head = tree.head if hasattr(tree, "head") else head_raw_data[(span, pos)][1] 230 | ne = ner_raw_data[span] if span in ner_raw_data else "NONE" 231 | constituent = " ".join(text_raw_data[span[0]:span[1]]).lower() 232 | 233 | # Process pos info 234 | node.pos_index = pos_to_index[pos] 235 | pos_count[pos] += 1 236 | node.pos = pos #YOLO 237 | 238 | # Process word info 239 | node.word_split = [character_to_index[character] for character in word] if word else [] 240 | node.word_index = word_to_index[word] if word else -1 241 | node.word = word if word else "" # YOLO 242 | 243 | # Process head info 244 | node.head_split = [character_to_index[character] for character in head] 245 | node.head_index = word_to_index[head] 246 | node.head = head # YOLO 247 | 248 | # Process ne info 249 | node.under_ne = under_ne 250 | node.ne = ne 251 | if ne != "NONE": 252 | under_ne = True 253 | if not node.parent or node.parent.span!=span: 254 | ne_count[ne] += 1 255 | pos_ne_count[pos] += 1 256 | """ 257 | if hasattr(tree, "head"): 258 | print " ".join(text_raw_data) 259 | print " ".join(text_raw_data[span[0]:span[1]]) 260 | print ne 261 | print node.parent.head 262 | raw_input() 263 | """ 264 | # Process span info 265 | node.span = span 266 | node.span_length = span[1] - span[0] 267 | span_to_node[span] = node 268 | 269 | # Process lexicon info 270 | node.lexicon_hit = [0] * len(lexicon_list) 271 | hits = 0 272 | for index, lexicon in enumerate(lexicon_list): 273 | if constituent in lexicon: 274 | lexicon[constituent] += 1 275 | node.lexicon_hit[index] = 1 276 | hits = 1 277 | lexicon_hits[0] += hits 278 | 279 | # Binarize children 280 | if len(tree.subtrees) > 2: 281 | side_child_pos = tree.subtrees[-1].label 282 | side_child_span = tree.subtrees[-1].span 283 | side_child_head = head_raw_data[(side_child_span, side_child_pos)][1] 284 | if side_child_head != head: 285 | sub_subtrees = tree.subtrees[:-1] 286 | else: 287 | sub_subtrees = tree.subtrees[1:] 288 | new_span = (sub_subtrees[0].span[0], sub_subtrees[-1].span[1]) 289 | new_tree = PSTree(label=pos, span=new_span, subtrees=sub_subtrees) 290 | new_tree.head = head 291 | if side_child_head != head: 292 | tree.subtrees = [new_tree, tree.subtrees[-1]] 293 | else: 294 | tree.subtrees = [tree.subtrees[0], new_tree] 295 | 296 | # Process children 297 | nodes = 1 298 | for subtree in tree.subtrees: 299 | child = Node() 300 | node.add_child(child) 301 | child_nodes = construct_node(child, subtree, ner_raw_data, head_raw_data, text_raw_data, 302 | character_to_index, word_to_index, pos_to_index, lexicon_list, 303 | pos_count, ne_count, pos_ne_count, lexicon_hits, span_to_node, under_ne) 304 | nodes += child_nodes 305 | return nodes 306 | 307 | def create_dense_nodes(ner_raw_data, text_raw_data, pos_to_index, lexicon_list, 308 | pos_count, ne_count, pos_ne_count, lexicon_hits, span_to_node): 309 | node_list = [] 310 | max_dense_span = 3 311 | # Start from bigram, since all unigrams are already covered by parses 312 | for span_length in range(2, 1+max_dense_span): 313 | for span_start in range(0, 1+len(text_raw_data)-span_length): 314 | span = (span_start, span_start+span_length) 315 | if span in span_to_node: continue 316 | pos = "NONE" 317 | ne = ner_raw_data[span] if span in ner_raw_data else "NONE" 318 | constituent = " ".join(text_raw_data[span[0]:span[1]]).lower() 319 | 320 | # span, child 321 | # TODO: sibling 322 | node = Node(family=1) 323 | node_list.append(node) 324 | node.span = span 325 | node.span_length = span_length 326 | span_to_node[span] = node 327 | node.child_list = [span_to_node[(span[0],span[1]-1)], span_to_node[(span[0]+1,span[1])]] 328 | 329 | # word, head, pos 330 | node.pos_index = pos_to_index[pos] 331 | pos_count[pos] += 1 332 | node.word_split = [] 333 | node.word_index = -1 334 | node.head_split = [] 335 | node.head_index = -1 336 | 337 | # ne 338 | node.ne = ne 339 | if ne != "NONE": 340 | ne_count[ne] += 1 341 | pos_ne_count[pos] += 1 342 | 343 | # lexicon 344 | node.lexicon_hit = [0] * len(lexicon_list) 345 | hits = 0 346 | for index, lexicon in enumerate(lexicon_list): 347 | if constituent in lexicon: 348 | lexicon[constituent] += 1 349 | node.lexicon_hit[index] = 1 350 | hits = 1 351 | lexicon_hits[0] += hits 352 | 353 | return node_list 354 | 355 | def get_tree_data(raw_data, character_to_index, word_to_index, pos_to_index, lexicon_list): 356 | log("get_tree_data()...") 357 | """ Get tree structured data from CoNLL 2012 358 | 359 | Stores into Node data structure 360 | """ 361 | tree_pyramid_list = [] 362 | ner_list = [] 363 | word_count = 0 364 | pos_count = defaultdict(lambda: 0) 365 | ne_count = defaultdict(lambda: 0) 366 | pos_ne_count = defaultdict(lambda: 0) 367 | lexicon_hits = [0] 368 | 369 | for document in raw_data["auto"]: 370 | for part in raw_data["auto"][document]: 371 | 372 | ner_raw_data = defaultdict(lambda: {}) 373 | for k, v in raw_data["gold"][document][part]["ner"].iteritems(): 374 | ner_raw_data[k[0]][(k[1], k[2])] = v 375 | 376 | for index, parse in enumerate(raw_data["auto"][document][part]["parses"]): 377 | text_raw_data = raw_data["auto"][document][part]["text"][index] 378 | word_count += len(text_raw_data) 379 | 380 | if parse.subtrees[0].label == "NOPARSE": continue 381 | head_raw_data = raw_data["auto"][document][part]["heads"][index] 382 | 383 | root_node = Node() 384 | span_to_node = {} 385 | nodes = construct_node( 386 | root_node, parse, ner_raw_data[index], head_raw_data, text_raw_data, 387 | character_to_index, word_to_index, pos_to_index, lexicon_list, 388 | pos_count, ne_count, pos_ne_count, lexicon_hits, span_to_node, False) 389 | root_node.nodes = nodes 390 | root_node.text_raw_data = text_raw_data #YOLO 391 | 392 | additional_node_list = [] 393 | """ 394 | additional_node_list = create_dense_nodes( 395 | ner_raw_data[index], text_raw_data, 396 | pos_to_index, lexicon_list, 397 | pos_count, ne_count, pos_ne_count, lexicon_hits, span_to_node) 398 | """ 399 | tree_pyramid_list.append((root_node, additional_node_list)) 400 | ner_list.append(ner_raw_data[index]) 401 | 402 | log(" %d sentences\n" % len(tree_pyramid_list)) 403 | return (tree_pyramid_list, ner_list, word_count, pos_count, ne_count, pos_ne_count, 404 | lexicon_hits[0]) 405 | 406 | def label_tree_data(node, pos_to_index, ne_to_index): 407 | node.y = ne_to_index[node.ne] 408 | # node.y = ne_to_index[":".join(node.ner)] 409 | 410 | for child in node.child_list: 411 | label_tree_data(child, pos_to_index, ne_to_index) 412 | return 413 | 414 | def read_dataset(data_split_list = ["train", "validate", "test"]): 415 | # Read all raw data 416 | annotation_method_list = ["gold", "auto"] 417 | raw_data = {} 418 | for split in data_split_list: 419 | raw_data[split] = {} 420 | for method in annotation_method_list: 421 | if split == "test" and method == "auto": 422 | full_path = os.path.join(test_auto_data_path_prefix, "test", data_path_suffix) 423 | else: 424 | if split == "validate": 425 | data_path_root = "development" 426 | else: 427 | data_path_root = split 428 | full_path = os.path.join(data_path_prefix, data_path_root, data_path_suffix) 429 | config = {"file_suffix": "%s_conll" % method, "dir_prefix": full_path} 430 | raw_data[split][method] = load_data(config) 431 | 432 | # Read lists of annotations 433 | character_list, character_to_index = read_list_file(character_file) 434 | word_list, word_to_index = read_list_file(word_file) 435 | pos_list, pos_to_index = read_list_file(pos_file) 436 | ne_list, ne_to_index = read_list_file(ne_file) 437 | 438 | pos_to_index["NONE"] = len(pos_to_index) 439 | 440 | # Read lexicon 441 | lexicon_list = [] 442 | for meta in lexicon_meta_list: 443 | lexicon_list.append(read_list_file(meta["raw"], encoding=meta["encoding"])[1]) 444 | #lexicon_list.append(read_list_file(meta["clean"], encoding=meta["encoding"])[1]) 445 | 446 | for lexicon in lexicon_list: 447 | for phrase in lexicon: 448 | lexicon[phrase] = 0 449 | 450 | # Build a tree structure for each sentence 451 | data = {} 452 | word_count = {} 453 | pos_count = {} 454 | ne_count = {} 455 | pos_ne_count = {} 456 | lexicon_hits = {} 457 | for split in data_split_list: 458 | (tree_pyramid_list, ner_list, 459 | word_count[split], pos_count[split], ne_count[split], pos_ne_count[split], 460 | lexicon_hits[split]) = get_tree_data(raw_data[split], 461 | character_to_index, word_to_index, pos_to_index, lexicon_list) 462 | #data[split] = [tree_list, ner_list] 463 | data[split] = {"tree_pyramid_list": tree_pyramid_list, "ner_list": ner_list} 464 | 465 | for index, lexicon in enumerate(lexicon_list): 466 | with codecs.open("tmp_%d.txt" % index, "w", encoding="utf8") as f: 467 | for phrase, count in sorted(lexicon.iteritems(), key=lambda x: (-x[1], x[0])): 468 | if count == 0: break 469 | f.write("%9d %s\n" % (count, phrase)) 470 | 471 | # Show statistics of each data split 472 | print "-" * 80 473 | print "%10s%10s%9s%9s%7s%12s%13s" % ("split", "sentence", "token", "node", "NE", "spanned_NE", 474 | "lexicon_hit") 475 | print "-" * 80 476 | for split in data_split_list: 477 | print "%10s%10d%9d%9d%7d%12d%13d" % (split, 478 | len(data[split]["tree_pyramid_list"]), 479 | word_count[split], 480 | sum(pos_count[split].itervalues()), 481 | sum(len(ner) for ner in data[split]["ner_list"]), 482 | sum(ne_count[split].itervalues()), 483 | lexicon_hits[split]) 484 | 485 | # Show POS distribution 486 | total_pos_count = defaultdict(lambda: 0) 487 | for split in data_split_list: 488 | for pos in pos_count[split]: 489 | total_pos_count[pos] += pos_count[split][pos] 490 | nodes = sum(total_pos_count.itervalues()) 491 | print "\nTotal %d nodes" % nodes 492 | print "-"*80 + "\n POS count ratio\n" + "-"*80 493 | for pos, count in sorted(total_pos_count.iteritems(), key=lambda x: x[1], reverse=True)[:10]: 494 | print "%6s %7d %5.1f%%" % (pos, count, count*100./nodes) 495 | 496 | # Show NE distribution in [train, validate] 497 | total_ne_count = defaultdict(lambda: 0) 498 | for split in data_split_list: 499 | if split == "test": continue 500 | for ne in ne_count[split]: 501 | total_ne_count[ne] += ne_count[split][ne] 502 | nes = sum(total_ne_count.itervalues()) 503 | print "\nTotal %d spanned named entities in [train, validate]" % nes 504 | print "-"*80 + "\n NE count ratio\n" + "-"*80 505 | for ne, count in sorted(total_ne_count.iteritems(), key=lambda x: x[1], reverse=True): 506 | print "%12s %6d %5.1f%%" % (ne, count, count*100./nes) 507 | 508 | # Show POS-NE distribution in [train, validate] 509 | total_pos_ne_count = defaultdict(lambda: 0) 510 | for split in data_split_list: 511 | if split == "test": continue 512 | for pos in pos_ne_count[split]: 513 | total_pos_ne_count[pos] += pos_ne_count[split][pos] 514 | print "-"*80 + "\n POS NE total ratio\n" + "-"*80 515 | for pos, count in sorted(total_pos_ne_count.iteritems(), key=lambda x: x[1], reverse=True)[:10]: 516 | total = total_pos_count[pos] 517 | print "%6s %6d %7d %5.1f%%" % (pos, count, total, count*100./total) 518 | 519 | # Compute the mapping to labels 520 | ne_to_index["NONE"] = len(ne_to_index) 521 | 522 | # Add label to nodes 523 | for split in data_split_list: 524 | for tree, pyramid in data[split]["tree_pyramid_list"]: 525 | label_tree_data(tree, pos_to_index, ne_to_index) 526 | for node in pyramid: 527 | node.y = ne_to_index[node.ne] 528 | 529 | return (data, word_list, ne_list, 530 | len(character_to_index), len(pos_to_index), len(ne_to_index), len(lexicon_list)) 531 | 532 | if __name__ == "__main__": 533 | #extract_vocabulary_and_alphabet() 534 | #extract_glove_embeddings() 535 | #extract_clean_lexicon() 536 | read_dataset() 537 | exit() 538 | 539 | 540 | 541 | 542 | 543 | -------------------------------------------------------------------------------- /ontonotes/ne.txt: -------------------------------------------------------------------------------- 1 | PERSON 2 | NORP 3 | FAC 4 | ORG 5 | GPE 6 | LOC 7 | PRODUCT 8 | EVENT 9 | WORK_OF_ART 10 | LAW 11 | LANGUAGE 12 | DATE 13 | TIME 14 | PERCENT 15 | MONEY 16 | QUANTITY 17 | ORDINAL 18 | CARDINAL 19 | -------------------------------------------------------------------------------- /ontonotes/pos.txt: -------------------------------------------------------------------------------- 1 | S 2 | SBAR 3 | SBARQ 4 | SINV 5 | SQ 6 | ADJP 7 | ADVP 8 | CONJP 9 | FRAG 10 | INTJ 11 | LST 12 | NAC 13 | NP 14 | NX 15 | PP 16 | PRN 17 | PRT 18 | QP 19 | RRC 20 | UCP 21 | VP 22 | WHADJP 23 | WHADVP 24 | WHNP 25 | WHPP 26 | X 27 | CC 28 | CD 29 | DT 30 | EX 31 | FW 32 | IN 33 | JJ 34 | JJR 35 | JJS 36 | LS 37 | MD 38 | NN 39 | NNS 40 | NNP 41 | NNPS 42 | PDT 43 | POS 44 | PRP 45 | PRP$ 46 | RB 47 | RBR 48 | RBS 49 | RP 50 | SYM 51 | TO 52 | UH 53 | VB 54 | VBD 55 | VBG 56 | VBN 57 | VBP 58 | VBZ 59 | WDT 60 | WP 61 | WP$ 62 | WRB 63 | TOP 64 | HYPH 65 | $ 66 | -LRB- 67 | -RRB- 68 | '' 69 | . 70 | XX 71 | : 72 | AFX 73 | NFP 74 | ADD 75 | META 76 | EMBED 77 | NML 78 | , 79 | `` 80 | -------------------------------------------------------------------------------- /rnn.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | class Node(object): 7 | def __init__(self, family=0): 8 | self.child_list = [] 9 | self.parent = None 10 | self.left = None 11 | self.right = None 12 | self.family = family 13 | return 14 | 15 | def add_child(self, child): 16 | if self.child_list: 17 | sibling = self.child_list[-1] 18 | sibling.right = child 19 | child.left = sibling 20 | self.child_list.append(child) 21 | child.parent = self 22 | return 23 | 24 | def show_tree(self, depth=0, branch=[]): 25 | indent = " " * depth 26 | for i in branch: 27 | indent = indent[:i] + "|" + indent[i+1:] 28 | print "%s|--%s %s" % (indent, self.pos, self.head) 29 | 30 | branch2 = [i for i in branch] 31 | if self.right: 32 | branch2.append(depth*4) 33 | elif not self.child_list: 34 | print indent 35 | 36 | for child in self.child_list: 37 | child.show_tree(depth+1, branch2) 38 | return 39 | 40 | class Affine_Layer(object): 41 | def __init__(self, name, input_dimension, output_dimension): 42 | self.name = name 43 | with tf.variable_scope(self.name, initializer=tf.contrib.layers.xavier_initializer()): 44 | self.W = tf.get_variable("W", [input_dimension, output_dimension]) 45 | self.b = tf.get_variable("b", [1, output_dimension]) 46 | return 47 | 48 | def transform(self, input_tensor): 49 | O = tf.matmul(input_tensor, self.W) + self.b 50 | return O 51 | 52 | def l2_loss(self): 53 | l2_loss = tf.nn.l2_loss(self.W) + tf.nn.l2_loss(self.b) 54 | return l2_loss 55 | 56 | class Child_Sum_LSTM_Layer(object): 57 | def __init__(self, name, raw_features, hidden_features): 58 | self.name = name 59 | with tf.variable_scope(self.name, initializer=tf.contrib.layers.xavier_initializer()): 60 | self.W = {} 61 | self.W["h"] = tf.get_variable("Wh", [raw_features+hidden_features, hidden_features]) 62 | self.W["i"] = tf.get_variable("Wi", [raw_features+hidden_features, hidden_features]) 63 | self.W["o"] = tf.get_variable("Wo", [raw_features+hidden_features, hidden_features]) 64 | self.W["f"] = tf.get_variable("Wf", [raw_features+hidden_features, hidden_features]) 65 | self.b = {} 66 | self.b["h"] = tf.get_variable("bh", [1, hidden_features]) 67 | self.b["i"] = tf.get_variable("bi", [1, hidden_features]) 68 | self.b["o"] = tf.get_variable("bo", [1, hidden_features]) 69 | self.b["f"] = tf.get_variable("bf", [1, hidden_features]) 70 | return 71 | 72 | def transform(self, raw, cell, out): 73 | """ 74 | raw : samples * raw_features 75 | cell: samples * degree * hidden_features 76 | out : samples * degree * hidden_features 77 | """ 78 | samples = tf.shape(raw)[0] 79 | raw_features = tf.shape(raw)[1] 80 | degree = tf.shape(cell)[1] 81 | hidden_features = tf.shape(cell)[2] 82 | 83 | x_concat = tf.reshape(raw, [samples, 1, raw_features]) 84 | x_concat = tf.tile(x_concat, [1, degree, 1]) 85 | x_concat = tf.concat([x_concat, out], 2) 86 | x_concat = tf.reshape(x_concat, [samples*degree, raw_features+hidden_features]) 87 | 88 | x_sum = tf.reduce_sum(out, 1) 89 | x_sum = tf.concat([raw, x_sum], 1) 90 | 91 | h = tf.matmul(x_sum, self.W["h"]) + self.b["h"] 92 | # h = tf.nn.relu(h) 93 | h = tf.tanh(h) 94 | 95 | i = tf.matmul(x_sum, self.W["i"]) + self.b["i"] 96 | i = tf.nn.sigmoid(i) 97 | 98 | o = tf.matmul(x_sum, self.W["o"]) + self.b["o"] 99 | o = tf.nn.sigmoid(o) 100 | 101 | f = tf.matmul(x_concat, self.W["f"]) + self.b["f"] 102 | f = tf.nn.sigmoid(f) 103 | f = tf.reshape(f, [samples, degree, hidden_features]) 104 | 105 | cell = tf.reduce_sum(f * cell, 1) 106 | cell = i * h + cell 107 | # out = o * tf.nn.relu(cell) 108 | out = o * tf.tanh(cell) 109 | return cell, out 110 | """ 111 | O_cell = tf.reshape(hh_cell, [self.samples, 1, self.hidden_dimension]) 112 | O_out = tf.reshape(hh_out, [self.samples, 1, self.hidden_dimension]) 113 | O = tf.concat([O_cell, O_out], 1) 114 | return O 115 | """ 116 | def l2_loss(self): 117 | l2_loss = tf.zeros(1) 118 | for i in self.W.itervalues(): 119 | l2_loss += tf.nn.l2_loss(i) 120 | for i in self.b.itervalues(): 121 | l2_loss += tf.nn.l2_loss(i) 122 | return l2_loss 123 | 124 | class Config(object): 125 | """ Store hyper parameters for tree models 126 | """ 127 | 128 | def __init__(self): 129 | self.name = "YOLO" 130 | 131 | self.vocabulary_size = 5 132 | self.word_to_word_embeddings = 300 133 | 134 | self.use_character_to_word_embedding = False 135 | self.alphabet_size = 5 136 | self.character_embeddings = 25 137 | self.word_length = 20 138 | self.max_conv_window = 3 139 | self.kernels = 40 140 | 141 | self.lexicons = 4 142 | 143 | self.pos_dimension = 5 144 | self.hidden_dimension = 350 145 | self.output_dimension = 2 146 | 147 | self.degree = 2 148 | self.poses = 3 149 | self.words = 4 150 | self.neighbors = 4 151 | 152 | self.hidden_layers = 3 153 | 154 | self.learning_rate = 1e-5 155 | self.epsilon = 1e-2 156 | self.keep_rate_P = 0.65 157 | self.keep_rate_X = 0.65 158 | self.keep_rate_H = 0.65 159 | return 160 | 161 | class RNN(object): 162 | """ A special Bidrectional Recursive Neural Network 163 | 164 | From an input tree, it classifies each node and identifies positive spans and their labels. 165 | 166 | Instantiating an object of this class only defines a Tensorflow computation graph 167 | under the name scope config.name. Weights of a model instance reside in a Tensorflow session. 168 | """ 169 | 170 | def __init__(self, config): 171 | self.create_hyper_parameter(config) 172 | self.create_input() 173 | self.create_word_embedding_layer() 174 | self.create_hidden_layer() 175 | self.create_output() 176 | self.create_update_op() 177 | return 178 | 179 | def create_hyper_parameter(self, config): 180 | """ Add attributes of cofig to self 181 | """ 182 | for parameter in dir(config): 183 | if parameter[0] == "_": continue 184 | setattr(self, parameter, getattr(config, parameter)) 185 | 186 | def create_input(self): 187 | """ Construct the input layer and embedding dictionaries 188 | 189 | If L is a tensor, wild indices will cause tf.gather() to raise error. 190 | Since L is a variable, gathering with some index of x being -1 will return zeroes, 191 | but will still raise error in apply_gradient. 192 | """ 193 | # Create placeholders 194 | self.e = tf.placeholder(tf.float32, [None, None]) 195 | self.y = tf.placeholder( tf.int32, [None, None]) 196 | self.f = tf.placeholder( tf.int32, [None, None]) 197 | self.T = tf.placeholder( tf.int32, [None, None, self.neighbors+self.degree]) 198 | self.p = tf.placeholder( tf.int32, [None, None, self.poses]) 199 | self.x = tf.placeholder( tf.int32, [None, None, self.words]) 200 | self.w = tf.placeholder( tf.int32, [None, None, self.words, self.word_length]) 201 | self.lex = tf.placeholder(tf.float32, [None, None, self.lexicons]) 202 | self.l = tf.placeholder(tf.float32, [None]) 203 | self.krP = tf.placeholder(tf.float32) 204 | self.krX = tf.placeholder(tf.float32) 205 | self.krH = tf.placeholder(tf.float32) 206 | 207 | self.nodes = tf.shape(self.T)[0] 208 | self.samples = tf.shape(self.T)[1] 209 | 210 | # Create embedding dictionaries 211 | # We use one-hot character embeddings so no dictionary is needed 212 | with tf.variable_scope(self.name, initializer=tf.random_normal_initializer(stddev=0.1)): 213 | self.L = tf.get_variable("L", 214 | [self.vocabulary_size, self.word_to_word_embeddings]) 215 | # self.C = tf.get_variable("C", 216 | # [2+self.alphabet_size, self.character_embeddings]) 217 | self.L_hat = tf.concat(axis=0, values=[tf.zeros([1, self.word_to_word_embeddings]), self.L]) 218 | # self.C_hat = tf.concat(0, [tf.zeros([1, self.character_embeddings]), self.C]) 219 | 220 | # Compute indices of neighbors 221 | offset = tf.reshape(tf.range(self.samples), [1, self.samples, 1]) 222 | offset = tf.tile(offset, [self.nodes, 1, self.neighbors+self.degree]) 223 | self.T_hat = offset + (1+self.T) * self.samples 224 | 225 | # Compute pos features 226 | P = tf.one_hot(self.p, self.pos_dimension, on_value=10.) 227 | P = tf.reshape(P, [self.nodes, self.samples, self.poses*self.pos_dimension]) 228 | self.P = tf.nn.dropout(P, self.krP) 229 | return 230 | 231 | def create_convolution_layers(self): 232 | """ Create a unit which use the character string of a word to generate its embedding 233 | 234 | Special characters: -1: start, -2: end, -3: padding 235 | We use one-hot character embeddings so no dictionary is needed. 236 | """ 237 | self.K = [None] 238 | self.character_embeddings = self.alphabet_size 239 | with tf.variable_scope(self.name, initializer=tf.contrib.layers.xavier_initializer()): 240 | for window in xrange(1, self.max_conv_window+1): 241 | self.K.append(tf.get_variable("K%d" % window, 242 | [window, self.character_embeddings, 1, self.kernels*window])) 243 | 244 | def cnn(w): 245 | W = tf.one_hot(w+2, self.alphabet_size, on_value=1.) 246 | # W = tf.gather(self.C_hat, w+3) 247 | W = tf.reshape(W, [-1, self.word_length, self.character_embeddings, 1]) 248 | stride = [1, 1, self.character_embeddings, 1] 249 | 250 | W_hat = [] 251 | for window in xrange(1, self.max_conv_window+1): 252 | W_window = tf.nn.conv2d(W, self.K[window], stride, "VALID") 253 | W_window = tf.reduce_max(W_window, axis=[1, 2]) 254 | W_hat.append(W_window) 255 | 256 | W_hat = tf.concat(axis=1, values=W_hat) 257 | return tf.nn.relu(W_hat) 258 | 259 | self.f_x_cnn = cnn 260 | return 261 | 262 | def create_highway_layers(self): 263 | """ Create a unit to transform the embedding of a word from CNN 264 | 265 | A highway layer is a linear combination of a fully connected layer and an identity layer. 266 | """ 267 | layers = 1 268 | self.W_x_mlp = [] 269 | self.W_x_gate = [] 270 | self.b_x_mlp = [] 271 | self.b_x_gate = [] 272 | for layer in xrange(layers): 273 | with tf.variable_scope(self.name, initializer=tf.contrib.layers.xavier_initializer()): 274 | self.W_x_mlp.append(tf.get_variable("W_x_mlp_%d" % layer, 275 | [self.character_to_word_embeddings, 276 | self.character_to_word_embeddings])) 277 | self.W_x_gate.append(tf.get_variable("W_x_gate_%d" % layer, 278 | [self.character_to_word_embeddings, 279 | self.character_to_word_embeddings])) 280 | self.b_x_mlp.append(tf.get_variable("b_x_mlp_%d" % layer, 281 | [1, self.character_to_word_embeddings])) 282 | with tf.variable_scope(self.name, 283 | initializer=tf.random_normal_initializer(mean=-2, stddev=0.1)): 284 | self.b_x_gate.append(tf.get_variable("b_x_gate_%d" % layer, 285 | [1, self.character_to_word_embeddings])) 286 | 287 | def highway(x): 288 | data = x 289 | for layer in xrange(layers): 290 | mlp = tf.nn.relu(tf.matmul(data, self.W_x_mlp[layer]) + self.b_x_mlp[layer]) 291 | gate = tf.sigmoid(tf.matmul(data, self.W_x_gate[layer]) + self.b_x_gate[layer]) 292 | data = mlp*gate + data*(1-gate) 293 | return data 294 | 295 | self.f_x_highway = highway 296 | return 297 | 298 | def create_word_embedding_layer(self): 299 | """ Create a layer to compute word embeddings for all words 300 | """ 301 | self.word_dimension = self.word_to_word_embeddings 302 | X = tf.gather(self.L_hat, self.x+1) 303 | 304 | if self.use_character_to_word_embedding: 305 | conv_windows = (1+self.max_conv_window) * self.max_conv_window / 2 306 | self.character_to_word_embeddings = conv_windows * self.kernels 307 | self.word_dimension += self.character_to_word_embeddings 308 | 309 | self.create_convolution_layers() 310 | self.create_highway_layers() 311 | 312 | w = tf.reshape(self.w, [self.nodes*self.samples*self.words, self.word_length]) 313 | W = self.f_x_highway(self.f_x_cnn(w)) 314 | X = tf.reshape(X, [self.nodes*self.samples*self.words, self.word_to_word_embeddings]) 315 | X = tf.concat(axis=1, values=[X, W]) 316 | 317 | X = tf.reshape(X, [self.nodes, self.samples, self.words*self.word_dimension]) 318 | self.X = tf.nn.dropout(X, self.krX) 319 | 320 | # Mean embedding of leaf words 321 | m = self.X[:,:,:self.word_dimension] 322 | self.m = tf.reduce_sum(m, axis=0) / tf.reshape(self.l, [self.samples, 1]) 323 | return 324 | 325 | def get_hidden_unit(self, name, degree): 326 | """ Create a unit to compute the hidden features of one direction of a node 327 | """ 328 | self.raw_features = (self.pos_dimension * self.poses 329 | + self.word_dimension * (1+self.words) 330 | + self.lexicons) 331 | 332 | self.layer_dict[name] = [] 333 | for i in xrange(self.hidden_layers): 334 | if i == 0: 335 | input_dimension = degree*self.hidden_dimension + self.raw_features 336 | else: 337 | input_dimension = degree*self.hidden_dimension + self.hidden_dimension 338 | with tf.variable_scope(self.name): 339 | with tf.variable_scope(name): 340 | hidden_layer = Affine_Layer("H%d"%i, input_dimension, self.hidden_dimension) 341 | self.layer_dict[name].append(hidden_layer) 342 | 343 | def hidden_unit(x, c): 344 | h = x 345 | ret = [] 346 | for i in xrange(self.hidden_layers): 347 | h = tf.concat([h, c[:,i,:]], 1) 348 | h = self.layer_dict[name][i].transform(h) 349 | h = tf.nn.relu(h) 350 | ret.append(h) 351 | return tf.stack(ret, 1) 352 | return hidden_unit 353 | 354 | def get_hidden_unit_backup(self, name, degree): 355 | """ Create a unit to compute the hidden features of one direction of a node 356 | """ 357 | self.raw_features = (self.pos_dimension * self.poses 358 | + self.word_dimension * (1+self.words) 359 | + self.lexicons) 360 | 361 | self.W[name] = {} 362 | self.b[name] = {} 363 | with tf.variable_scope(self.name, initializer=tf.contrib.layers.xavier_initializer()): 364 | with tf.variable_scope(name): 365 | for i in xrange(self.hidden_layers): 366 | if i == 0: 367 | input_dimension = self.hidden_dimension * degree + self.raw_features 368 | else: 369 | input_dimension = self.hidden_dimension * (degree + 1) 370 | self.W[name]["h%d"%i] = tf.get_variable("W_h%d"%i, 371 | [input_dimension, self.hidden_dimension]) 372 | self.b[name]["h%d"%i] = tf.get_variable("b_h%d"%i, [1, self.hidden_dimension]) 373 | 374 | def hidden_unit(x, c): 375 | h = x 376 | ret = [] 377 | for i in xrange(self.hidden_layers): 378 | h = tf.concat([h, c[:,i,:]], 1) 379 | h = tf.matmul(h, self.W[name]["h%d"%i]) + self.b[name]["h%d"%i] 380 | h = tf.nn.relu(h) 381 | ret.append(h) 382 | return tf.stack(ret, 1) 383 | return hidden_unit 384 | 385 | def get_child_sum_lstm_unit(self, name): 386 | self.raw_features = (self.pos_dimension * self.poses 387 | + self.word_dimension * (1+self.words) 388 | + self.lexicons) 389 | 390 | self.layer_dict[name] = [] 391 | for i in xrange(self.hidden_layers): 392 | if i == 0: 393 | input_dimension = self.raw_features 394 | else: 395 | input_dimension = self.hidden_dimension 396 | with tf.variable_scope(self.name): 397 | with tf.variable_scope(name): 398 | hidden_layer = Child_Sum_LSTM_Layer("H%d"%i, input_dimension, self.hidden_dimension) 399 | self.layer_dict[name].append(hidden_layer) 400 | 401 | def hidden_unit(x, cell, out): 402 | """ 403 | x: samples * raw_features 404 | cell, out: samples * degree * layers * hidden_features 405 | ret_cell, ret_out: samples * hidden_features 406 | """ 407 | ret_out = x 408 | ret_cell_list = [] 409 | ret_out_list = [] 410 | for i in xrange(self.hidden_layers): 411 | ret_cell, ret_out = self.layer_dict[name][i].transform( 412 | ret_out, cell[:,:,i,:], out[:,:,i,:]) 413 | ret_cell_list.append(ret_cell) 414 | ret_out_list.append(ret_out) 415 | return tf.stack(ret_cell_list, 1), tf.stack(ret_out_list, 1) 416 | return hidden_unit 417 | 418 | def get_child_sum_lstm_unit_backup(self, name, degree): 419 | self.raw_features = (self.pos_dimension * self.poses 420 | + self.word_dimension * (1+self.words) 421 | + self.lexicons) 422 | self.input_dimension = self.raw_features + self.hidden_dimension 423 | self.W[name] = {} 424 | self.b[name] = {} 425 | with tf.variable_scope(self.name, initializer=tf.contrib.layers.xavier_initializer()): 426 | with tf.variable_scope(name): 427 | self.W[name]["hh"] = tf.get_variable("W_hh", 428 | [self.input_dimension, self.hidden_dimension]) 429 | self.W[name]["hi"] = tf.get_variable("W_hi", 430 | [self.input_dimension, self.hidden_dimension]) 431 | self.W[name]["ho"] = tf.get_variable("W_ho", 432 | [self.input_dimension, self.hidden_dimension]) 433 | self.W[name]["hf"] = tf.get_variable("W_hf", 434 | [self.input_dimension, self.hidden_dimension]) 435 | self.b[name]["hh"] = tf.get_variable("b_hh", [1, self.hidden_dimension]) 436 | self.b[name]["hi"] = tf.get_variable("b_hi", [1, self.hidden_dimension]) 437 | self.b[name]["ho"] = tf.get_variable("b_ho", [1, self.hidden_dimension]) 438 | self.b[name]["hf"] = tf.get_variable("b_hf", [1, self.hidden_dimension]) 439 | 440 | def hidden_unit(x, c_cell, c_out): 441 | """ 442 | x: samples * raw_features 443 | c: samples * degree * hidden_dimension 444 | """ 445 | c_out_sum = tf.reduce_sum(c_out, axis=1) 446 | x2 = tf.concat([x, c_out_sum], 1) 447 | 448 | x3 = tf.reshape(x, [self.samples, 1, self.raw_features]) 449 | x3 = tf.tile(x3, [1, degree, 1]) 450 | x3 = tf.concat([x3, c_out], 2) 451 | x3 = tf.reshape(x3, [self.samples*degree, self.input_dimension]) 452 | 453 | hh = tf.matmul(x2, self.W[name]["hh"]) + self.b[name]["hh"] 454 | hh = tf.nn.relu(hh) 455 | 456 | hi = tf.matmul(x2, self.W[name]["hi"]) + self.b[name]["hi"] 457 | hi = tf.nn.sigmoid(hi) 458 | 459 | ho = tf.matmul(x2, self.W[name]["ho"]) + self.b[name]["ho"] 460 | ho = tf.nn.sigmoid(ho) 461 | 462 | hf = tf.matmul(x3, self.W[name]["hf"]) + self.b[name]["hf"] 463 | hf = tf.nn.sigmoid(hf) 464 | hf = tf.reshape(hf, [self.samples, degree, self.hidden_dimension]) 465 | c_cell_sum = tf.reduce_sum(hf*c_cell, axis=1) 466 | 467 | hh_cell = hi*hh + c_cell_sum 468 | hh_out = ho * tf.nn.relu(hh) 469 | 470 | hh_cell = tf.reshape(hh_cell, [self.samples, 1, self.hidden_dimension]) 471 | hh_out = tf.reshape(hh_out, [self.samples, 1, self.hidden_dimension]) 472 | hh = tf.concat([hh_cell, hh_out], 1) 473 | return hh 474 | return hidden_unit 475 | 476 | def create_hidden_layer(self): 477 | """ Create a layer to compute hidden features for all nodes 478 | """ 479 | self.layer_dict = {} 480 | # self.f_h_bottom = self.get_hidden_unit("hidden_bottom", self.degree) 481 | self.f_h_bottom = self.get_hidden_unit("hidden_bottom", 1) 482 | self.f_h_top = self.get_hidden_unit("hidden_top", 1) 483 | # self.f_h_bottom = self.get_child_sum_lstm_unit("hidden_bottom") 484 | # self.f_h_bottom = self.get_child_sum_lstm_unit("hidden_top") 485 | 486 | # 2: one for LSTM memory cell, one for output 487 | # H = tf.zeros([(1+self.nodes) * self.samples, self.hidden_layers, 2, self.hidden_dimension]) 488 | H = tf.zeros([(1+self.nodes) * self.samples, self.hidden_layers, self.hidden_dimension]) 489 | 490 | # Bottom-up 491 | def bottom_condition(index, H): 492 | return index <= self.nodes-1 493 | def bottom_body(index, H): 494 | p = self.P[index,:,:] 495 | x = self.X[index,:,:] 496 | lex = self.lex[index,:,:] 497 | 498 | t = self.T_hat[index,:,self.neighbors:self.neighbors+self.degree] 499 | # c = tf.gather(H, t) 500 | c = tf.reduce_sum(tf.gather(H, t), axis=1) 501 | 502 | raw = tf.concat([self.m, p, x, lex], 1) 503 | h = self.f_h_bottom(raw, c) 504 | # h1, h2 = self.f_h_bottom(raw, c[:,:,:,0,:], c[:,:,:,1,:]) 505 | # h = tf.stack([h1, h2], 2) 506 | 507 | h = tf.nn.dropout(h, self.krH) 508 | 509 | h_upper = tf.zeros([ (1+index)*self.samples, self.hidden_layers, self.hidden_dimension]) 510 | h_lower = tf.zeros([(self.nodes-1-index)*self.samples, self.hidden_layers, self.hidden_dimension]) 511 | # h_upper = tf.zeros([(1+index)*self.samples, self.hidden_layers, 2, self.hidden_dimension]) 512 | # h_lower = tf.zeros([(self.nodes-1-index)*self.samples, self.hidden_layers, 2, self.hidden_dimension]) 513 | H_hat = H+tf.concat([h_upper, h, h_lower], 0) 514 | return index+1, H_hat 515 | _, H_bottom = tf.while_loop(bottom_condition, bottom_body, [tf.constant(0), H], parallel_iterations=1) 516 | 517 | # Top-down 518 | def top_condition(index, H): 519 | return index >= 0 520 | def top_body(index, H): 521 | p = self.P[index,:,:] 522 | x = self.X[index,:,:] 523 | lex = self.lex[index,:,:] 524 | 525 | t = self.T_hat[index,:,3] 526 | # t = self.T_hat[index,:,3:4] 527 | c = tf.gather(H, t) 528 | 529 | raw = tf.concat([self.m, p, x, lex], 1) 530 | h = self.f_h_top(raw, c) 531 | # h1, h2 = self.f_h_bottom(raw, c[:,:,:,0,:], c[:,:,:,1,:]) 532 | # h = tf.stack([h1, h2], 2) 533 | 534 | h = tf.nn.dropout(h, self.krH) 535 | 536 | h_upper = tf.zeros([ (1+index)*self.samples, self.hidden_layers, self.hidden_dimension]) 537 | h_lower = tf.zeros([(self.nodes-1-index)*self.samples, self.hidden_layers, self.hidden_dimension]) 538 | # h_upper = tf.zeros([(1+index)*self.samples, self.hidden_layers, 2, self.hidden_dimension]) 539 | # h_lower = tf.zeros([(self.nodes-1-index)*self.samples, self.hidden_layers, 2, self.hidden_dimension]) 540 | H_hat = H+tf.concat([h_upper, h, h_lower], 0) 541 | return index-1, H_hat 542 | _, H_top = tf.while_loop(top_condition, top_body, [self.nodes-1, H], parallel_iterations=1) 543 | #_, H_top = tf.while_loop(top_condition, top_body, [self.nodes-1, H_bottom]) 544 | 545 | # self.H = H_bottom[:,self.hidden_layers-1,:] 546 | self.H = H_bottom[:,self.hidden_layers-1,:] + H_top[:,self.hidden_layers-1,:] 547 | # self.H = H_bottom[:,self.hidden_layers-1,1,:] + H_top[:,self.hidden_layers-1,1,:] 548 | return 549 | 550 | def get_output_unit(self, name): 551 | """ Create a unit to compute the class scores of a node 552 | """ 553 | with tf.variable_scope(self.name): 554 | with tf.variable_scope(name): 555 | self.layer_dict[name] = Affine_Layer("O", self.hidden_dimension*3, self.output_dimension) 556 | 557 | def output_unit(H): 558 | H = tf.gather(H, self.T_hat[:,:,:3]) 559 | H = tf.reshape(H, [self.nodes * self.samples, self.hidden_dimension * 3]) 560 | O = self.layer_dict[name].transform(H) 561 | return O 562 | return output_unit 563 | 564 | def create_output(self): 565 | """ Construct the output layer 566 | """ 567 | self.f_o = self.get_output_unit("output") 568 | 569 | self.O = self.f_o(self.H) 570 | Y_hat = tf.nn.softmax(self.O) 571 | self.y_hat = tf.reshape(tf.argmax(Y_hat, 1), [self.nodes, self.samples]) 572 | 573 | e = tf.reshape(self.e, [self.nodes * self.samples]) 574 | y = tf.reshape(self.y, [self.nodes * self.samples]) 575 | Y = tf.one_hot(y, self.output_dimension, on_value=1.) 576 | self.loss = tf.reduce_sum(e * tf.nn.softmax_cross_entropy_with_logits(logits=self.O, labels=Y)) 577 | return 578 | 579 | def create_update_op(self): 580 | """ Create the computation of back-propagation 581 | """ 582 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, epsilon=self.epsilon) 583 | self.update_op = optimizer.minimize(self.loss) 584 | return 585 | 586 | def get_padded_word(self, word): 587 | """ Preprocessing: Form a uniform-length string from a raw word string 588 | 589 | Mainly to enable batch CNN. 590 | A word W is cut and transformed into start, W_cut, end, padding. 591 | Special characters: -1: start, -2: end, -3: padding. 592 | """ 593 | word_cut = [-1] + word[:self.word_length-2] + [-2] 594 | padding = [-3] * (self.word_length - len(word_cut)) 595 | return word_cut + padding 596 | 597 | def get_formatted_input(self, tree, pyramid): 598 | """ Preprocessing: Extract data structures from an input tree 599 | 600 | tree: the root node of a tree 601 | pyramid: a bottom-up list of additional nodes outside the tree 602 | """ 603 | # Get top-down node list of the tree 604 | node_list = [tree] 605 | index = -1 606 | while index+1 < len(node_list): 607 | index += 1 608 | node_list.extend(node_list[index].child_list) 609 | """ 610 | # Construct linear chains 611 | for node in node_list[1:]: 612 | if not node.left: 613 | node.left = node.parent.left 614 | if not node.right: 615 | node.right = node.parent.right 616 | """ 617 | # Merge two lists of nodes according to bottom-up dependency; index all nodes 618 | node_list = node_list[::-1] + pyramid 619 | for index, node in enumerate(node_list): 620 | node.index = index 621 | 622 | # Extract data from layers bottom-up 623 | N = [] 624 | e = [] 625 | y = [] 626 | f = [] 627 | T = [] 628 | p = [] 629 | x = [] 630 | w = [] 631 | lex = [] 632 | l = 0 633 | for node in node_list: 634 | N.append(node) 635 | e.append(1) 636 | #e.append(0.1 if node.under_ne else 1) 637 | #e.append(.5 if node.y==self.output_dimension-1 else 1) 638 | y.append(node.y) 639 | f.append(node.family) 640 | 641 | child_index_list = [-1] * self.degree 642 | for i, child in enumerate(node.child_list): 643 | child_index_list[i] = child.index 644 | T.append([node.index, 645 | node.left.index if node.left else -1, 646 | node.right.index if node.right else -1, 647 | node.parent.index if node.parent else -1] 648 | + child_index_list) 649 | 650 | p.append([node.pos_index, 651 | node.left.pos_index if node.left else -1, 652 | node.right.pos_index if node.right else -1]) 653 | 654 | x.append([node.word_index, 655 | node.head_index, 656 | node.left.head_index if node.left else -1, 657 | node.right.head_index if node.right else -1]) 658 | 659 | w.append([self.get_padded_word(node.word_split), 660 | self.get_padded_word(node.head_split), 661 | self.get_padded_word(node.left.head_split if node.left else []), 662 | self.get_padded_word(node.right.head_split if node.right else [])]) 663 | 664 | lex.append(node.lexicon_hit) 665 | 666 | if node.word_index != -1: l += 1 667 | 668 | N = np.array(N) 669 | e = np.array( e, dtype=np.float32) 670 | y = np.array( y, dtype=np.int32) 671 | f = np.array( f, dtype=np.int32) 672 | T = np.array( T, dtype=np.int32) 673 | p = np.array( p, dtype=np.int32) 674 | x = np.array( x, dtype=np.int32) 675 | w = np.array( w, dtype=np.int32) 676 | lex = np.array(lex, dtype=np.float32) 677 | return N, e, y, f, T, p, x, w, lex, l, tree.index 678 | 679 | def get_batch_input(self, tree_pyramid_list): 680 | """ Preprocessing: Get batched data structures for the input layer from input trees 681 | """ 682 | input_list = [] 683 | for tree, pyramid in tree_pyramid_list: 684 | input_list.append(self.get_formatted_input(tree, pyramid)) 685 | 686 | samples = len(input_list) 687 | nodes = max([i[1].shape[0] for i in input_list]) 688 | N = np.zeros([nodes, samples ], dtype=np.object) 689 | e = np.zeros([nodes, samples ], dtype=np.float32) 690 | y = -1 * np.ones( [nodes, samples ], dtype=np.int32) 691 | f = np.zeros([nodes, samples ], dtype=np.int32) 692 | T = -1 * np.ones( [nodes, samples, self.neighbors+self.degree ], dtype=np.int32) 693 | p = -1 * np.ones( [nodes, samples, self.poses ], dtype=np.int32) 694 | x = -1 * np.ones( [nodes, samples, self.words ], dtype=np.int32) 695 | w = -3 * np.ones( [nodes, samples, self.words, self.word_length], dtype=np.int32) 696 | lex = np.zeros([nodes, samples, self.lexicons ], dtype=np.float32) 697 | l = np.zeros( samples , dtype=np.float32) 698 | r = np.zeros( samples , dtype=np.int32) 699 | 700 | for sample, sample_input in enumerate(input_list): 701 | n = sample_input[0].shape[0] 702 | ( N[:n, sample ], 703 | e[:n, sample ], 704 | y[:n, sample ], 705 | f[:n, sample ], 706 | T[:n, sample, : ], 707 | p[:n, sample, : ], 708 | x[:n, sample, : ], 709 | w[:n, sample, :, :], 710 | lex[:n, sample, : ], 711 | l[ sample ], 712 | r[ sample ]) = sample_input 713 | return N, e, y, f, T, p, x, w, lex, l, r 714 | 715 | def train(self, tree_pyramid_list): 716 | """ Update parameters from a batch of trees with labeled nodes 717 | """ 718 | _, e, y, f, T, p, x, w, lex, l, _ = self.get_batch_input(tree_pyramid_list) 719 | 720 | loss, _ = self.sess.run([self.loss, self.update_op], 721 | feed_dict={self.e:e, self.y:y, self.f:f, self.T:T, 722 | self.p:p, self.x:x, self.w:w, self.lex:lex, self.l:l, 723 | self.krP:self.keep_rate_P, self.krX:self.keep_rate_X, self.krH:self.keep_rate_H}) 724 | return loss 725 | 726 | def predict(self, tree_pyramid_list): 727 | """ Predict positive spans and their labels from a batch of trees 728 | 729 | Spans that are contained by other positive spans are ignored. 730 | """ 731 | N, e, _, f, T, p, x, w, lex, l, r = self.get_batch_input(tree_pyramid_list) 732 | 733 | y_hat = self.sess.run(self.y_hat, 734 | feed_dict={self.f:f, self.T:T, 735 | self.p:p, self.x:x, self.w:w, self.lex:lex, self.l:l, 736 | self.krP:1.0, self.krX:1.0, self.krH:1.0}) 737 | 738 | def parse_output(node_index, sample_index, span_y, uncovered_token_set): 739 | node = N[node_index][sample_index] 740 | label = y_hat[node_index][sample_index] 741 | 742 | if label != self.output_dimension-1: 743 | span_y[node.span] = label 744 | for token_index in xrange(*node.span): 745 | uncovered_token_set.remove(token_index) 746 | return 747 | 748 | for child in node.child_list: 749 | parse_output(child.index, sample_index, span_y, uncovered_token_set) 750 | return 751 | 752 | tree_span_y = [] 753 | for sample_index in xrange(T.shape[1]): 754 | span_y = {} 755 | uncovered_token_set = set(xrange(int(l[sample_index]))) 756 | 757 | # Get span to positive y prediction of trees in top-down orders 758 | parse_output(r[sample_index], sample_index, span_y, uncovered_token_set) 759 | 760 | # Get span to positive y prediction of pyramids in top-down orders 761 | for i in xrange(T.shape[0]-1, -1, -1): 762 | if e[i][sample_index] > 0: 763 | pyramid_top_index = i 764 | break 765 | for node_index in xrange(pyramid_top_index, r[sample_index], -1): 766 | node = N[node_index][sample_index] 767 | label = y_hat[node_index][sample_index] 768 | if label != self.output_dimension-1: 769 | conflict = False 770 | for token_index in xrange(*node.span): 771 | if token_index not in uncovered_token_set: 772 | conflict = True 773 | break 774 | if conflict: continue 775 | span_y[node.span] = label 776 | for token_index in xrange(*node.span): 777 | uncovered_token_set.remove(token_index) 778 | 779 | tree_span_y.append(span_y) 780 | 781 | return tree_span_y 782 | 783 | def main(): 784 | config = Config() 785 | model = RNN(config) 786 | with tf.Session() as sess: 787 | sess.run(tf.global_variables_initializer()) 788 | L = sess.run(model.L) 789 | print L.shape 790 | for v in tf.trainable_variables(): 791 | print v 792 | return 793 | 794 | if __name__ == "__main__": 795 | main() 796 | exit() 797 | 798 | 799 | --------------------------------------------------------------------------------