├── README.md ├── corpus.py ├── data_structure.py ├── evaluate.py ├── input.txt ├── model_rvnn_lstm_20 ├── model_user.py ├── output.txt ├── parser.py ├── rvnn.py └── word_to_ix /README.md: -------------------------------------------------------------------------------- 1 | # discourse-parser 2 | demo mode instruction: 3 | python parser.py -demo -i -o 4 | 5 | 6 | input_file 7 | Chinese raw text (utf-8) (simplified) 8 | 9 | 10 | output_file 11 | json format: 12 | { 13 | 'EDUs':[(edu1), (edu2)...(edun)] 14 | 'tree':{'args':[(subtree)], 'sense':(sense), 'center':(center)} 15 | 'relations':[{'arg1':(arg1),'arg2':(arg2),'sense':(sense),'center':(center)},{...},{...}] 16 | } 17 | -------------------------------------------------------------------------------- /corpus.py: -------------------------------------------------------------------------------- 1 | # coding=UTF-8 2 | ''' 3 | define the Corpus class which contains information of a paragraph: text, segments index, relations and the merge structure 4 | 5 | merge structure: words merge to segments, segments merge to EDU, EDUS merge to bigger discourse unit 6 | 7 | ''' 8 | #to get arguments from shell 9 | #execute and stored the command arguments in variable args when first import 10 | import args 11 | # the Reporter class is use to help log, statistics, debug, measure process time, an instance reporter is initialized when first import 12 | import report 13 | 14 | #these punctuation divide text to segments 15 | _ENDs = (u'?', u'”', u'…', u'—', u'、', u'。', u'」', u'!', u',', u':', u';', u'?') 16 | 17 | 18 | word_to_ix = {} 19 | oov = 0 20 | 21 | #to judge the '——', '。」' condition 22 | def is_punc_in_text(text, idx): 23 | if text[idx] in _ENDs: 24 | if idx+1 < len(text): 25 | if text[idx+1] in _ENDs: 26 | if text[idx] == u'—' or text[idx+1] != u'—': 27 | return False 28 | if text[idx] == u'—' and idx > 1: 29 | if text[idx-1] != u'—': 30 | return False 31 | return True 32 | else: 33 | return False 34 | 35 | 36 | 37 | def build_word_to_ix(corpus_list): 38 | 39 | for corpus in corpus_list: 40 | for word in corpus.text: 41 | if word not in word_to_ix: 42 | word_to_ix[word] = len(word_to_ix) 43 | #for idx of oov 44 | oov = len(word_to_ix) 45 | word_to_ix['oov'] = len(word_to_ix) #for oov when test 46 | 47 | class Corpus(): 48 | def __init__(self): 49 | #encoding = unicode, easy to index, need to decode back to utf-8 for printing 50 | self.text = '' 51 | #each element is a [start_idx, end_idx] list, a text is devided to segments by given punctuations 52 | self.segments_span = [] 53 | #each element is a [start_idx, end_idx] list, some segments merge to a EDU 54 | self.edus_span = [] 55 | #top-down, pre order 56 | self.relations = [] 57 | #filename+'-'+paragraph_count ex: 001.xml-1 58 | self.id = '' 59 | 60 | #get a Corpus object from xml rows. The span index is start from 0, not 1, different with the xml format 61 | #in xml rows, the relations are in top down, pre order. This style remains the same 62 | def xml_rows_to_corpus(self, xml_rows): 63 | #each row is a relation 64 | 65 | #get the text from the first row 66 | #encoding = unicode, easy to index, need to decode back to utf-8 for printing 67 | r = xml_rows[0] 68 | self.text = r.get('Sentence').replace(u'|', '') 69 | 70 | #get segments idx, using the global _EDU tuple to split the text to segments 71 | self.segments_span = self.text_to_segments_span(self.text) 72 | 73 | #get the relation object list from xml rows 74 | #the xml rows list the relations in top down, left first style 75 | self.relations = self.xml_rows_to_ralations(xml_rows) 76 | #may need to assure the pre order of the relations 77 | #there is indeed some paragraph not pre-order from xml! 78 | self.relations = self.relations_to_pre_order(self.relations) 79 | 80 | #get the EDU index by just checking the kinds of the boundary in the relations 81 | self.edus_span = self.find_edus_span_from_relations(self.relations) 82 | 83 | #check whether the edu span and segment span are coodinated 84 | def span_certification(self): 85 | seg_idx = 0 86 | passed = True 87 | ends = [] 88 | for e_span in self.edus_span: 89 | start = e_span[0] 90 | end = e_span[-1] 91 | while start != self.segments_span[seg_idx][0]: 92 | ends.append(self.segments_span[seg_idx][-1]) 93 | seg_idx += 1 94 | if seg_idx == len(self.segments_span): 95 | print 'certification not passed:', self.id, start 96 | passed = False 97 | break 98 | 99 | while end != self.segments_span[seg_idx][-1]: 100 | ends.append(self.segments_span[seg_idx][-1]) 101 | seg_idx += 1 102 | 103 | if seg_idx == len(self.segments_span): 104 | print 'certification not passed:', self.id, end 105 | print ends 106 | passed = False 107 | break 108 | if not passed: 109 | break 110 | return 111 | 112 | #get segments idx, using the global _EDU tuple to split the text to segments 113 | def text_to_segments_span(self, text): 114 | start = 0 115 | end = 0 116 | segments_span = [] 117 | 118 | for idx in range(len(text)): 119 | word = text[idx] 120 | if word in _ENDs: 121 | end = idx 122 | segments_span.append([start,end]) 123 | start = end+1 124 | 125 | #some corpus has no punc in the end!! 126 | if start != len(text): 127 | segments_span.append([start,len(text)-1]) 128 | 129 | return segments_span 130 | 131 | #get the relation object list from xml rows 132 | #the xml rows list the relations in top down, pre order 133 | def xml_rows_to_ralations(self, xml_rows): 134 | relations = [] 135 | 136 | for r in xml_rows: 137 | #encoding = unicode, easy to index, need to decode back to utf-8 for printing 138 | sense = r.get('RelationType') 139 | center = r.get('Center') 140 | 141 | #get the span 142 | span = [] 143 | #sentence_position ex: u'27…51|52…63|64…78|79…89' 144 | sentence_position = r.get('SentencePosition') 145 | positions = sentence_position.split(u'|') 146 | for p in positions: 147 | ps = p.split(u'…') 148 | #the span index is start from 0, not 1, different with the xml format 149 | span.append( [ int(ps[0])-1 ,int(ps[-1])-1 ] ) 150 | 151 | 152 | relation = Relation(span, sense, center) 153 | relations.append(relation) 154 | 155 | return relations 156 | 157 | #get the EDU index by just checking the kinds of the boundary in the relations 158 | def find_edus_span_from_relations(self, relations): 159 | edus_span = [] 160 | 161 | #this set used to collected all kinds of boundary 162 | boundary_idx_set = set() 163 | for relation in relations: 164 | #span example: [ [0, 21], [22, 76], [77, 129] ] 165 | for span_unit in relation.span: 166 | boundary_idx_set.add(span_unit[0]) 167 | boundary_idx_set.add(span_unit[1]) 168 | 169 | #sorted from left to right 170 | boundary_idx = sorted(boundary_idx_set) 171 | 172 | for i in range(len(boundary_idx)-1): 173 | # 2 index is a pair 174 | if i%2 == 0: 175 | edus_span.append([ boundary_idx[i], boundary_idx[i+1] ]) 176 | 177 | return edus_span 178 | 179 | def relations_to_pre_order(self, relations): 180 | # this is for python3 removing the cmp in the sorted function, to transfer a cmp to key 181 | from functools import cmp_to_key 182 | #cmp_for_pre_order: compare function for sorting for pre order 183 | key_for_pre_order_from_relation = cmp_to_key(self.cmp_for_pre_order_from_relation) 184 | 185 | relations = sorted(relations, key=key_for_pre_order_from_relation, reverse=True) 186 | 187 | return relations 188 | 189 | def to_binary_structure(self): 190 | relations = self.relations 191 | new_added_relations = [] 192 | to_be_removed_relations = [] 193 | for relation in relations: 194 | if len(relation.span) > 2: 195 | #to delete the original multi-children relation in the future 196 | to_be_removed_relations.append(relation) 197 | #ex: [0,3] 198 | left_span_unit = relation.span[0] 199 | for idx in range(1,len(relation.span)): 200 | right_span_unit = relation.span[idx] 201 | new_span = [left_span_unit, right_span_unit] 202 | new_relation = Relation(new_span, relation.sense, relation.center) 203 | new_added_relations.append(new_relation) 204 | left_span_unit = [ new_span[0][0], new_span[1][1] ] 205 | 206 | for r in to_be_removed_relations: 207 | relations.remove(r) 208 | for r in new_added_relations: 209 | relations.append(r) 210 | 211 | self.relations = self.relations_to_pre_order(relations) 212 | return 213 | 214 | 215 | #compare function for sorting for pre order 216 | def cmp_for_pre_order_from_relation(self, rel_1, rel_2): 217 | #compare the start boundary first 218 | if rel_1.span[0][0] < rel_2.span[0][0]: 219 | return 1 220 | #if same, compare the end boundary 221 | elif rel_1.span[0][0] == rel_2.span[0][0]: 222 | if rel_1.span[-1][-1] > rel_2.span[-1][-1]: 223 | return 1 224 | return -1 225 | 226 | 227 | 228 | class Relation(): 229 | def __init__(self, span, sense, center): 230 | #The span index is start from 0, not 1, different with the xml format 231 | self.span = span 232 | #encoding = unicode, easy to index, need to decode back to utf-8 for printing 233 | self.sense = sense 234 | # 1, 2, or 3 235 | self.center = center 236 | return 237 | #end -------------------------------------------------------------------------------- /data_structure.py: -------------------------------------------------------------------------------- 1 | # coding=UTF-8 2 | # for parsing xml file 3 | import xml.etree.ElementTree as ET 4 | # for dealing with file system 5 | import os 6 | #for argv 7 | import sys 8 | # for save/load dictionary 9 | import pickle 10 | # for copy objects 11 | import copy 12 | from collections import defaultdict 13 | import random 14 | 15 | #whether to only consider segment and discourse unit level relations only when training 16 | seg_du_flag = True 17 | 18 | #need special judge for '——', '。」' condition 19 | _ENDs = (u'?', u'”', u'…', u'—', u'、', u'。', u'」', u'!', u',', u':', u';', u'?') 20 | #word_to_index dictionary 21 | word_to_ix = {} 22 | #for out of vocabulary dictionary index 23 | oov = 0 24 | #file path the dictionary dump 25 | WORD_TO_IX_DUMP_PATH = './word_to_ix' 26 | #center value from xml format to Corpus format 27 | XML_TO_CORPUS_CERTER_DICT = {'1':'former', '2':'latter', '3':'equal'} 28 | 29 | PRE_EDU_SENSE = 'PRE_EDU_SENSE' 30 | EDU_SENSE = 'EDU_SENSE' 31 | NON_CENTER = 'NON_CENTER' 32 | PSEUDO_SENSE = 'PSEUDO_SENSE' 33 | PSEUDO_CENTER = 'PSEUDO_CENTER' 34 | NON_TYPE = u'' 35 | 36 | LABEL_TO_SENSE = {0:'Coordination', 1:'Causality', 2:'Transition', 3:'Explanation'} 37 | LABEL_TO_CENTER = {0:'Front', 1:'Latter', 2:'Equal'} 38 | 39 | SENSE_TO_LABEL = {u'并列关系': 0, u'顺承关系': 0, u'递进关系':0, u'选择关系': 0, u'对比关系':0, u'因果关系': 1, u'推断关系': 1, u'假设关系': 1, u'目的关系': 1, u'条件关系': 1, u'背景关系': 1, u'转折关系': 2, u'让步关系': 2, u'解说关系': 3, u'总分关系': 3, u'例证关系': 3, u'评价关系': 3, EDU_SENSE: 4, PRE_EDU_SENSE: 5, PSEUDO_SENSE: 6} 40 | SENSE_DIM = len(set(SENSE_TO_LABEL.values())) 41 | DU_SENSE_LABEL = tuple( SENSE_TO_LABEL[x] for x in (u'并列关系', u'顺承关系', u'递进关系', u'选择关系', u'对比关系', u'因果关系', u'推断关系', u'假设关系', u'目的关系', u'条件关系', u'背景关系', u'转折关系', u'让步关系', u'解说关系', u'总分关系', u'例证关系', u'评价关系', EDU_SENSE) ) 42 | 43 | 44 | CENTER_TO_LABEL = {'former':0, 'latter':1, 'equal':2, NON_CENTER: 3, PSEUDO_CENTER: 3} 45 | CENTER_DIM = len(set(CENTER_TO_LABEL.values())) 46 | DU_CENTER_LABEL = tuple(CENTER_TO_LABEL[x] for x in ('former', 'latter', 'equal')) 47 | COORD_SENSE_LABEL = SENSE_TO_LABEL[u'并列关系'] 48 | COORD_CENTER_LABEL = CENTER_TO_LABEL['equal'] 49 | 50 | 51 | #for lstm-crf 52 | #extract from corpus, lack ON, IJ when comparing to pos tag guideline, and NR-SHORT, NN-SHORT, NT-SHORT does not appear in the guideline 53 | POS_TO_LABEL34 = {u'SP': 0, u'BA': 1, u'FW': 2, u'DER': 3, u'DEV': 4, u'MSP': 5, u'ETC': 6, u'JJ': 7, u'DT': 8, u'DEC': 9, u'DEG': 10, u'LB': 11, u'LC': 12, u'NN': 13, u'PU': 14, u'NR': 15, u'PN': 16, u'VA': 17, u'VC': 18, u'AD': 19, u'CC': 20, u'VE': 21, u'M': 22, u'CD': 23, u'P': 24, u'AS': 25, u'NR-SHORT': 26, u'VV': 27, u'CS': 28, u'NT': 29, u'OD': 30, u'NN-SHORT': 31, u'SB': 32, u'NT-SHORT': 33} 54 | 55 | #V:0, N:1, LC:2, PN:3, D:4, M:5, AD:6, PP:7, C:8, P:9, O:10 56 | POS_TO_LABEL11 = {u'SP': 9, u'BA': 10, u'FW': 10, u'DER': 9, u'DEV': 9, u'MSP': 9, u'ETC': 9, u'JJ': 10, u'DT': 4, u'DEC': 9, u'DEG': 9, u'LB': 10, u'LC': 2, u'NN': 1, u'PU': 10, u'NR': 1, u'PN': 3, u'VA': 0, u'VC': 0, u'AD': 6, u'CC': 8, u'VE': 0, u'M': 5, u'CD': 4, u'P': 7, u'AS': 9, u'NR-SHORT': 1, u'VV': 0, u'CS': 8, u'NT': 1, u'OD': 4, u'NN-SHORT': 1, u'SB': 10, u'NT-SHORT': 1} 57 | 58 | #Verb:0, Noun:1, Conjunction:2, Punctuation:3, other:4 59 | POS_TO_LABEL5 = {u'SP': 4, u'BA': 4, u'FW': 4, u'DER': 4, u'DEV': 4, u'MSP': 4, u'ETC': 4, u'JJ': 4, u'DT': 4, u'DEC': 4, u'DEG': 4, u'LB': 4, u'LC': 4, u'NN': 1, u'PU': 3, u'NR': 1, u'PN': 1, u'VA': 0, u'VC': 0, u'AD': 4, u'CC': 2, u'VE': 0, u'M': 4, u'CD': 4, u'P': 4, u'AS': 4, u'NR-SHORT': 1, u'VV': 0, u'CS': 2, u'NT': 1, u'OD': 4, u'NN-SHORT': 1, u'SB': 4, u'NT-SHORT': 1} 60 | 61 | POS_TO_LABEL1 = {u'SP': 0, u'BA': 0, u'FW': 0, u'DER': 0, u'DEV': 0, u'MSP': 0, u'ETC': 0, u'JJ': 0, u'DT': 0, u'DEC': 0, u'DEG': 0, u'LB': 0, u'LC': 0, u'NN': 0, u'PU': 0, u'NR': 0, u'PN': 0, u'VA': 0, u'VC': 0, u'AD': 0, u'CC': 0, u'VE': 0, u'M': 0, u'CD': 0, u'P': 0, u'AS': 0, u'NR-SHORT': 0, u'VV': 0, u'CS': 0, u'NT': 0, u'OD': 0, u'NN-SHORT': 0, u'SB': 0, u'NT-SHORT': 0} 62 | 63 | WORD_SIMPLE = u'S' 64 | WORD_BEGIN = u'B' 65 | WORD_MIDDLE = u'M' 66 | WORD_END = u'E' 67 | 68 | WORD_TAG_TO_LABEL = {WORD_SIMPLE:0, WORD_BEGIN:1, WORD_MIDDLE:2, WORD_END:3} 69 | 70 | POS_TO_LABEL = POS_TO_LABEL1 71 | 72 | pos_label_n = max( POS_TO_LABEL.values() )+1 73 | word_tag_label_n = max( WORD_TAG_TO_LABEL.values() )+1 74 | SEQ_TAG_TO_LABEL = {} 75 | #each element is a tuple (word_tag, pos) in which pos is one of pos tag of the label value 76 | #use defaultdict for the condition that the lstm-crf model returns a START or STOP tag 77 | LABEL_TO_WORD_INF = defaultdict(tuple) 78 | for word_tagk, word_tagv in WORD_TAG_TO_LABEL.iteritems(): 79 | for posk, posv in POS_TO_LABEL.iteritems(): 80 | label = word_tagv*pos_label_n+posv 81 | SEQ_TAG_TO_LABEL[word_tagk+u'-'+posk] = label 82 | LABEL_TO_WORD_INF[label] = (word_tagk, posk) 83 | 84 | #print SEQ_TAG_TO_LABEL 85 | #print LABEL_TO_WORD_INF 86 | 87 | STRUCT_LABEL_TRUE = 1 88 | STRUCT_LABEL_FALSE = 0 89 | STRUCT_LABEL_DIM = 2 90 | 91 | 92 | 93 | 94 | ''' 95 | define the Corpus class which contains information of a paragraph: text, segments index, relations and the merge structure 96 | 97 | merge structure: words merge to segments, segments merge to EDU, EDUS merge to bigger discourse unit 98 | 99 | ''' 100 | class Corpus(): 101 | def __init__(self): 102 | #encoding = unicode, easy to index, need to decode back to utf-8 for printing 103 | self.text = '' 104 | #each element is a Word object, some characters merge to words 105 | self.words = [] 106 | #each element is a [start_idx, end_idx] list, a text is devided to segments by given punctuations 107 | self.segment_spans = [] 108 | #each element is a [start_idx, end_idx] list, some segments merge to a EDU 109 | self.edu_spans = [] 110 | #relations of discouse unit level, pre order 111 | self.du_relations = [] 112 | #relations of segment level, pre order 113 | self.seg_relations = [] 114 | #relations of character level, pre order 115 | self.w_relations = [] 116 | #filename+'-'+paragraph_count ex: 001.xml-1, count from 1 117 | self.id = '' 118 | #check whether the edu span and segment span are coodinated 119 | def span_certification(self): 120 | seg_idx = 0 121 | passed = True 122 | ends = [] 123 | for e_span in self.edu_spans: 124 | start = e_span[0] 125 | end = e_span[-1] 126 | while start != self.segment_spans[seg_idx][0]: 127 | ends.append(self.segment_spans[seg_idx][-1]) 128 | seg_idx += 1 129 | if seg_idx == len(self.segment_spans): 130 | print 'certification not passed(start):', self.id, start, self.edu_spans 131 | passed = False 132 | break 133 | 134 | while end != self.segment_spans[seg_idx][-1]: 135 | ends.append(self.segment_spans[seg_idx][-1]) 136 | seg_idx += 1 137 | 138 | if seg_idx == len(self.segment_spans): 139 | print 'certification not passed(end):', self.id, end, self.edu_spans 140 | print ends 141 | passed = False 142 | break 143 | if not passed: 144 | break 145 | return 146 | 147 | class Word(): 148 | def __init__(self, span, pos): 149 | self.span = span 150 | self.pos = pos 151 | 152 | def word_to_labels(word): 153 | char_n = word.span[-1] - word.span[0] + 1 154 | labels = [] 155 | word_tags = [] 156 | if char_n == 1: 157 | word_tags.append(WORD_SIMPLE) 158 | else: 159 | word_tags.append(WORD_BEGIN) 160 | for i in range(char_n-2): 161 | word_tags.append(WORD_MIDDLE) 162 | word_tags.append(WORD_END) 163 | for word_tag in word_tags: 164 | key = word_tag+u'-'+word.pos 165 | labels.append(SEQ_TAG_TO_LABEL[key]) 166 | return labels 167 | 168 | 169 | def labels_to_words_in_test_instance(labels, instance): 170 | words = [] 171 | start_idx = instance.segment_spans[0][0] 172 | end_idx = instance.segment_spans[-1][-1] 173 | char_idx = start_idx 174 | word_start_idx = char_idx 175 | for label in labels: 176 | tag_tuple = LABEL_TO_WORD_INF[label] 177 | word_tag = tag_tuple[0] 178 | pos = tag_tuple[1] 179 | if word_tag == WORD_BEGIN: 180 | word_start_idx = char_idx 181 | elif word_tag == WORD_MIDDLE: 182 | pass 183 | elif word_tag == WORD_END: 184 | span = [word_start_idx, char_idx] 185 | word = Word(span, pos) 186 | words.append(word) 187 | elif word_tag == WORD_SIMPLE: 188 | span = [char_idx, char_idx] 189 | words.append(word) 190 | else: 191 | print 'word tag exception' 192 | print_test_instance(instance) 193 | char_idx += 1 194 | 195 | instance.words = words 196 | 197 | 198 | class Relation(): 199 | def __init__(self, spans, sense, center, type): 200 | #The span index is start from 0, not 1, different with the xml format, ex[[0,5],[6,10],[11,13]] 201 | self.spans = spans 202 | #encoding = unicode, easy to index, need to decode back to utf-8 for printing 203 | self.sense = sense 204 | #unicode, 'former', 'latter', 'equal', different from xml format 205 | self.center = center 206 | #connective type, utf-8, explicit or implicit 207 | self.type = type 208 | return 209 | 210 | '''use to store the syntactic parsing information of a segment''' 211 | class ParseSeg(): 212 | def __init__(self, text, relations, words, id): 213 | #encoding = unicode, easy to index, need to decode back to utf-8 for printing 214 | self.text = text 215 | #pre order 216 | self.relations = relations 217 | #filename+'-'+segment_count ex: chtb_0001.nw.new-1, count from 1 218 | self.id = id 219 | #each element is a Word object, some characters merge to words 220 | self.words = words 221 | 222 | #instance for nn model 223 | class Instance(object): 224 | def __init__(self): 225 | #words initially, merging with each each other during the construction process 226 | self.fragments = [] 227 | #word spans initially, merging with each other during the construction process 228 | self.fragment_spans = [] 229 | #post order 230 | self.i_relations = [] 231 | #each element is a Word object list for a segment, some characters merge to words 232 | self.words_list = [] 233 | 234 | #training instance for nn model 235 | class TrainInstance(Instance): 236 | def __init__(self): 237 | super(TrainInstance, self).__init__() 238 | #STRUCT_LABEL_TRUE or STRUCT_LABEL_FALSE 239 | self.label = '' 240 | #True or False, whether the instance is above segment level in the corpus 241 | self.segment_flag = '' 242 | #for processing segments with lstm if needed 243 | self.segment_spans = [] 244 | #source corpus id + '-train' + '-[produced order number]' + '-[instance number after sampling]' 245 | #ex: 001.xml-1-train-5-2 246 | self.id = '' 247 | 248 | #test instance for nn model 249 | class TestInstance(Instance): 250 | def __init__(self): 251 | super(TestInstance, self).__init__() 252 | self.segment_spans = [] 253 | self.edu_spans = [] 254 | self.du_i_relations = [] 255 | self.seg_i_relations = [] 256 | #list of list, word level relations for each segments 257 | self.w_i_relations_list = [] 258 | #ex:[u',', u'。'], end puncs for each segment 259 | self.puncs = [] 260 | #source corpus id + '-test' + '-[produced order number]' + '-[instance number after sampling]' 261 | #ex: 001.xml-1-train-5-2 262 | self.id = '' 263 | 264 | #relation in Instance 265 | class I_Relation(): 266 | def __init__(self, spans, sense, center, type): 267 | # ex: [[0, 46], [47, 76]] 268 | self.spans = spans 269 | #Variable int 270 | self.sense = sense 271 | #Variable int 272 | self.center = center 273 | #connective type, utf-8, explicit or implicit 274 | self.type = type 275 | 276 | #read xml files from given directory and get Corpus list 277 | def xml_dir_to_corpus_list(xml_dir): 278 | corpus_list = [] 279 | 280 | #for parsing xml file, get the parsed tree root 281 | #tutorial: http://t.cn/RAskmKa 282 | utf8_parser = ET.XMLParser(encoding='utf-8') 283 | 284 | for filename in os.listdir(xml_dir): 285 | if os.path.splitext(filename)[-1] == '.xml': 286 | file_path = os.path.join(xml_dir,filename) 287 | #get the parsing tree of xml file 288 | tree = ET.parse(file_path) 289 | root = tree.getroot() 290 | filename = os.path.basename(file_path) 291 | #for corpus id 292 | paragraph_count=1 293 | for p in root.iter('P'): 294 | rows = p.findall('R') 295 | # some paragraph may be empty! 296 | if rows == []: 297 | continue 298 | corp = Corpus() 299 | r = rows[0] 300 | corp.text = r.get('Sentence').replace(u'|', '') 301 | 302 | corp.segment_spans = text_to_segment_spans(corp.text) 303 | #get the relation object list from xml rows 304 | #the xml rows list the relations in top down, left first style 305 | corp.du_relations = xml_rows_to_ralations(rows) 306 | #may need to assure the pre order of the relations 307 | #there is indeed some paragraph not pre-order from xml! 308 | corp.du_relations = relations_to_pre_order(corp.du_relations) 309 | #get the EDU index by just checking the kinds of the boundary in the relations 310 | corp.edu_spans = relations_to_edu_spans(corp.du_relations) 311 | 312 | corp.id = filename+'-'+str(paragraph_count) 313 | paragraph_count += 1 314 | #print_corpus(corp) 315 | #check whether the edu span and segment span are coodinated 316 | corp.span_certification() 317 | 318 | corpus_list.append(corp) 319 | return corpus_list 320 | 321 | # a parseseg_dict is indexed by file id number. 322 | def parse_dir_to_parseseg_dict(parse_dir): 323 | #for syntactic tree 324 | from nltk.tree import Tree 325 | utf8_parser = ET.XMLParser(encoding='utf-8') 326 | #for each file in the directory, use file_id to index 327 | parseseg_dict = {} 328 | for filename in os.listdir(parse_dir): 329 | #make sure the files are what we want 330 | #print filename 331 | if filename[-3:] == '.nw': 332 | #for parseseg id 333 | file_id = int(filename.split('.')[0].split('_')[-1]) 334 | parseseg_dict[file_id] = [] 335 | file_path = os.path.join(parse_dir,filename) 336 | et_tree = ET.parse( file_path) 337 | root = et_tree.getroot() 338 | #for id, count from 1 339 | segment_count = 1 340 | for sentence in root.iter('S'): 341 | p_tree = Tree.fromstring(sentence.text) 342 | 343 | #return relations, word_count, text 344 | #make relations of a parsing tree in postfixed order 345 | relations, _, words, text = tree_to_relations_and_words(p_tree, 0) 346 | spans = text_to_segment_spans(text) 347 | #split relations to several part according to the spans 348 | splitted_relations = split_relations_by_spans(relations, spans) 349 | word_idx = 0 350 | span_idx = 0 351 | for relations in splitted_relations: 352 | word_start_idx = word_idx 353 | while words[word_idx].span[-1] <= spans[span_idx][-1]: 354 | word_idx += 1 355 | if word_idx >= len(words): 356 | break 357 | seg_words = words[word_start_idx:word_idx] 358 | t = text[ spans[span_idx][0] : spans[span_idx][-1]+1 ] 359 | rs = relations_to_pre_order(relations) 360 | id = filename+'-'+str(segment_count) 361 | segment_count += 1 362 | span_idx += 1 363 | parseseg = ParseSeg(t, rs, seg_words, id) 364 | parseseg_dict[file_id].append(parseseg) 365 | return parseseg_dict 366 | 367 | def text_to_test_instance(text): 368 | instance = TestInstance() 369 | instance.segment_spans = text_to_segment_spans(text) 370 | instance.fragments = [text[i] for i in range(len(text))] 371 | instance.fragment_spans = [[i,i] for i in range(len(text))] 372 | for span in instance.segment_spans: 373 | instance.puncs.append( text[span[-1]] ) 374 | return instance 375 | 376 | #the instance.id of output instances is the same, ex: ex: 001.xml-1-train-5-5 377 | def corpus_to_train_instance_list(corpus): 378 | # for finding pairs, we only store the span, not the whole Relation object 379 | #after find all neighboring pairs, we recover the relation structure and make instances 380 | tr_instances = [] 381 | #use a copy of the corpus to convert to binary structure, not the original object 382 | corp = copy.deepcopy(corpus) 383 | 384 | relations = corp.du_relations 385 | relations.extend(corp.seg_relations) 386 | relations.extend(corp.w_relations) 387 | 388 | #make relations binary preorder 389 | relations = relations_to_binary_preorder(relations) 390 | #the key of the span_to_relation_structure_dict is a span, ex: [ [0,5],[6,10] ] 391 | #the value is the pre order relations when see the key span as root 392 | spans_to_relation_structure_dict = get_spans_to_relation_structure_dict(relations) 393 | #spans information of words level under each segments 394 | 395 | 396 | w_spans_list_list = [] 397 | for i in range(len(corp.segment_spans)): 398 | w_spans_list_list.append([]) 399 | #spans information above segment level 400 | seg_du_spans_list = [] 401 | #for w_spans_list_list, make use of pre order 402 | seg_idx = 0 403 | 404 | for relation in relations: 405 | while relation.spans[0][0] > corp.segment_spans[seg_idx][-1]: 406 | seg_idx += 1 407 | 408 | # condition that the span is above segment level 409 | if relation.spans[0][0] == corp.segment_spans[seg_idx][0] and relation.spans[-1][-1] >= corp.segment_spans[seg_idx][-1]: 410 | seg_du_spans_list.append(relation.spans) 411 | # condition that the span is below segment level 412 | if relation.spans[-1][-1] <= corp.segment_spans[seg_idx][-1]: 413 | w_spans_list_list[seg_idx].append(relation.spans) 414 | #add each word span 415 | for seg_idx in range(len(corp.segment_spans)): 416 | w_spans_list_list[seg_idx] = [ [[i,i]] for i in range(corp.segment_spans[seg_idx][0], corp.segment_spans[seg_idx][1]+1) ] + w_spans_list_list[seg_idx] 417 | 418 | #we separate the seg_du_spans_list and w_spans_list_list to avoid psedo relation construction across word level and segment/du level 419 | #we need seg_du_spans_list to be the first element of spans_list_list to use segment_flag 420 | spans_list_list = [seg_du_spans_list] 421 | if not seg_du_flag: 422 | spans_list_list.extend(w_spans_list_list) 423 | #for instance.id, count from 1 424 | instance_count = 1 425 | segment_flag = True 426 | for spans_list in spans_list_list: 427 | spans_list = spans_list_to_pre_order(spans_list) 428 | #print spans_list 429 | for i in range(len(spans_list)): 430 | now_spans = spans_list[i] 431 | #thanks for the pre order, we only consider the spans after now_span 432 | for j in range(i, len(spans_list)): 433 | # if the span is neighboring right to the now_span, then we can get an instance 434 | if spans_list[j][0][0] == now_spans[-1][-1]+1: 435 | recovered_relations = recover_relations_from_spans_pair(now_spans, spans_list[j], spans_to_relation_structure_dict) 436 | tr_instance = get_train_instance_from_corpus_and_relations(corp, recovered_relations) 437 | #to post order, for construct structure when training 438 | tr_instance.i_relations = relations_to_post_order(tr_instance.i_relations) 439 | tr_instance.segment_flag = segment_flag 440 | #set id 441 | tr_instance.id = corp.id+'-train'+'-'+str(instance_count)+'-'+str(instance_count) 442 | tr_instances.append(tr_instance) 443 | instance_count += 1 444 | segment_flag = False 445 | return tr_instances 446 | 447 | def corpus_to_test_instance(corpus, binary=True): 448 | corp = copy.deepcopy(corpus) 449 | te_instance = TestInstance() 450 | 451 | #make the relations binary pre-order and put them in the instance 452 | if binary: 453 | corp.du_relations = relations_to_binary_preorder(corp.du_relations) 454 | corp.seg_relations = relations_to_binary_preorder(corp.seg_relations) 455 | corp.w_relations = relations_to_binary_preorder(corp.w_relations) 456 | 457 | for relation in corp.du_relations: 458 | i_relation = I_Relation(relation.spans, SENSE_TO_LABEL[relation.sense], CENTER_TO_LABEL[relation.center], relation.type) 459 | te_instance.du_i_relations.append(i_relation) 460 | for relation in corp.seg_relations: 461 | i_relation = I_Relation(relation.spans, SENSE_TO_LABEL[relation.sense], CENTER_TO_LABEL[relation.center], relation.type) 462 | te_instance.seg_i_relations.append(i_relation) 463 | #separate each segment w_i_relations 464 | segment_count = 0 465 | te_instance.w_i_relations_list.append([]) 466 | for relation in corp.w_relations: 467 | i_relation = I_Relation(relation.spans, SENSE_TO_LABEL[relation.sense], CENTER_TO_LABEL[relation.center], relation.type) 468 | if i_relation.spans[0][0] > corp.segment_spans[segment_count][-1]: 469 | segment_count += 1 470 | te_instance.w_i_relations_list.append([]) 471 | te_instance.w_i_relations_list[-1].append(i_relation) 472 | 473 | 474 | te_instance.segment_spans = corp.segment_spans 475 | te_instance.edu_spans = corp.edu_spans 476 | te_instance.fragments = [corp.text[i] for i in range(len(corp.text))] 477 | te_instance.fragment_spans = [[i,i] for i in range(len(corp.text))] 478 | for span in corp.segment_spans: 479 | te_instance.puncs.append( corp.text[span[-1]] ) 480 | get_words_list_in_instance_from_corpus(te_instance, corp) 481 | te_instance.id = corp.id + '-test' + '-1-1' 482 | 483 | return te_instance 484 | 485 | #the key of the span_to_relation_structure_dict is a span, ex: [ [0,5],[6,10] ] 486 | #the value is the pre order relations when see the key span as root 487 | def get_spans_to_relation_structure_dict(relations): 488 | 489 | #[] when use edu span as key in the future 490 | spans_to_relation_structure_dict = defaultdict(list) 491 | 492 | 493 | #build span_to_relation_structure_dict 494 | for i in range(len(relations)): 495 | now_spans = relations[i].spans 496 | #thanks for the pre order, we only consider the spans 'now' and after now_span 497 | for j in range(i, len(relations)): 498 | # if the span is out of now_span 499 | if relations[j].spans[0][0] < now_spans[0][0] or relations[j].spans[-1][-1] > now_spans[-1][-1]: break 500 | #use str() to make list hashable 501 | spans_to_relation_structure_dict[str(now_spans)].append( relations[j] ) 502 | 503 | return spans_to_relation_structure_dict 504 | 505 | def recover_relations_from_spans_pair(spans_left, spans_right, spans_to_relation_structure_dict): 506 | 507 | #get the start boundary and end boundary of span_left and span_right 508 | span_unit_left = [ spans_left[0][0], spans_left[-1][-1] ] 509 | span_unit_right = [ spans_right[0][0], spans_right[-1][-1] ] 510 | 511 | root_span = [ span_unit_left, span_unit_right ] 512 | 513 | relation_list = spans_to_relation_structure_dict[str(root_span)] 514 | 515 | # if in the dict, the root span doesn't exist, it means the relation is not true in the original structure.We construct a psuedo Relation for root relation where sense and center is None 516 | if relation_list == []: 517 | relation_list = spans_to_relation_structure_dict[str(spans_left)] + spans_to_relation_structure_dict[str(spans_right)] 518 | root_relation = Relation(root_span, PSEUDO_SENSE, PSEUDO_CENTER, NON_TYPE) 519 | relation_list = [root_relation] + relation_list 520 | 521 | return relation_list 522 | 523 | def get_train_instance_from_corpus_and_relations(corpus, relations): 524 | 525 | tr_instance = TrainInstance() 526 | 527 | i_relations = [] 528 | for relation in relations: 529 | i_relation = I_Relation(relation.spans, SENSE_TO_LABEL[relation.sense], CENTER_TO_LABEL[relation.center], relation.type) 530 | i_relations.append(i_relation) 531 | tr_instance.i_relations = i_relations 532 | 533 | if i_relations[0].sense == SENSE_TO_LABEL[PSEUDO_SENSE]: 534 | tr_instance.label = STRUCT_LABEL_FALSE 535 | else: 536 | tr_instance.label = STRUCT_LABEL_TRUE 537 | 538 | start_idx = i_relations[0].spans[0][0] 539 | end_idx = i_relations[0].spans[-1][-1] 540 | 541 | tr_instance.fragment_spans = [[idx,idx] for idx in range(start_idx,end_idx+1)] 542 | tr_instance.fragments = [corpus.text[idx] for idx in range(start_idx,end_idx+1)] 543 | #for segment spans 544 | tr_instance.segment_spans = [] 545 | for span in corpus.segment_spans: 546 | if span[0] < start_idx: 547 | continue 548 | if span[-1] > end_idx: 549 | break 550 | tr_instance.segment_spans.append(span) 551 | #for words 552 | get_words_list_in_instance_from_corpus(tr_instance, corpus) 553 | 554 | 555 | 556 | return tr_instance 557 | 558 | def test_instance_and_text_to_dict(test_instance,text): 559 | dict = relations_and_text_to_dict(test_instance.du_i_relations, text) 560 | return dict 561 | 562 | def relations_and_text_to_dict(relations,text): 563 | dict = {} 564 | now_relation = relations[0] 565 | dict['sense'] = LABEL_TO_SENSE[now_relation.sense] 566 | dict['center'] = LABEL_TO_CENTER[now_relation.center] 567 | dict['args'] = [] 568 | arg_count = 0 569 | rel_idx = 1 570 | start_idx = 1 571 | for now_span in now_relation.spans: 572 | while rel_idx < len(relations) and relations[rel_idx].spans[0][0] <= now_span[-1]: 573 | rel_idx += 1 574 | #detect a leaf node 575 | if start_idx == rel_idx: 576 | dict['args'].append(text[now_span[0]:now_span[-1]+1].encode('utf-8')) 577 | continue 578 | else: 579 | dict['args'].append(relations_and_text_to_dict(relations[start_idx:rel_idx],text)) 580 | start_idx = rel_idx 581 | return dict 582 | 583 | 584 | 585 | def get_words_list_in_instance_from_corpus(instance, corpus): 586 | instance.words_list = [] 587 | word_i = 0 588 | for span in instance.segment_spans: 589 | words = [] 590 | word = corpus.words[word_i] 591 | seg_start_idx = span[0] 592 | seg_end_idx = span[-1] 593 | #we need to consider the cross boundary word like '「熬」到 ' 594 | while word.span[0] <= seg_end_idx: 595 | if word.span[0] < seg_start_idx: 596 | if word.span[-1] >= seg_start_idx: 597 | #print corpus.id 598 | #print_word(word, corpus.text) 599 | new_word = Word([seg_start_idx, word.span[-1]],word.pos) 600 | #print_word(new_word, corpus.text) 601 | words.append(new_word) 602 | word_i += 1 603 | elif word.span[-1] > seg_end_idx: 604 | if word.span[0] <= seg_end_idx: 605 | #print corpus.id 606 | #print_word(word, corpus.text) 607 | new_word = Word([word.span[0], seg_end_idx],word.pos) 608 | #print_word(new_word, corpus.text) 609 | words.append(new_word) 610 | #we don't do word_i += 1 because the same word might be used in latter segment 611 | break 612 | else: 613 | words.append(word) 614 | word_i += 1 615 | 616 | if word_i >= len(corpus.words): 617 | break 618 | word = corpus.words[word_i] 619 | 620 | instance.words_list.append(words) 621 | if word_i >= len(corpus.words): 622 | break 623 | 624 | # add segment level relation 625 | def add_corpus_seg_relations(corpus): 626 | seg_relations = [] 627 | seg_i = 0 628 | for edu_span in corpus.edu_spans: 629 | #initial blank left node 630 | left_span = [] 631 | #while still in the edu 632 | while corpus.segment_spans[seg_i][-1] <= edu_span[-1]: 633 | if left_span == []: 634 | left_span = corpus.segment_spans[seg_i] 635 | else: 636 | right_span = corpus.segment_spans[seg_i] 637 | # a edu is finished 638 | if right_span[-1] == edu_span[-1]: 639 | sense = EDU_SENSE 640 | else: 641 | sense = PRE_EDU_SENSE 642 | relation = Relation([left_span,right_span], sense, NON_CENTER, NON_TYPE) 643 | seg_relations.append(relation) 644 | # make a new left node span 645 | left_span = [ left_span[0], right_span[-1] ] 646 | # if it is the last segment 647 | if seg_i + 1 == len(corpus.segment_spans): 648 | break 649 | 650 | seg_i += 1 651 | corpus.seg_relations = seg_relations 652 | return corpus 653 | 654 | #merge the parseseg relations of each segment to a Corpus in the same file id number 655 | def merge_parseseg_dict_to_corpus_list(corpus_list, parseseg_dict): 656 | # start from a non existing file id 657 | file_id = 0 658 | parseseg_idx = 0 659 | not_found = 0 660 | 661 | for corpus in corpus_list: 662 | #print corpus.id 663 | now_file_id = int(corpus.id.split('.')[0].split('-')[-1]) 664 | if now_file_id != file_id: 665 | file_id = now_file_id 666 | parseseg_idx = 0 667 | parseseg_list = parseseg_dict[file_id] 668 | for span_idx in range(len(corpus.segment_spans)): 669 | seg_spans = corpus.segment_spans[span_idx] 670 | seg_text = corpus.text[seg_spans[0] : seg_spans[-1]+1] 671 | parseseg = parseseg_list[parseseg_idx] 672 | while seg_text != parseseg.text: 673 | #print seg_text.encode('utf-8') 674 | #print parseseg.text.encode('utf-8') 675 | #print '\n' 676 | parseseg_idx += 1 677 | #assuming we can find a match before out of boundary 678 | if parseseg_idx >= len(parseseg_list): 679 | print corpus.id 680 | print seg_text.encode('utf-8') 681 | for pseg in parseseg_list: 682 | print pseg.text.encode('utf-8') 683 | not_found += 1 684 | parseseg_idx = 0 685 | break 686 | parseseg = parseseg_list[parseseg_idx] 687 | 688 | #modify parseg span to corpus segment span 689 | #print_parseseg(parseseg) 690 | modify_parseg_span_to_segment_spans(parseseg, seg_spans) 691 | corpus.words.extend(parseseg.words) 692 | corpus.w_relations.extend(parseseg.relations) 693 | #break 694 | 695 | #for set the EDU_SENSE 696 | edu_idx = 0 697 | #there may be many relations suit one EDU since it's not binary 698 | for relation in corpus.w_relations: 699 | if relation.spans[0][0] > corpus.edu_spans[edu_idx][-1]: 700 | edu_idx += 1 701 | if relation.spans[0][0] == corpus.edu_spans[edu_idx][0] and relation.spans[-1][-1] == corpus.edu_spans[edu_idx][-1]: 702 | relation.sense = EDU_SENSE 703 | if edu_idx == len(corpus.edu_spans): 704 | break 705 | print 'not_found count:', not_found 706 | return corpus_list 707 | 708 | def relations_binary_to_multi_preorder(relations): 709 | #reorganize to pre-order 710 | relations = relations_to_pre_order(relations) 711 | new_relations = [] 712 | r_idx = 0 713 | #since relations is preorder, we can find left child(which is not EDU) by checking next relation 714 | if len(relations) > 0: 715 | new_relations.append(relations[0]) 716 | while r_idx < len(relations)-1: 717 | r = new_relations.pop() 718 | next_r = relations[r_idx+1] 719 | #the relation now is a coordination and equal relation 720 | if r.sense == COORD_SENSE_LABEL and r.center == COORD_CENTER_LABEL: 721 | print 'find' 722 | print_i_relation(r) 723 | print_i_relation(next_r) 724 | #the next relation is the left child of now relation a coordination and equal relation 725 | if next_r.spans[0][0] == r.spans[0][0] and next_r.sense == COORD_SENSE_LABEL and next_r.center == COORD_CENTER_LABEL: 726 | r.spans = next_r.spans+r.spans[1:] 727 | r_idx += 1 728 | new_relations.append(r) 729 | continue 730 | new_relations.append(r) 731 | new_relations.append(next_r) 732 | r_idx += 1 733 | 734 | return new_relations 735 | 736 | def relations_to_binary_preorder(relations): 737 | new_added_relations = [] 738 | to_be_removed_relations = [] 739 | for relation in relations: 740 | if len(relation.spans) > 2: 741 | #to delete the original multi-children relation in the future 742 | to_be_removed_relations.append(relation) 743 | #ex: [0,3] 744 | left_span_unit = relation.spans[0] 745 | for idx in range(1,len(relation.spans)): 746 | right_span_unit = relation.spans[idx] 747 | new_spans = [left_span_unit, right_span_unit] 748 | new_relation = Relation(new_spans, relation.sense, relation.center, relation.type) 749 | new_added_relations.append(new_relation) 750 | left_span_unit = [ new_spans[0][0], new_spans[1][1] ] 751 | elif len(relation.spans) == 1: 752 | to_be_removed_relations.append(relation) 753 | for r in to_be_removed_relations: 754 | relations.remove(r) 755 | for r in new_added_relations: 756 | relations.append(r) 757 | 758 | #reorganize to pre-order 759 | relations = relations_to_pre_order(relations) 760 | 761 | return relations 762 | 763 | 764 | def build_word_to_ix_from_corpus_list(corpus_list): 765 | 766 | for corpus in corpus_list: 767 | for word in corpus.text: 768 | if word not in word_to_ix: 769 | word_to_ix[word] = len(word_to_ix) 770 | #for idx of oov 771 | oov = len(word_to_ix) 772 | word_to_ix['oov'] = len(word_to_ix) #for oov when test 773 | 774 | #dump the dictionary to a file 775 | def save_word_to_ix(): 776 | global word_to_ix 777 | with open(WORD_TO_IX_DUMP_PATH, "wb") as myFile: 778 | pickle.dump(word_to_ix, myFile) 779 | 780 | def load_word_to_ix(): 781 | global word_to_ix 782 | with open(WORD_TO_IX_DUMP_PATH, "rb") as myFile: 783 | word_to_ix = pickle.load(myFile) 784 | 785 | #to judge the '——', '。」' condition 786 | def is_punc_in_text(text, idx): 787 | if text[idx] in _ENDs: 788 | if idx+1 < len(text): 789 | if text[idx+1] in _ENDs: 790 | if text[idx] == u'—' or text[idx+1] != u'—': 791 | return False 792 | if text[idx] == u'—' and idx > 1: 793 | if text[idx-1] != u'—': 794 | return False 795 | return True 796 | else: 797 | return False 798 | 799 | 800 | #get the relation object list from xml rows 801 | def xml_rows_to_ralations(xml_rows): 802 | relations = [] 803 | 804 | for r in xml_rows: 805 | sense = r.get('RelationType') 806 | center = r.get('Center') 807 | #get the span 808 | spans = [] 809 | #sentence_position ex: u'27…51|52…63|64…78|79…89' 810 | sentence_position = r.get('SentencePosition') 811 | positions = sentence_position.split(u'|') 812 | for p in positions: 813 | ps = p.split(u'…') 814 | #the span index is start from 0, not 1, different with the xml format 815 | spans.append( [ int(ps[0])-1 ,int(ps[-1])-1 ] ) 816 | relation = Relation(spans, sense, XML_TO_CORPUS_CERTER_DICT[center], r.get('ConnectiveType')) 817 | relations.append(relation) 818 | 819 | return relations 820 | 821 | #get the EDU index by just checking the kinds of the boundary in the relations 822 | def relations_to_edu_spans(relations): 823 | edu_spans = [] 824 | #this set used to collected all kinds of boundary 825 | boundary_idx_set = set() 826 | for relation in relations: 827 | #span example: [ [0, 21], [22, 76], [77, 129] ] 828 | for span_unit in relation.spans: 829 | boundary_idx_set.add(span_unit[0]) 830 | boundary_idx_set.add(span_unit[1]) 831 | 832 | #sorted from left to right 833 | boundary_idx = sorted(boundary_idx_set) 834 | 835 | for i in range(len(boundary_idx)-1): 836 | # 2 index is a pair 837 | if i%2 == 0: 838 | edu_spans.append([ boundary_idx[i], boundary_idx[i+1] ]) 839 | 840 | return edu_spans 841 | 842 | #get segments idx, use punctuations split the text to segments 843 | def text_to_segment_spans(text): 844 | start = 0 845 | end = 0 846 | segment_spans = [] 847 | 848 | for idx in range(len(text)): 849 | if is_punc_in_text(text, idx): 850 | end = idx 851 | segment_spans.append([start,end]) 852 | start = end+1 853 | 854 | #some corpus has no punc in the end!! 855 | if start != len(text): 856 | segment_spans.append([start,len(text)-1]) 857 | 858 | return segment_spans 859 | 860 | 861 | def modify_parseg_span_to_segment_spans(parseseg, seg_spans): 862 | shift_len = seg_spans[0] - parseseg.relations[0].spans[0][0] 863 | #print 'shift_len:', shift_len 864 | for relation in parseseg.relations: 865 | new_spans = [] 866 | for sp in relation.spans: 867 | new_spans.append( [sp[0]+shift_len, sp[-1]+shift_len] ) 868 | relation.spans = new_spans 869 | for word in parseseg.words: 870 | new_span = [ word.span[0]+shift_len, word.span[-1]+shift_len ] 871 | word.span = new_span 872 | 873 | return 874 | 875 | 876 | #split relations to several part according to the spans 877 | def split_relations_by_spans(relations, spans): 878 | splitted_relations = [] 879 | for i in range(len(spans)): 880 | splitted_relations.append([]) 881 | #for simplest condition 882 | if len(spans) == 1: 883 | splitted_relations[0].extend(relations) 884 | return splitted_relations 885 | else: 886 | span_idx = 0 887 | for relation in relations: 888 | r_spans = relation.spans 889 | now_span = spans[span_idx] 890 | # make now_span catch up with the beginning of r_spans 891 | while r_spans[0][0] > now_span[-1]: 892 | span_idx += 1 893 | now_span = spans[span_idx] 894 | # if now_span fully cover r_spans 895 | if r_spans[-1][-1] <= now_span[-1] and r_spans[0][0] >= now_span[0]: 896 | splitted_relations[span_idx].append(relation) 897 | # need multiple spans to cover r_spans 898 | else: 899 | #print 'hard condition' 900 | span_idx = 0 901 | now_span = spans[span_idx] 902 | # make now_span catch up with the beginning of r_spans 903 | while r_spans[0][0] > now_span[-1]: 904 | span_idx += 1 905 | now_span = spans[span_idx] 906 | #the relation spans in the current relation of the relation of now_span 907 | sub_spans = [] 908 | for r_s in r_spans: 909 | #make a copy to operate 910 | r_span = r_s[:] 911 | #while r_span hasn't been used out 912 | while r_span != []: 913 | #print sub_spans, r_span 914 | #if now_span has not caught up with r_span 915 | if r_span[0] > now_span[-1]: 916 | now_relation = Relation(sub_spans, PRE_EDU_SENSE, NON_CENTER, NON_TYPE) 917 | splitted_relations[span_idx].append(now_relation) 918 | span_idx += 1 919 | now_span = spans[span_idx] 920 | sub_spans = [] 921 | else: 922 | #the inter-cover part 923 | sub_span = [ max(r_span[0],now_span[0]), min(r_span[-1], now_span[-1]) ] 924 | sub_spans.append(sub_span) 925 | #cut r_span 926 | if now_span[-1] >= r_span[-1]: 927 | r_span = [] 928 | else: 929 | r_span = [now_span[-1]+1, r_span[-1]] 930 | #for the condition a now_span right inter-cover the r_spans 931 | if sub_spans != []: 932 | now_relation = Relation(sub_spans, PRE_EDU_SENSE, NON_CENTER, NON_TYPE) 933 | splitted_relations[span_idx].append(now_relation) 934 | return splitted_relations 935 | 936 | #return relations, word_count, text 937 | #make relations of a parsing tree in postfixed order 938 | def tree_to_relations_and_words(p_tree, char_count): 939 | relations = [] 940 | words = [] 941 | text = '' 942 | #check if we reach the leaf 943 | if type(p_tree) == type(u''): 944 | chars = p_tree 945 | #print chars.encode('utf-8') 946 | n_char = len(chars) 947 | spans = [] 948 | for i in range(n_char): 949 | span = [char_count,char_count] 950 | char_count += 1 951 | spans.append(span) 952 | relation = Relation(spans, PRE_EDU_SENSE, NON_CENTER, NON_TYPE) 953 | relations.append(relation) 954 | text = chars 955 | return relations,char_count, words, text 956 | else: 957 | spans = [] 958 | #for '-NONE-' condition in the parsing tree 959 | if p_tree.label() == '-NONE-': 960 | return None, char_count, words, text 961 | for child_idx in range(len(p_tree)): 962 | #if we are about to reach the leaf, extract word information in advance 963 | if type(p_tree[child_idx]) == type(u''): 964 | pos = p_tree.label() 965 | n_char = len(p_tree[child_idx]) 966 | w_span = [char_count, char_count+n_char-1] 967 | word = Word(w_span, pos) 968 | #print_word(word, text) 969 | words.append(word) 970 | sub_relations, char_count, sub_words, sub_text = tree_to_relations_and_words(p_tree[child_idx], char_count) 971 | if sub_relations == None: 972 | continue 973 | span = [ sub_relations[-1].spans[0][0], sub_relations[-1].spans[-1][-1] ] 974 | spans.append(span) 975 | relations.extend(sub_relations) 976 | words.extend(sub_words) 977 | text += sub_text 978 | if spans == []: 979 | return None, char_count, words, text 980 | relation = Relation(spans, PRE_EDU_SENSE, NON_CENTER, NON_TYPE) 981 | relations.append(relation) 982 | return relations, char_count, words, text 983 | 984 | def spans_list_to_pre_order(spans_list): 985 | # this is for python3 removing the cmp in the sorted function, to transfer a cmp to key 986 | from functools import cmp_to_key 987 | #cmp_for_pre_order_from_span: compare function for sorting for pre order 988 | key_for_pre_order_from_spans = cmp_to_key(cmp_for_pre_order_from_spans) 989 | 990 | spans_list = sorted(spans_list, key=key_for_pre_order_from_spans, reverse=True) 991 | 992 | return spans_list 993 | 994 | #compare function for sorting for pre order 995 | def cmp_for_pre_order_from_spans(spans_1, spans_2): 996 | #compare the start boundary first 997 | if spans_1[0][0] < spans_2[0][0]: 998 | return 1 999 | #if same, compare the end boundary 1000 | elif spans_1[0][0] == spans_2[0][0]: 1001 | if spans_1[-1][-1] > spans_2[-1][-1]: 1002 | return 1 1003 | return -1 1004 | 1005 | def relations_to_pre_order(relations): 1006 | # this is for python3 removing the cmp in the sorted function, to transfer a cmp to key 1007 | from functools import cmp_to_key 1008 | #cmp_for_pre_order: compare function for sorting for pre order 1009 | key_for_pre_order_from_relation = cmp_to_key(cmp_for_pre_order_from_relation) 1010 | 1011 | relations = sorted(relations, key=key_for_pre_order_from_relation, reverse=True) 1012 | 1013 | return relations 1014 | 1015 | #compare function for sorting for pre order 1016 | def cmp_for_pre_order_from_relation(rel_1, rel_2): 1017 | #compare the start boundary first 1018 | if rel_1.spans[0][0] < rel_2.spans[0][0]: 1019 | return 1 1020 | #if same, compare the end boundary 1021 | elif rel_1.spans[0][0] == rel_2.spans[0][0]: 1022 | if rel_1.spans[-1][-1] > rel_2.spans[-1][-1]: 1023 | return 1 1024 | return -1 1025 | 1026 | def relations_to_post_order(relations): 1027 | # this is for python3 removing the cmp in the sorted function, to transfer a cmp to key 1028 | from functools import cmp_to_key 1029 | #cmp_for_post_order_from_i_relation: compare function for sorting for post order 1030 | key_for_post_order_from_relation = cmp_to_key(cmp_for_post_order_from_relation) 1031 | 1032 | relations = sorted(relations, key=key_for_post_order_from_relation, reverse=True) 1033 | 1034 | return relations 1035 | 1036 | #compare function for sorting for post order 1037 | def cmp_for_post_order_from_relation(r_1, r_2): 1038 | #compare the end boundary first 1039 | if r_1.spans[-1][-1] < r_2.spans[-1][-1]: 1040 | return 1 1041 | #if same, compare the start boundary 1042 | elif r_1.spans[-1][-1] == r_2.spans[-1][-1]: 1043 | if r_1.spans[0][0] > r_2.spans[0][0]: 1044 | return 1 1045 | return -1 1046 | 1047 | def print_EDUs_from_corpus(corpus): 1048 | for span in corpus.edu_spans: 1049 | print corpus.text[span[0]:span[-1]+1].encode('utf-8') 1050 | 1051 | def print_corpus(corpus): 1052 | print 'id:', corpus.id 1053 | print 'text:', corpus.text.encode('utf-8') 1054 | for word in corpus.words: 1055 | print_word(word, corpus.text) 1056 | print 'segments_span:', corpus.segment_spans 1057 | print 'edus_span:', corpus.edu_spans 1058 | print 'du_relations:' 1059 | for relation in corpus.du_relations: 1060 | print_relation(relation) 1061 | print 'seg_relations:' 1062 | for relation in corpus.seg_relations: 1063 | print_relation(relation) 1064 | print 'w_relations:' 1065 | for relation in corpus.w_relations: 1066 | print_relation(relation) 1067 | 1068 | def print_word(word, text): 1069 | if text != '': 1070 | print 'word:', text[word.span[0]: word.span[-1]+1].encode('utf-8'), 1071 | print 'span:',word.span, 1072 | print 'pos:', word.pos.encode('utf-8'), 1073 | 1074 | def print_relation(relation): 1075 | print relation.spans, relation.sense.encode('utf-8'), relation.center.encode('utf-8'),relation.type.encode('utf-8'), 1076 | return 1077 | 1078 | def print_i_relation(relation): 1079 | #if relation.sense != 5: 1080 | print relation.spans, relation.sense, relation.center, 1081 | return 1082 | 1083 | def print_parseseg(parseseg): 1084 | print parseseg.id 1085 | print parseseg.text.encode('utf-8') 1086 | for word in parseseg.words: 1087 | #text = '' due to index is not modified 1088 | print_word(word, '') 1089 | for relation in parseseg.relations: 1090 | print_relation(relation) 1091 | 1092 | def print_train_instance(instance): 1093 | for words in instance.words_list: 1094 | for word in words: 1095 | print_word(word, ''), 1096 | print '' 1097 | print 'fragments:' 1098 | if type(instance.fragments[0]) == type(u''): 1099 | for fragment in instance.fragments: 1100 | print fragment.encode('utf8'),',', 1101 | else: 1102 | for fragment in instance.fragments: 1103 | print fragment,',', 1104 | print '' 1105 | print 'fragment_spans: ', instance.fragment_spans 1106 | #ex:[[0, 46], [47, 61], [62, 76]] 1107 | print 'segment_spans:', instance.segment_spans 1108 | print 'i_relations:' 1109 | for i_relation in instance.i_relations: 1110 | print_i_relation(i_relation) 1111 | print '' 1112 | #True or False 1113 | print'label:', instance.label 1114 | print'id:', instance.id 1115 | 1116 | def print_test_instance(instance): 1117 | for words in instance.words_list: 1118 | for word in words: 1119 | print_word(word, ''), 1120 | print 'fragments:' 1121 | if type(instance.fragments[0]) == type(u''): 1122 | for fragment in instance.fragments: 1123 | print fragment.encode('utf8'),',', 1124 | else: 1125 | for fragment in instance.fragments: 1126 | print fragment,',', 1127 | print '' 1128 | print 'fragment_spans: ', instance.fragment_spans 1129 | print 'segment_spans:', instance.segment_spans 1130 | print 'edu_spans:', instance.edu_spans 1131 | print 'puncs:' 1132 | for punc in instance.puncs: 1133 | print punc.encode('utf-8'),' ', 1134 | #ex:[[0, 46], [47, 61], [62, 76]] 1135 | print 'du_i_relations:' 1136 | for i_relation in instance.du_i_relations: 1137 | print_i_relation(i_relation) 1138 | print 'seg_i_relations:' 1139 | for i_relation in instance.seg_i_relations: 1140 | print_i_relation(i_relation) 1141 | print 'w_i_relations_list:' 1142 | for i_relation_list in instance.w_i_relations_list: 1143 | for i_relation in i_relation_list: 1144 | print_i_relation(i_relation) 1145 | print 'i_relations:' 1146 | for i_relation in instance.i_relations: 1147 | print_i_relation(i_relation) 1148 | print '' 1149 | print'id:', instance.id 1150 | 1151 | 1152 | def print_word_to_ix(): 1153 | for k, v in word_to_ix.iteritems(): 1154 | print k.encode('utf-8'),':',v,'\t', 1155 | 1156 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import data_structure 2 | 3 | from sklearn.metrics import precision_recall_fscore_support 4 | from collections import defaultdict 5 | 6 | class Evaluater(): 7 | def __init__(self): 8 | #for du level parseval 9 | self.du_parse_eval_data_dict = {'merge': {}, 'sense': {}, 'center': {}, 'join': {}} 10 | #for word level parseval 11 | self.w_parse_eval_data_dict = {'merge': {}, 'sense': {}, 'center': {}, 'join': {}} 12 | #for char tags 13 | self.char_tags_eval_data_dict = {'word_seg': {}, 'pos': {}} 14 | for key in self.du_parse_eval_data_dict: 15 | self.du_parse_eval_data_dict[key] = {'tp':0, 'gold_n':0, 'pred_n':0} 16 | self.w_parse_eval_data_dict[key] = {'tp':0, 'gold_n':0, 'pred_n':0} 17 | for key in self.char_tags_eval_data_dict: 18 | self.char_tags_eval_data_dict[key] = {'tp':0, 'gold_n':0, 'pred_n':0} 19 | #for edu 20 | self.gold_edu_tag_list = [] 21 | self.pred_edu_tag_list = [] 22 | 23 | #order: sense1, sense2, center, child_num, binary, explicit 24 | self.relation_distribution_eval_dict = defaultdict(lambda: defaultdict(int)) 25 | self.edu_punc_distrbution = defaultdict(lambda: defaultdict(int)) 26 | 27 | def collect_eval_data(self, gold_te_instance, pred_te_instance): 28 | #self.get_tag_eval_data(gold_te_instance, pred_te_instance) 29 | self.get_edu_eval_data(gold_te_instance, pred_te_instance) 30 | self.get_du_parse_eval_data(gold_te_instance, pred_te_instance) 31 | 32 | def get_tag_eval_data(self, gold_te_instance, pred_te_instance): 33 | self.char_tags_eval_data_dict['word_seg']['gold_n'] += len(gold_te_instance.words) 34 | self.char_tags_eval_data_dict['word_seg']['pred_n'] += len(pred_te_instance.words) 35 | self.char_tags_eval_data_dict['pos']['gold_n'] += len(gold_te_instance.words) 36 | self.char_tags_eval_data_dict['pos']['pred_n'] += len(pred_te_instance.words) 37 | for gold_word in gold_te_instance.words: 38 | for pred_word in pred_te_instance.words: 39 | if gold.word.span == pred.word.span: 40 | self.char_tags_eval_data_dict['word_seg']['tp'] += 1 41 | #we need to use POS_TO_LABEL since different pos may belong to same class 42 | if data_structure.POS_TO_LABEL[gold.word.pos] == data_structure.POS_TO_LABEL[pred.word.pos]: 43 | self.char_tags_eval_data_dict['pos']['tp'] += 1 44 | 45 | def get_edu_eval_data(self, gold_te_instance, pred_te_instance): 46 | 47 | gold_edu_boundary_list = [span[-1] for span in gold_te_instance.edu_spans] 48 | pred_edu_spans = data_structure.relations_to_edu_spans(pred_te_instance.du_i_relations) 49 | pred_edu_boundary_list = [span[-1] for span in pred_edu_spans] 50 | 51 | 52 | gold_edu_tag_list = [] 53 | pred_edu_tag_list = [] 54 | punc_list = gold_te_instance.puncs 55 | 56 | 57 | seg_boundary_list = [ span[-1] for span in gold_te_instance.segment_spans ] 58 | for boundary in seg_boundary_list: 59 | if boundary in gold_edu_boundary_list: 60 | gold_edu_tag_list.append(1) 61 | else: 62 | gold_edu_tag_list.append(0) 63 | if boundary in pred_edu_boundary_list: 64 | pred_edu_tag_list.append(1) 65 | else: 66 | pred_edu_tag_list.append(0) 67 | 68 | 69 | for tag, punc in zip(gold_edu_tag_list,punc_list): 70 | print tag, punc.encode('utf-8'), 71 | print '' 72 | for tag, punc in zip(pred_edu_tag_list,punc_list): 73 | print tag, punc.encode('utf-8'), 74 | print '\n\n' 75 | #print gold_edu_tag_list 76 | #print pred_edu_tag_list 77 | 78 | self.gold_edu_tag_list.extend(gold_edu_tag_list) 79 | self.pred_edu_tag_list.extend(pred_edu_tag_list) 80 | return 81 | 82 | def get_du_parse_eval_data(self, gold_te_instance, pred_te_instance): 83 | 84 | for key in self.du_parse_eval_data_dict: 85 | # use max() in case of only one edu 86 | self.du_parse_eval_data_dict[key]['gold_n'] += max(0,len(gold_te_instance.du_i_relations)-1) 87 | self.du_parse_eval_data_dict[key]['pred_n'] += max(0,len(pred_te_instance.du_i_relations)-1) 88 | 89 | for pr in pred_te_instance.du_i_relations: 90 | self.relation_distribution_eval_dict[str(pr.sense)+'_all']['pred_n'] += 1 91 | self.relation_distribution_eval_dict[pr.type+'_all']['pred_n'] += 1 92 | for gr in gold_te_instance.du_i_relations: 93 | self.relation_distribution_eval_dict[str(gr.sense)+'_all']['gold_n'] += 1 94 | self.relation_distribution_eval_dict[gr.type+'_all']['gold_n'] += 1 95 | 96 | for pr in pred_te_instance.du_i_relations: 97 | for gr in gold_te_instance.du_i_relations: 98 | if gr.spans[0][0] == pr.spans[0][0] and gr.spans[-1][-1] == pr.spans[-1][-1]: 99 | self.relation_distribution_eval_dict[str(pr.sense)]['pred_n'] += 1 100 | self.relation_distribution_eval_dict[pr.type]['pred_n'] += 1 101 | self.relation_distribution_eval_dict[str(gr.sense)]['gold_n'] += 1 102 | self.relation_distribution_eval_dict[gr.type]['gold_n'] += 1 103 | if gr.sense == pr.sense: 104 | self.relation_distribution_eval_dict[str(gr.sense)+'_all']['tp'] += 1 105 | self.relation_distribution_eval_dict[gr.type+'_all']['tp'] += 1 106 | self.relation_distribution_eval_dict[str(gr.sense)]['tp'] += 1 107 | self.relation_distribution_eval_dict[gr.type]['tp'] += 1 108 | # exclude the root node 109 | if gr.spans[0][0] != 0 or gr.spans[-1][-1] != gold_te_instance.segment_spans[-1][-1]: 110 | self.du_parse_eval_data_dict['merge']['tp'] += 1 111 | if gr.sense == pr.sense: 112 | self.du_parse_eval_data_dict['sense']['tp'] += 1 113 | if gr.center == pr.center: 114 | self.du_parse_eval_data_dict['center']['tp'] += 1 115 | if gr.sense == pr.sense and gr.center == pr.center: 116 | self.du_parse_eval_data_dict['join']['tp'] += 1 117 | 118 | def show_eval_result(self): 119 | 120 | edu_result = precision_recall_fscore_support(self.gold_edu_tag_list, self.pred_edu_tag_list, average='binary') 121 | print 'edu_result:', edu_result 122 | for key, dict in self.du_parse_eval_data_dict.iteritems(): 123 | print key,': ', 124 | for k, v in dict.iteritems(): 125 | print k, ': ', v, 126 | if dict['gold_n'] != 0: 127 | print 'recall: ', float( dict['tp'] )/dict['gold_n'], 128 | else: 129 | print 'recall: 0', 130 | if dict['pred_n'] != 0: 131 | print 'precision: ', float( dict['tp'] )/dict['pred_n'], 132 | else: 133 | print 'precision: 0', 134 | if dict['gold_n'] + dict['pred_n'] != 0: 135 | print 'f1: ', 2*float( dict['tp'] )/( dict['gold_n'] + dict['pred_n'] ) 136 | else: 137 | print 'f1: 0' 138 | 139 | for key, dict in self.relation_distribution_eval_dict.iteritems(): 140 | if type(key) == type(1): 141 | print key,': ', 142 | else: 143 | print key.encode('utf-8'),': ', 144 | for k, v in dict.iteritems(): 145 | print k, ': ', v, 146 | if dict['gold_n'] != 0: 147 | print 'recall: ', float( dict['tp'] )/dict['gold_n'], 148 | else: 149 | print 'recall: 0', 150 | if dict['pred_n'] != 0: 151 | print 'precision: ', float( dict['tp'] )/dict['pred_n'], 152 | else: 153 | print 'precision: 0', 154 | if dict['gold_n'] + dict['pred_n'] != 0: 155 | print 'f1: ', 2*float( dict['tp'] )/( dict['gold_n'] + dict['pred_n'] ) 156 | else: 157 | print 'f1: 0' 158 | 159 | def show_single_eval_result(self, gold_instance, pred_instance): 160 | 161 | edu_result = precision_recall_fscore_support(self.gold_edu_tag_list, self.pred_edu_tag_list, average='binary') 162 | f1_dict = {} 163 | for key, dict in self.du_parse_eval_data_dict.iteritems(): 164 | print key,': ', 165 | if dict['gold_n'] + dict['pred_n'] != 0: 166 | f1_dict[key] = 2*float( dict['tp'] )/( dict['gold_n'] + dict['pred_n'] ) 167 | else: 168 | f1_dict[key] = 0 169 | 170 | if f1_dict['join'] > -1.0 and f1_dict['join'] < 1.1: 171 | print edu_result 172 | for k, v in f1_dict.iteritems(): 173 | print k, v 174 | data_structure.print_test_instance(gold_instance) 175 | data_structure.print_test_instance(pred_instance) 176 | 177 | 178 | 179 | def show_relation_distribution_from_corpus_list(self, corpus_list): 180 | #order: sense1, sense2, center, child_num, binary, explicit 181 | relation_distribution_dict = defaultdict(int) 182 | example_count_dict = defaultdict(int) 183 | example_dict = defaultdict(list) 184 | 185 | for corpus in corpus_list: 186 | for relation in corpus.du_relations: 187 | key_s1 = str(data_structure.SENSE_TO_LABEL[relation.sense]) 188 | key_s2 = relation.sense 189 | key_c = str(relation.center) 190 | key_child_num = str(len(relation.spans)) 191 | if len(relation.spans) == 2: 192 | key_binary = 'binary' 193 | else: 194 | key_binary = 'multi' 195 | key_type = relation.type 196 | key_list = [key_s1,key_s2,key_c,key_child_num,key_binary, key_type] 197 | key_combinations = [''] 198 | for key in key_list: 199 | new_combinations = [] 200 | for key_combination in key_combinations: 201 | new_combinations.append(key_combination+'_'+key) 202 | new_combinations.append(key_combination+'_'+'N') 203 | key_combinations = new_combinations[:] 204 | for key_cb in key_combinations: 205 | relation_distribution_dict[key_cb] += 1 206 | if len(relation.spans) >= 7: 207 | print len(relation.spans) 208 | for span in relation.spans: 209 | print corpus.text[span[0]:span[-1]+1].encode('utf-8') 210 | 211 | key1 = key_s1+'_'+key_type 212 | key2 = key_s1+'_'+key_c 213 | if example_count_dict[key1] < 3: 214 | example_dict[key1].append([]) 215 | for span in relation.spans: 216 | example_dict[key1][-1].append(corpus.text[span[0]:span[-1]+1].encode('utf-8')) 217 | example_count_dict[key1] += 1 218 | if example_count_dict[key2] < 3: 219 | example_dict[key2].append([]) 220 | for span in relation.spans: 221 | example_dict[key2][-1].append(corpus.text[span[0]:span[-1]+1].encode('utf-8')) 222 | example_count_dict[key2] += 1 223 | 224 | for key in example_dict: 225 | print key.encode('utf-8') 226 | for example in example_dict[key]: 227 | for argument in example: 228 | print argument 229 | print '' 230 | 231 | relation_distribution_list = sorted(relation_distribution_dict.iteritems(), key=lambda x:x[0] ) 232 | for k_v in relation_distribution_list: 233 | print k_v[0].encode('utf-8')+':', k_v[1],',\t', 234 | 235 | def show_edu_punc_distribution_from_corpus_list(self, corpus_list): 236 | for corpus in corpus_list: 237 | for span in corpus.segment_spans: 238 | self.edu_punc_distrbution['punc_number'][corpus.text[span[-1]].encode('utf-8')] += 1 239 | for span in corpus.edu_spans: 240 | self.edu_punc_distrbution['edu_punc_number'][corpus.text[span[-1]].encode('utf-8')] += 1 241 | for key, v in self.edu_punc_distrbution['punc_number'].iteritems(): 242 | edu_v = self.edu_punc_distrbution['edu_punc_number'][key] 243 | print key, v, edu_v, float(edu_v)/v 244 | 245 | def show_analysis_from_corpus_list(self, corpus_list): 246 | edu_n = 0 247 | paragraph_n = 0 248 | relation_n = 0 249 | for corpus in corpus_list: 250 | edu_n += len(corpus.edu_spans) 251 | paragraph_n += 1 252 | relation_n += len(corpus.du_relations) 253 | print 'edu_n:', edu_n 254 | print 'paragraph_n:', paragraph_n 255 | print 'relation_n:', relation_n 256 | 257 | def show_pos_distribution_from_corpus_list(self, corpus_list): 258 | pos_distribution = defaultdict(int) 259 | for corpus in corpus_list: 260 | for word in corpus.words: 261 | pos_distribution[word.pos] += 1 262 | print pos_distribution -------------------------------------------------------------------------------- /input.txt: -------------------------------------------------------------------------------- 1 | 据统计,这些城市去年完成国内生产总值一百九十多亿元,比开放前的一九九一年增长九成多。国务院于一九九二年先后批准了黑河、凭祥、珲春、伊宁、瑞丽等十四个边境城市为对外开放城市,同时还批准这些城市设立十四个边境经济合作区。三年多来,这些城市社会经济发展迅速,地方经济实力明显增强;经济年平均增长百分之十七,高于全国年平均增长速度。 -------------------------------------------------------------------------------- /model_rvnn_lstm_20: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abccaba2000/discourse-parser/a7414e66e48621b3a1ae9bedd17b07a8be1487ce/model_rvnn_lstm_20 -------------------------------------------------------------------------------- /model_user.py: -------------------------------------------------------------------------------- 1 | import data_structure 2 | import rvnn 3 | import evaluate 4 | 5 | import sys 6 | import copy 7 | import time 8 | import random 9 | 10 | import torch 11 | import torch.autograd as autograd 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | 16 | 17 | #whether to reduce instances by sampling 18 | sample_flag = True 19 | 20 | SAVE_MODEL_PATH = './model_test_loss_batch10_1000_10' 21 | 22 | #for NN model config 23 | EMBEDDING_DIM = 64 24 | RVNN_HIDDEN_DIM = 64 25 | BILSTM_CRF_HIDDEN_DIM = 64 26 | LSTM_HIDDEN_DIM = 64 27 | LEARNING_RATE = 0.1 28 | SAMPLE_NUMBER = 30 29 | #multiple factor of the segment/DU level instances to sample word level instances 30 | SAMPLE_FACTOR = 1 31 | EPOCHS = 10 32 | BATCH_SIZE = 10 33 | GPU = True 34 | 35 | lstm_crf_only_flag = False 36 | lstm_crf_flag = False 37 | 38 | 39 | class Config(): 40 | def __init__(self, vocab_size, d_embedding, d_bilstm_crf_hidden, tag_size, d_lstm_hidden, d_rvhidden, d_struct, d_center, d_relation, gpu): 41 | self.vocab_size = vocab_size 42 | self.d_embedding = d_embedding 43 | self.d_bilstm_crf_hidden = d_bilstm_crf_hidden 44 | self.tag_size = tag_size 45 | self.d_lstm_hidden = d_lstm_hidden 46 | self.d_rvhidden = d_rvhidden 47 | self.d_struct = d_struct 48 | self.d_center = d_center 49 | self.d_relation = d_relation 50 | self.gpu = gpu 51 | 52 | def get_model(model_path=None): 53 | if model_path != None: 54 | model = torch.load(model_path) 55 | elif lstm_crf_only_flag: 56 | model = rvnn.BiLSTM_CRF(len(data_structure.word_to_ix), max(data_structure.SEQ_TAG_TO_LABEL.values())+1, EMBEDDING_DIM, EMBEDDING_DIM) 57 | else: 58 | config = Config(len(data_structure.word_to_ix), EMBEDDING_DIM, BILSTM_CRF_HIDDEN_DIM, max(data_structure.SEQ_TAG_TO_LABEL.values())+1, LSTM_HIDDEN_DIM, RVNN_HIDDEN_DIM, data_structure.STRUCT_LABEL_DIM, data_structure.CENTER_DIM, data_structure.SENSE_DIM, GPU) 59 | model = rvnn.RVNN(config) 60 | 61 | #gpu 62 | if GPU: 63 | model.cuda() 64 | 65 | return model 66 | 67 | def prepare_sequence(seq, to_ix): 68 | idxs = [to_ix[w] for w in seq] 69 | return torch.tensor(idxs, dtype=torch.long) 70 | 71 | def demo_predict(model, instance): 72 | instance.fragments = text_to_nn_word_list(instance.fragments) 73 | return_instance = model(instance) 74 | return return_instance 75 | 76 | def train_from_instances(model, train_instances): 77 | 78 | optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE) 79 | loss_function = nn.NLLLoss() 80 | if GPU: 81 | loss_function.cuda() 82 | for epoch in xrange(EPOCHS): 83 | start_time = time.time() 84 | #if sample_flag: 85 | # train_instances = sample_train_instances(train_instances) 86 | random.shuffle(instances) 87 | 88 | instance_count = 0 89 | 90 | struct_loss_sum = 0 91 | center_loss_sum = 0 92 | sense_loss_sum = 0 93 | tag_loss_sum = 0 94 | objective_loss_sum = 0 95 | 96 | objective_loss = 0 97 | 98 | for train_instance in train_instances: 99 | #if train_instance.id[0:18] != '593.xml-5-train-27-13737'[0:18]: 100 | # continue 101 | #data_structure.print_train_instance(train_instance) 102 | if instance_count%100 == 0: 103 | print('trained instances: %d'%instance_count) 104 | sys.stderr.write('trained instances: %d'%instance_count) 105 | print('instance id: %s\r'%train_instance.id) 106 | sys.stderr.write('instance id: %s\r'%train_instance.id) 107 | instance_count += 1 108 | #instance = train_instance 109 | instance = copy.deepcopy(train_instance) 110 | 111 | instance.fragments = text_to_nn_word_list(instance.fragments) 112 | model.zero_grad() 113 | ''' 114 | if lstm_crf_only_flag: 115 | for words in train_instance.words_list: 116 | tag_labels = [] 117 | for word in words: 118 | tag_labels.extend(data_structure.word_to_labels(word)) 119 | #tag_targets = torch.LongTensor(tags) 120 | #print tags 121 | tag_targets = prepare_tensor_train_ans(tag_labels) 122 | tag_loss = model.neg_log_likelihood(instance.fragments, tag_targets) 123 | #tag_loss = autograd.Variable(torch.LongTensor([0])) 124 | tag_loss_sum += tag_loss.data[0] 125 | objective_loss += tag_loss 126 | ''' 127 | #else: 128 | #scores_list: [, struct_scores, ] 129 | if instance.label == data_structure.STRUCT_LABEL_TRUE: 130 | instance, scores_list = model.forward_by_i_relations(instance, struct=True, center=True, sense=True) 131 | else: 132 | instance, scores_list = model.forward_by_i_relations(instance, struct=True, center=False, sense=False) 133 | 134 | score_idx = 0 135 | 136 | #for char tag 137 | if lstm_crf_flag: 138 | tag_score_list = scores_list[score_idx] 139 | tag_score_idx = 0 140 | instance1 = copy.deepcopy(train_instance) 141 | start_idx = instance1.segment_spans[0][0] 142 | for seg_idx in range(len(instance1.segment_spans)): 143 | #for gold tags 144 | tag_labels = [] 145 | words = instance1.words_list[seg_idx] 146 | for word in words: 147 | tag_labels.extend(data_structure.word_to_labels(word)) 148 | #for segment input 149 | span = instance1.segment_spans[seg_idx] 150 | sentence = text_to_nn_word_list(instance.fragments[span[0]-start_idx:span[-1]+1-start_idx]) 151 | 152 | tag_targets = prepare_tensor_train_ans(tag_labels) 153 | gold_tag_score = model.score_sentence(sentence, tag_targets) 154 | #according to the lstm_crf implementation in pytorch tutorial, this subtraction computes NLLLoss 155 | tag_loss = tag_score_list[tag_score_idx] - gold_tag_score 156 | tag_score_idx += 1 157 | tag_loss_sum += tag_loss.data[0] 158 | objective_loss += tag_loss 159 | score_idx += 1 160 | 161 | if not(lstm_crf_flag and lstm_crf_only_flag): 162 | 163 | #for struct 164 | struct_target = prepare_train_ans([instance.label]) 165 | struct_loss = loss_function(scores_list[score_idx], struct_target) 166 | #print('struct_scores: ', scores_list[score_idx], 'struct_target: ', struct_target) 167 | struct_loss_sum += struct_loss.data[0] 168 | objective_loss += struct_loss 169 | score_idx += 1 170 | 171 | if instance.label == data_structure.STRUCT_LABEL_TRUE and instance.segment_flag: 172 | #for center and sense 173 | #we use index -1 since the i_relations is already in post order 174 | center_target = prepare_train_ans([instance.i_relations[-1].center]) 175 | center_loss = loss_function(scores_list[score_idx], center_target) 176 | center_loss_sum += center_loss 177 | score_idx += 1 178 | #print('center_scores: ', scores_list[score_idx], 'center_target: ', center_target) 179 | #we use index -1 since the i_relations is already in post order 180 | sense_target = prepare_train_ans([instance.i_relations[-1].sense]) 181 | sense_loss = loss_function(scores_list[score_idx], sense_target) 182 | sense_loss_sum += sense_loss 183 | #print('sense_scores: ', scores_list[score_idx], 'sense_target: ', sense_target) 184 | objective_loss += center_loss 185 | objective_loss += sense_loss 186 | 187 | if instance_count%BATCH_SIZE == 0: 188 | objective_loss_sum += objective_loss.data[0] 189 | objective_loss.backward() 190 | optimizer.step() 191 | objective_loss = 0 192 | end_time = time.time() 193 | print('use time: %d'%(end_time-start_time)) 194 | sys.stderr.write('use time: %d'%(end_time-start_time)) 195 | print('epoch %d lose: struct: %f, center: %f, sense: %f, tag: %f, objective: %f\n'%(epoch, struct_loss_sum, center_loss_sum, sense_loss_sum, tag_loss_sum, objective_loss_sum)) 196 | sys.stderr.write('epoch %d lose: struct: %f, center: %f, sense: %f, tag: %f, objective: %f\n'%(epoch, struct_loss_sum, center_loss_sum, sense_loss_sum, tag_loss_sum, objective_loss_sum)) 197 | 198 | 199 | if SAVE_MODEL_PATH != None: 200 | torch.save(model,SAVE_MODEL_PATH) 201 | 202 | return 203 | 204 | def test_from_corpus_list(model, corpus_list): 205 | evaluater = evaluate.Evaluater() 206 | evaluater2 = evaluate.Evaluater() 207 | evaluater3 = evaluate.Evaluater() 208 | count = 0 209 | print 'vocab_size', len(data_structure.word_to_ix) 210 | for corpus in corpus_list: 211 | 212 | #if count == int(sys.argv[1]): 213 | # continue 214 | #print corpus.id 215 | #count += 1 216 | test_instance = data_structure.corpus_to_test_instance(corpus, binary=True) 217 | # multinuclear gold instnace 218 | gold_multi_instance = data_structure.corpus_to_test_instance(corpus, binary=False) 219 | gold_binary_instance = data_structure.corpus_to_test_instance(corpus, binary=True) 220 | #print test_instance.fragments 221 | 222 | test_instance.fragments = text_to_nn_word_list(test_instance.fragments) 223 | model.zero_grad() 224 | #print 'test_instance' 225 | #data_structure.print_test_instance(test_instance) 226 | if lstm_crf_only_flag: 227 | _, labels = model(test_instance.fragments) 228 | #return_instance = data_structure.labels_to_words_in_test_instance(labels, instance) 229 | 230 | else: 231 | return_instance = model(test_instance) 232 | #print 'corpus', 233 | #data_structure.print_corpus(corpus) 234 | return_instance.i_relations = data_structure.relations_to_post_order(return_instance.i_relations) 235 | ''' 236 | print 'return_instance' 237 | data_structure.print_test_instance(return_instance) 238 | print 'gold_binary_instance' 239 | data_structure.print_test_instance(gold_binary_instance) 240 | print 'gold_multi_instance' 241 | data_structure.print_test_instance(gold_multi_instance) 242 | ''' 243 | 244 | evaluater_tmp1 = evaluate.Evaluater() 245 | evaluater_tmp1.collect_eval_data(gold_binary_instance, return_instance) 246 | evaluater_tmp1.show_single_eval_result(gold_binary_instance, return_instance) 247 | 248 | evaluater.collect_eval_data(gold_binary_instance, return_instance) 249 | evaluater2.collect_eval_data(gold_multi_instance, return_instance) 250 | return_instance.du_i_relations = \ 251 | data_structure.relations_binary_to_multi_preorder(return_instance.du_i_relations) 252 | evaluater3.collect_eval_data(gold_multi_instance, return_instance) 253 | 254 | evaluater_tmp2 = evaluate.Evaluater() 255 | evaluater_tmp2.collect_eval_data(gold_multi_instance, return_instance) 256 | evaluater_tmp2.show_single_eval_result(gold_multi_instance, return_instance) 257 | 258 | evaluater.show_eval_result() 259 | evaluater2.show_eval_result() 260 | evaluater3.show_eval_result() 261 | return 262 | 263 | def sample_train_instances(instances): 264 | 265 | pos_seg_du_instances = [] 266 | pos_w_instances = [] 267 | neg_seg_du_instances = [] 268 | neg_w_instances = [] 269 | 270 | for instance in instances: 271 | if instance.segment_flag: 272 | if instance.label == data_structure.STRUCT_LABEL_TRUE: 273 | pos_seg_du_instances.append(instance) 274 | else: 275 | neg_seg_du_instances.append(instance) 276 | else: 277 | if instance.label == data_structure.STRUCT_LABEL_TRUE: 278 | pos_w_instances.append(instance) 279 | else: 280 | neg_w_instances.append(instance) 281 | 282 | pos_seg_du_count = len(pos_seg_du_instances) 283 | neg_seg_du_count = len(neg_seg_du_instances) 284 | pos_w_count = len(pos_w_instances) 285 | neg_w_count = len(neg_w_instances) 286 | w_count = pos_w_count+neg_w_count 287 | seg_du_count = pos_seg_du_count + neg_seg_du_count 288 | 289 | 290 | if w_count > SAMPLE_FACTOR*seg_du_count: 291 | w_instances = random.sample(neg_w_instances+pos_w_instances, SAMPLE_FACTOR*seg_du_count) 292 | else: 293 | w_instances = neg_w_instances+pos_w_instances 294 | 295 | instances = pos_seg_du_instances+neg_seg_du_instances+w_instances 296 | #shuffle 297 | random.shuffle(instances) 298 | #rename instance id 299 | count = 1 300 | for instance in instances: 301 | length = len(instance.id.split('-')[-1]) 302 | instance.id = instance.id[0:-1*length] 303 | instance.id += str(count) 304 | count += 1 305 | 306 | return instances 307 | 308 | def print_loss(epoch, tag_loss_sum, struct_loss_sum, sense_loss_sum, center_loss_sum, objective_loss_sum): 309 | if lstm_crf_only_flag: 310 | print('epoch %d lose: tag: %f, objective: %f\n'%(epoch, tag_loss_sum.data[0], objective_loss_sum.data[0])) 311 | sys.stderr.write('epoch %d lose:tag: %f, objective: %f\n'%(epoch, tag_loss_sum.data[0], objective_loss_sum.data[0])) 312 | else: 313 | print('epoch %d lose: struct: %f, center: %f, sense: %f, tag: %f, objective: %f\n'%(epoch, struct_loss_sum.data[0], center_loss_sum.data[0], sense_loss_sum.data[0], tag_loss_sum.data[0], objective_loss_sum.data[0])) 314 | sys.stderr.write('epoch %d lose: struct: %f, center: %f, sense: %f, tag: %f, objective: %f\n'%(epoch, struct_loss_sum.data[0], center_loss_sum.data[0], sense_loss_sum.data[0], tag_loss_sum.data[0], objective_loss_sum.data[0])) 315 | return 316 | 317 | #four categories: whether they are above segment level or they are positive 318 | def print_instnaces_categories(instances): 319 | pos_seg_du_instances = [] 320 | pos_w_instances = [] 321 | neg_seg_du_instances = [] 322 | neg_w_instances = [] 323 | 324 | for instance in instances: 325 | if instance.segment_flag: 326 | if instance.label == data_structure.STRUCT_LABEL_TRUE: 327 | pos_seg_du_instances.append(instance) 328 | else: 329 | neg_seg_du_instances.append(instance) 330 | else: 331 | if instance.label == data_structure.STRUCT_LABEL_TRUE: 332 | pos_w_instances.append(instance) 333 | else: 334 | neg_w_instances.append(instance) 335 | 336 | pos_seg_du_count = len(pos_seg_du_instances) 337 | neg_seg_du_count = len(neg_seg_du_instances) 338 | pos_w_count = len(pos_w_instances) 339 | neg_w_count = len(neg_w_instances) 340 | 341 | print pos_seg_du_count, neg_seg_du_count, pos_w_count, neg_w_count 342 | 343 | def text_to_nn_word_list(text): 344 | #for t in text: 345 | # print t.encode('utf-8'), 346 | #print '' 347 | idxs = map(lambda w: data_structure.word_to_ix.setdefault(w,data_structure.word_to_ix['oov']-1), text) 348 | tensor = torch.LongTensor(idxs) 349 | if GPU: 350 | return autograd.Variable(tensor).cuda() 351 | else: 352 | return autograd.Variable(tensor) 353 | 354 | #for returning tensor, not variable 355 | def prepare_tensor_train_ans(idxs): 356 | tensor = torch.LongTensor(idxs) 357 | #print tensor 358 | if GPU: 359 | return tensor.cuda() 360 | else: 361 | return tensor 362 | 363 | def prepare_train_ans(idxs): 364 | tensor = torch.LongTensor(idxs) 365 | #print tensor 366 | if GPU: 367 | return autograd.Variable(tensor).cuda() 368 | else: 369 | return autograd.Variable(tensor) 370 | 371 | 372 | 373 | -------------------------------------------------------------------------------- /output.txt: -------------------------------------------------------------------------------- 1 | {"EDUs": ["据统计,这些城市去年完成国内生产总值一百九十多亿元,", "比开放前的一九九一年增长九成多。", "国务院于一九九二年先后批准了黑河、凭祥、珲春、伊宁、瑞丽等十四个边境城市为对外开放城市,", "同时还批准这些城市设立十四个边境经济合作区。", "三年多来,这些城市社会经济发展迅速,", "地方经济实力明显增强;", "经济年平均增长百分之十七,", "高于全国年平均增长速度。"], "tree": {"args": [{"args": [{"args": ["据统计,这些城市去年完成国内生产总值一百九十多亿元,", "比开放前的一九九一年增长九成多。"], "center": "Front", "sense": "Coordination"}, {"args": ["国务院于一九九二年先后批准了黑河、凭祥、珲春、伊宁、瑞丽等十四个边境城市为对外开放城市,", "同时还批准这些城市设立十四个边境经济合作区。"], "center": "Equal", "sense": "Coordination"}], "center": "Equal", "sense": "Coordination"}, {"args": [{"args": ["三年多来,这些城市社会经济发展迅速,", "地方经济实力明显增强;"], "center": "Front", "sense": "Explanation"}, {"args": ["经济年平均增长百分之十七,", "高于全国年平均增长速度。"], "center": "Equal", "sense": "Coordination"}], "center": "Equal", "sense": "Coordination"}], "center": "Equal", "sense": "Coordination"}, "relations": [{"arg1": "据统计,这些城市去年完成国内生产总值一百九十多亿元,比开放前的一九九一年增长九成多。国务院于一九九二年先后批准了黑河、凭祥、珲春、伊宁、瑞丽等十四个边境城市为对外开放城市,同时还批准这些城市设立十四个边境经济合作区。", "arg2": "三年多来,这些城市社会经济发展迅速,地方经济实力明显增强;经济年平均增长百分之十七,高于全国年平均增长速度。", "center": "Equal", "sense": "Coordination"}, {"arg1": "据统计,这些城市去年完成国内生产总值一百九十多亿元,比开放前的一九九一年增长九成多。", "arg2": "国务院于一九九二年先后批准了黑河、凭祥、珲春、伊宁、瑞丽等十四个边境城市为对外开放城市,同时还批准这些城市设立十四个边境经济合作区。", "center": "Equal", "sense": "Coordination"}, {"arg1": "据统计,这些城市去年完成国内生产总值一百九十多亿元,", "arg2": "比开放前的一九九一年增长九成多。", "center": "Front", "sense": "Coordination"}, {"arg1": "国务院于一九九二年先后批准了黑河、凭祥、珲春、伊宁、瑞丽等十四个边境城市为对外开放城市,", "arg2": "同时还批准这些城市设立十四个边境经济合作区。", "center": "Equal", "sense": "Coordination"}, {"arg1": "三年多来,这些城市社会经济发展迅速,地方经济实力明显增强;", "arg2": "经济年平均增长百分之十七,高于全国年平均增长速度。", "center": "Equal", "sense": "Coordination"}, {"arg1": "三年多来,这些城市社会经济发展迅速,", "arg2": "地方经济实力明显增强;", "center": "Front", "sense": "Explanation"}, {"arg1": "经济年平均增长百分之十七,", "arg2": "高于全国年平均增长速度。", "center": "Equal", "sense": "Coordination"}]} -------------------------------------------------------------------------------- /parser.py: -------------------------------------------------------------------------------- 1 | import data_structure 2 | import pickle 3 | import evaluate 4 | import random 5 | 6 | import argparse 7 | 8 | all_xml_dir = 'raw_data/CDTB_data_repair' 9 | xml_dir = 'raw_data/train_repair' 10 | parse_dir = 'raw_data/train_parsed_repair' 11 | test_xml_dir = 'raw_data/test_repair' 12 | test_parse_dir = 'raw_data/test_parsed_repair' 13 | train_corpus_list_file = './train_corpus_list_file' 14 | train_instances_file = './train_instances_file' 15 | test_corpus_list_file = './test_corpus_list_file' 16 | 17 | 18 | 19 | 20 | def process_commands(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('-demo','--demonstration', action='store_true', 23 | help='demo mode') 24 | parser.add_argument('-i', '--input_file', 25 | help='input file') 26 | parser.add_argument('-o', '--output_file', 27 | help='output file') 28 | return parser.parse_args() 29 | 30 | def demo(): 31 | import model_user 32 | import json 33 | print 'demo mode' 34 | model = model_user.get_model('./model_rvnn_lstm_20') 35 | with open(args.input_file, "rb") as myFile: 36 | lines = myFile.readlines() 37 | text = '' 38 | for line in lines: 39 | text += line 40 | data_structure.load_word_to_ix() 41 | text = text.decode('utf-8') 42 | instance = data_structure.text_to_test_instance(text) 43 | return_instance = model_user.demo_predict(model, instance) 44 | output_data = {} 45 | output_data['EDUs'] = [] 46 | output_data['relations'] = [] 47 | edu_boundaries = set() 48 | for relation in return_instance.du_i_relations: 49 | relation_info = {} 50 | arg_count = 1 51 | for span in relation.spans: 52 | edu_boundaries.add(span[-1]) 53 | arg = text[span[0]:span[-1]+1].encode('utf-8') 54 | relation_info['arg'+str(arg_count)] = arg 55 | arg_count += 1 56 | relation_info['sense'] = data_structure.LABEL_TO_SENSE[relation.sense] 57 | relation_info['center'] = data_structure.LABEL_TO_CENTER[relation.center] 58 | output_data['relations'].append(relation_info) 59 | edu_boundary_list = sorted(edu_boundaries) 60 | start_idx = 0 61 | for boundary in edu_boundary_list: 62 | output_data['EDUs'].append(text[start_idx:boundary+1].encode('utf-8')) 63 | start_idx = boundary+1 64 | dict = data_structure.test_instance_and_text_to_dict(return_instance, text) 65 | 66 | output_data['tree'] = dict 67 | with open(args.output_file, 'w') as outfile: 68 | json.dump(output_data, outfile, ensure_ascii=False) 69 | #data_structure.print_test_instance(return_instance) 70 | 71 | 72 | 73 | 74 | def experiment(): 75 | import model_user 76 | ''' 77 | # train corpus 78 | corpus_list = data_structure.xml_dir_to_corpus_list(xml_dir) 79 | data_structure.build_word_to_ix_from_corpus_list(corpus_list) 80 | parseseg_dict = data_structure.parse_dir_to_parseseg_dict(parse_dir) 81 | corpus_list = corpus_list[:-1] 82 | for corpus in corpus_list: 83 | corpus = data_structure.add_corpus_seg_relations(corpus) 84 | corpus_list = data_structure.merge_parseseg_dict_to_corpus_list(corpus_list, parseseg_dict) 85 | 86 | #test_corpus 87 | test_corpus_list = data_structure.xml_dir_to_corpus_list(test_xml_dir) 88 | test_parseseg_dict = data_structure.parse_dir_to_parseseg_dict(test_parse_dir) 89 | for corpus in test_corpus_list: 90 | corpus = data_structure.add_corpus_seg_relations(corpus) 91 | test_corpus_list =\ 92 | data_structure.merge_parseseg_dict_to_corpus_list(test_corpus_list, test_parseseg_dict) 93 | 94 | 95 | data_structure.load_word_to_ix() 96 | 97 | corpus_list = corpus_list[:-1] 98 | #train instances 99 | instances = [] 100 | for corpus in corpus_list: 101 | #if corpus.id == '246.xml-2': 102 | # data_structure.print_corpus(corpus) 103 | instance_list = data_structure.corpus_to_train_instance_list(corpus) 104 | instances.extend(instance_list) 105 | 106 | 107 | #dump data to file 108 | with open(train_corpus_list_file, "wb") as myFile: 109 | pickle.dump(corpus_list, myFile) 110 | with open(test_corpus_list_file, "wb") as myFile: 111 | pickle.dump(test_corpus_list, myFile) 112 | with open(train_instances_file, "wb") as myFile: 113 | pickle.dump(instances, myFile) 114 | data_structure.save_word_to_ix() 115 | ''' 116 | 117 | #load data 118 | #data_structure.load_word_to_ix() 119 | ''' 120 | with open(train_corpus_list_file, "rb") as myFile: 121 | corpus_list = pickle.load(myFile) 122 | 123 | for corpus in corpus_list: 124 | if corpus.id == '593.xml-2': 125 | data_structure.print_corpus(corpus) 126 | ''' 127 | ''' 128 | with open(train_instances_file, "rb") as myFile: 129 | instances = pickle.load(myFile) 130 | #instances = instances[:1] 131 | instances = random.sample(instances, 1000) 132 | ''' 133 | print 'hellow' 134 | data_structure.load_word_to_ix() 135 | 136 | #train 137 | #model = model_user.get_model(model_path=None) 138 | #model_user.train_from_instances(model, instances) 139 | #test 140 | 141 | with open(test_corpus_list_file, "rb") as myFile: 142 | test_corpus_list = pickle.load(myFile) 143 | model = model_user.get_model('./model_rvnn_lstm_20') 144 | model_user.test_from_corpus_list(model, test_corpus_list) 145 | 146 | 147 | def analysis(): 148 | 149 | corpus_list = data_structure.xml_dir_to_corpus_list(all_xml_dir) 150 | ''' 151 | for corpus in corpus_list: 152 | for span in corpus.edu_spans: 153 | print corpus.text[span[0]:span[-1]+1].encode('utf-8') 154 | ''' 155 | with open(train_corpus_list_file, "rb") as myFile: 156 | corpus_list = pickle.load(myFile) 157 | with open(test_corpus_list_file, "rb") as myFile: 158 | test_corpus_list = pickle.load(myFile) 159 | evaluater = evaluate.Evaluater() 160 | evaluater.show_relation_distribution_from_corpus_list(corpus_list) 161 | #evaluater.show_edu_punc_distribution_from_corpus_list(corpus_list) 162 | #evaluater.show_analysis_from_corpus_list(corpus_list) 163 | 164 | #corpus_list.extend(test_corpus_list) 165 | #evaluater.show_pos_distribution_from_corpus_list(corpus_list) 166 | def test(): 167 | ''' 168 | # train corpus 169 | corpus_list = data_structure.xml_dir_to_corpus_list(xml_dir) 170 | parseseg_dict = data_structure.parse_dir_to_parseseg_dict(parse_dir) 171 | corpus_list = corpus_list[:-1] 172 | for corpus in corpus_list: 173 | corpus = data_structure.add_corpus_seg_relations(corpus) 174 | corpus_list = data_structure.merge_parseseg_dict_to_corpus_list(corpus_list, parseseg_dict) 175 | 176 | #test_corpus 177 | test_corpus_list = data_structure.xml_dir_to_corpus_list(test_xml_dir) 178 | test_parseseg_dict = data_structure.parse_dir_to_parseseg_dict(test_parse_dir) 179 | for corpus in test_corpus_list: 180 | corpus = data_structure.add_corpus_seg_relations(corpus) 181 | test_corpus_list =\ 182 | data_structure.merge_parseseg_dict_to_corpus_list(test_corpus_list, test_parseseg_dict) 183 | 184 | with open(train_corpus_list_file, "wb") as myFile: 185 | pickle.dump(corpus_list, myFile) 186 | with open(test_corpus_list_file, "wb") as myFile: 187 | pickle.dump(test_corpus_list, myFile) 188 | ''' 189 | 190 | data_structure.load_word_to_ix() 191 | with open(train_corpus_list_file, "rb") as myFile: 192 | corpus_list = pickle.load(myFile) 193 | corpus_list = corpus_list[:-1] 194 | 195 | with open(test_corpus_list_file, "rb") as myFile: 196 | test_corpus_list = pickle.load(myFile) 197 | 198 | 199 | 200 | 201 | instances = [] 202 | for corpus in corpus_list: 203 | instance_list = data_structure.corpus_to_train_instance_list(corpus) 204 | #for instance in instance_list: 205 | #data_structure.print_train_instance(instance) 206 | instances.extend(instance_list) 207 | 208 | 209 | test_instances = [] 210 | for corpus in test_corpus_list: 211 | test_instance = data_structure.corpus_to_test_instance(corpus, binary=True) 212 | test_instances.append(test_instance) 213 | 214 | 215 | 216 | with open(train_instances_file, "wb") as myFile: 217 | pickle.dump(instances, myFile) 218 | 219 | 220 | args = process_commands() 221 | def main(): 222 | 223 | if args.demonstration: 224 | demo() 225 | #test() 226 | #experiment() 227 | #analysis() 228 | 229 | 230 | if __name__ == "__main__": 231 | main() -------------------------------------------------------------------------------- /rvnn.py: -------------------------------------------------------------------------------- 1 | # coding=UTF-8 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import torch.autograd as autograd 6 | 7 | import data_structure 8 | 9 | 10 | if True: 11 | EDU_PUNCS = (u'。') 12 | PRE_EDU_PUNCS = (u'、') 13 | else: 14 | EDU_PUNCS = () 15 | PRE_EDU_PUNCS = () 16 | 17 | N_CKY_CANDIDATES = 2 18 | 19 | greedy_flag = False 20 | gold_edu_flag = False 21 | lstm_flag = True 22 | seg_lstm_flag = False 23 | left_seq_flag = True 24 | right_seq_flag = False 25 | GPU = True 26 | 27 | 28 | class CKY_Unit(): 29 | def __init__(self, cky_candidate_list): 30 | self.cky_candidate_list = cky_candidate_list 31 | 32 | class CKY_Candidate(): 33 | def __init__(self, cky_span_infos, representation, score, sense, center): 34 | #a list, each element contains "start_idx" and "range" and "candidate_idx", we have "candidate_idx" since although we find a candidate cky_unit, it has many candidates also 35 | self.cky_span_infos = cky_span_infos 36 | self.representation = representation 37 | self.score = score 38 | self.sense = sense 39 | self.center = center 40 | 41 | class CKY_Span_Info(): 42 | def __init__(self, start_idx, cky_range, candidate_idx): 43 | self.start_idx = start_idx 44 | self.cky_range = cky_range 45 | self.candidate_idx = candidate_idx 46 | 47 | class RVNN(nn.Module): 48 | def __init__(self, config): 49 | super(RVNN, self).__init__() 50 | self.word_embeddings = nn.Embedding(config.vocab_size, config.d_embedding) 51 | self.reduce = Reduce(config.d_rvhidden/2) 52 | self.struct_linear = nn.Linear(config.d_rvhidden, config.d_struct) 53 | self.center_linear = nn.Linear(config.d_rvhidden, config.d_center) 54 | self.sense_linear = nn.Linear(config.d_rvhidden, config.d_relation) 55 | #whether to use lstm to process segments 56 | if lstm_flag: 57 | self.d_lstm_hidden = config.d_rvhidden 58 | self.lstm = nn.LSTM(config.d_embedding, self.d_lstm_hidden) 59 | self.gpu = config.gpu 60 | if seg_lstm_flag: 61 | self.d_seg_lstm_hidden = self.d_lstm_hidden/2 62 | self.seg_lstm = nn.LSTM(self.d_lstm_hidden, self.d_seg_lstm_hidden, bidirectional=True) 63 | def init_hidden(self, hidden_dim, bidirectional=False): 64 | # Before we've done anything, we dont have any hidden state. 65 | # Refer to the Pytorch documentation to see exactly why they have this dimensionality. 66 | # The axes semantics are (num_layers, minibatch_size, hidden_dim) 67 | if bidirectional: 68 | direction = 2 69 | else: 70 | direction = 1 71 | #if self.gpu: 72 | if GPU: 73 | return (autograd.Variable(torch.zeros(direction, 1, hidden_dim)).cuda(), 74 | autograd.Variable(torch.zeros(direction, 1, hidden_dim)).cuda()) 75 | else: 76 | return (autograd.Variable(torch.zeros(direction, 1, hidden_dim)),#.cuda(), 77 | autograd.Variable(torch.zeros(direction, 1, hidden_dim)))#.cuda()) 78 | 79 | def reconstruct_i_relations_from_cky_table(self, instance, cky_table, cky_candidate): 80 | 81 | infos = cky_candidate.cky_span_infos 82 | #recursive to the bottom 83 | if len(infos) == 1: 84 | return instance 85 | 86 | left_start_idx = infos[0].start_idx 87 | left_end_idx = infos[0].start_idx+infos[0].cky_range 88 | left_span = [ instance.fragment_spans[left_start_idx][0], instance.fragment_spans[left_end_idx][-1] ] 89 | 90 | right_start_idx = infos[-1].start_idx 91 | right_end_idx = infos[-1].start_idx+infos[-1].cky_range 92 | right_span = [ instance.fragment_spans[right_start_idx][0], instance.fragment_spans[right_end_idx][-1] ] 93 | 94 | i_relation = data_structure.I_Relation([left_span, right_span], cky_candidate.sense, cky_candidate.center, '') 95 | 96 | #modify the instance directly 97 | instance.i_relations.append(i_relation) 98 | 99 | #find next recursive target 100 | for info in infos: 101 | next_candidate = cky_table[info.start_idx][info.cky_range].cky_candidate_list[info.candidate_idx] 102 | self.reconstruct_i_relations_from_cky_table(instance, cky_table, next_candidate) 103 | return instance 104 | 105 | def forward_by_i_relations(self, instance, struct=False, center=False, sense=False): 106 | #convert segments to list of word embeddings 107 | if lstm_flag: 108 | self.lstm_hidden = self.init_hidden(self.d_lstm_hidden) 109 | fragments = [] 110 | #used to modify the index of spans to fit instance.fragments 111 | start_idx = instance.segment_spans[0][0] 112 | for span in instance.segment_spans: 113 | #print instance.fragments 114 | #print span 115 | fragment = instance.fragments[span[0]-start_idx:span[-1]+1-start_idx] 116 | #print fragment 117 | fragment = self.word_embeddings(fragment) 118 | #print fragment 119 | fragment = self.lstm(fragment.view(span[-1]-span[0]+1, 1, -1), self.lstm_hidden)[0] 120 | fragments.append(fragment[-1].view(1, -1)) 121 | if seg_lstm_flag: 122 | self.seg_hidden = self.init_hidden(self.d_seg_lstm_hidden, bidirectional=True) 123 | seg_embeds = torch.stack(fragments) 124 | fragments = list(self.seg_lstm(seg_embeds, self.seg_hidden)[0]) 125 | instance.fragments = fragments 126 | instance.fragment_spans = instance.segment_spans 127 | else: 128 | instance.fragments = list(self.word_embeddings(instance.fragments)) 129 | #print instance.fragments 130 | for i_relation in instance.i_relations: 131 | left_span = i_relation.spans[0] 132 | for idx in range(len(instance.fragment_spans)): 133 | #since the i_relations is in post order(buttom up stlye), it's always to find a corresponding fragment to merge 134 | if left_span == instance.fragment_spans[idx]: 135 | #data_structure.print_train_instance(instance) 136 | #data_structure.print_i_relation(i_relation) 137 | left = instance.fragments[idx].view(1, -1) 138 | right = instance.fragments[idx+1].view(1, -1) 139 | reduced = self.reduce(left, right) 140 | instance.fragments[idx] = reduced 141 | del instance.fragments[idx+1] 142 | #modify the segment span list to follow the change of segment list 143 | instance.fragment_spans[idx] = [ instance.fragment_spans[idx][0], instance.fragment_spans[idx+1][-1] ] 144 | del instance.fragment_spans[idx+1] 145 | #finish the merge, break to deal with next relation 146 | break 147 | final_representation = instance.fragments[0] 148 | 149 | scores_list = [] 150 | 151 | if struct: 152 | label_space = self.struct_linear(final_representation) 153 | struct_scores= F.log_softmax(label_space) 154 | scores_list.append(struct_scores) 155 | 156 | if center: 157 | label_space = self.center_linear(final_representation) 158 | center_scores = F.log_softmax(label_space) 159 | scores_list.append(center_scores) 160 | 161 | if sense: 162 | label_space = self.sense_linear(final_representation) 163 | sense_scores = F.log_softmax(label_space) 164 | scores_list.append(sense_scores) 165 | 166 | 167 | 168 | return instance, scores_list 169 | 170 | def forward_process_gold_edu(self, instance): 171 | #pre_edustage 172 | n_fragments = len(instance.fragments) 173 | #use flag instead of use torch Variable itself since the comparison limit, ex: can't write "if (torch Variable) == 0:" 174 | left_flag = False 175 | edus = [] 176 | edu_idx = 0 177 | left = None 178 | #from left to right 179 | for idx in range(n_fragments): 180 | edu_span = instance.edu_spans[edu_idx] 181 | # start a new edu or merge to previous edu 182 | if left_flag == False: 183 | left = instance.fragments[idx] 184 | left_span = instance.fragment_spans[idx] 185 | else: 186 | right = instance.fragments[idx] 187 | right_span = instance.fragment_spans[idx] 188 | left = self.reduce(left, right) 189 | merge_span = [left_span, right_span] 190 | left_span = [left_span[0], right_span[1]] 191 | #if reaching the edu boundary 192 | if left_span[-1] == edu_span[-1]: 193 | edus.append(left) 194 | if left_flag: 195 | i_relation = data_structure.I_Relation(merge_span,data_structure.SENSE_TO_LABEL[data_structure.EDU_SENSE] , data_structure.CENTER_TO_LABEL[data_structure.NON_CENTER], '') 196 | instance.i_relations.append(i_relation) 197 | left = None 198 | left_flag = False 199 | edu_idx += 1 200 | else: 201 | if left_flag: 202 | i_relation = data_structure.I_Relation(merge_span,data_structure.SENSE_TO_LABEL[data_structure.PRE_EDU_SENSE] , data_structure.CENTER_TO_LABEL[data_structure.NON_CENTER], '') 203 | instance.i_relations.append(i_relation) 204 | left_flag = True 205 | # set the edus, modify the instance segments to edu level 206 | instance.fragments = edus 207 | instance.fragment_spans = instance.edu_spans 208 | return instance 209 | 210 | def forward_sequence_predict_edu(self, instance): 211 | 212 | #pre_edustage 213 | n_fragments = len(instance.fragments) 214 | #use flag instead of use torch Variable itself since the comparison limit, ex: can't write "if (torch Variable) == 0:" 215 | left_flag = False 216 | right_flag = False 217 | edus = [] 218 | edus_span = [] 219 | #from left to right 220 | for i in range(n_fragments): 221 | #the code is written for left-to-right logic, when we need to process right to left, just modify the idx, other variables(inculding left, right) remain the same 222 | if left_seq_flag: 223 | idx = i 224 | else: 225 | idx = n_fragments-i-1 226 | punc = instance.puncs[idx] 227 | #merge or just set the left segment 228 | if not left_flag: 229 | left = instance.fragments[idx] 230 | left_span = instance.fragment_spans[idx] 231 | left_flag = True 232 | else: 233 | #merge, make sense 234 | right = instance.fragments[idx] 235 | right_span = instance.fragment_spans[idx] 236 | right_flag = True 237 | reduced = self.reduce(left, right) 238 | merge_span = [left_span[0], right_span[1]] 239 | label_space = self.sense_linear(reduced) 240 | label_score = F.log_softmax(label_space) 241 | sense_score = [ x.data[0] for x in label_score[0] ] 242 | merge_sense = sense_score.index( max(sense_score) ) 243 | 244 | #modify the instance information, no need to deal with the i_reltion 245 | if right_flag: 246 | if merge_sense == data_structure.SENSE_TO_LABEL[data_structure.EDU_SENSE] and punc not in PRE_EDU_PUNCS: 247 | edus.append(reduced) 248 | edus_span.append(merge_span) 249 | merge_sense = data_structure.SENSE_TO_LABEL[data_structure.EDU_SENSE] 250 | i_relation = data_structure.I_Relation([left_span, right_span], merge_sense, data_structure.CENTER_TO_LABEL[data_structure.NON_CENTER], '') 251 | instance.i_relations.append(i_relation) 252 | left_flag = False 253 | else: 254 | if merge_sense == data_structure.SENSE_TO_LABEL[data_structure.PRE_EDU_SENSE] or punc in PRE_EDU_PUNCS: 255 | merge_sense = data_structure.SENSE_TO_LABEL[data_structure.PRE_EDU_SENSE] 256 | i_relation = data_structure.I_Relation([left_span, right_span], merge_sense, data_structure.CENTER_TO_LABEL[data_structure.NON_CENTER], '') 257 | instance.i_relations.append(i_relation) 258 | left = reduced 259 | left_span = merge_span 260 | 261 | #condition that the sense is in du sense 262 | else: 263 | edus.append(left) 264 | edus_span.append(left_span) 265 | left = right 266 | left_span = right_span 267 | #left is already modified above 268 | if idx == n_fragments-1 or punc in EDU_PUNCS: 269 | edus.append(left) 270 | edus_span.append(left_span) 271 | left_flag = False 272 | 273 | else: 274 | if idx == n_fragments-1 or punc in EDU_PUNCS: 275 | edus.append(left) 276 | edus_span.append(left_span) 277 | left_flag = False 278 | 279 | right_flag = False 280 | 281 | # set the edus, modify the instance segments to edu level 282 | instance.fragments = edus 283 | instance.fragment_spans = edus_span 284 | return instance 285 | 286 | def forward_greedy_predict_structure(self, instance): 287 | i_relations = [] 288 | while len(instance.fragments) > 1: 289 | n_fragments = len(instance.fragments) 290 | scores = [] 291 | reduceds = [] 292 | for i in range(n_fragments-1): 293 | reduced = self.reduce(instance.fragments[i], instance.fragments[i+1]) 294 | reduceds.append(reduced) 295 | label_space = self.struct_linear(reduced) 296 | label_score = F.log_softmax(label_space) 297 | scores.append(label_score[0][1].data[0]) 298 | #for span 299 | idx = scores.index( max(scores) ) 300 | left_span = instance.fragment_spans[idx] 301 | right_span = instance.fragment_spans[idx+1] 302 | i_relation_span = [left_span, right_span] 303 | #print 'left_span: ', left_span, 'right_span: ', right_span, 'max score: ', max(scores) 304 | #for center 305 | label_space = self.center_linear(reduceds[idx]) 306 | label_score = F.log_softmax(label_space) 307 | center_score = [ x.data[0] for x in label_score[0] ] 308 | #exclude non du centers 309 | lowest = min(center_score) 310 | for i in range(len(center_score)): 311 | if i not in data_structure.DU_CENTER_LABEL: 312 | center_score[i] = lowest-1 313 | idx_center = center_score.index( max(center_score) ) 314 | 315 | 316 | 317 | #for sense 318 | label_space = self.sense_linear(reduceds[idx]) 319 | label_score = F.log_softmax(label_space) 320 | sense_score = [ x.data[0] for x in label_score[0] ] 321 | #exclude non du senses 322 | lowest = min(sense_score) 323 | for i in range(len(sense_score)): 324 | if i not in data_structure.DU_SENSE_LABEL: 325 | sense_score[i] = lowest-1 326 | idx_sense = sense_score.index( max(sense_score) ) 327 | 328 | #modify the instance information 329 | instance.fragments[idx] = reduceds[idx] 330 | instance.fragment_spans[idx] = [ i_relation_span[0][0], i_relation_span[-1][-1] ] 331 | del instance.fragments[idx+1] 332 | del instance.fragment_spans[idx+1] 333 | #make a new i_relation 334 | i_relation = data_structure.I_Relation(i_relation_span, idx_sense, idx_center, '') 335 | i_relations.append(i_relation) 336 | instance.i_relations = i_relations 337 | return instance 338 | 339 | def forward_cky_predict(self, instance, word_level_flag=True, du_level_flag=False): 340 | 341 | n_fragments = len(instance.fragments) 342 | #the first dimension corresponds to the start index 343 | #the second dimension corresponds to the range(from 0 to n_fragments-1) 344 | cky_table = [None for start_idx in range(n_fragments)] 345 | for start_idx in range(n_fragments): 346 | cky_table[start_idx] = [ None for cky_range in range(n_fragments)] 347 | 348 | #cky table initial condition 349 | for start_idx in range(n_fragments): 350 | #make a initial candidate 351 | #the cky_span has only one element and describe itself, it's only for the initialized case 352 | cky_span_infos = [ CKY_Span_Info(start_idx, 0, 0) ] 353 | 354 | #choose the basic initial sense, may be used if cky algorithm refer the sense of a possible child 355 | ''' 356 | if args.args.cky_predict_structure: 357 | initial_sense = data_structure.SENSE_TO_LABEL[EDU_SENSE] 358 | elif args.args.cky_predict_edu_and_structure: 359 | initial_sense = data_structure.SENSE_TO_LABEL[PRE_EDU_SENSE] 360 | ''' 361 | initial_sense = data_structure.SENSE_TO_LABEL[data_structure.PRE_EDU_SENSE] 362 | 363 | #view(1,-1) for Reduce function 364 | cky_candidate_list = [ CKY_Candidate( cky_span_infos, instance.fragments[start_idx].view(1,-1), 0, initial_sense, data_structure.CENTER_TO_LABEL[data_structure.NON_CENTER]) ] 365 | cky_table[start_idx][0] = CKY_Unit(cky_candidate_list) 366 | 367 | #cky algorithm 368 | for cky_range in range(1,n_fragments): 369 | for start_idx in range(0, n_fragments-cky_range): 370 | cky_table[start_idx][cky_range] = CKY_Unit([]) 371 | #each element is [merge_score, cky_candidate] tuple 372 | cky_candidate_score_list = [] 373 | end_idx = start_idx+cky_range 374 | #find candidates, middle_index is the end idx of left unit 375 | #start_idx+cky_range since at least 1 left for the right unit 376 | for middle_idx in range(start_idx, start_idx+cky_range): 377 | cky_unit_left = cky_table[start_idx][middle_idx-start_idx] 378 | cky_unit_right = cky_table[middle_idx+1][end_idx-middle_idx-1] 379 | for left_idx in range(len(cky_unit_left.cky_candidate_list)): 380 | for right_idx in range(len(cky_unit_right.cky_candidate_list)): 381 | 382 | left_candidate = cky_unit_left.cky_candidate_list[left_idx] 383 | right_candidate = cky_unit_right.cky_candidate_list[right_idx] 384 | left = left_candidate.representation 385 | right = right_candidate.representation 386 | reduced = self.reduce(left, right) 387 | label_space = self.struct_linear(reduced) 388 | label_score = F.log_softmax(label_space) 389 | #print 'label_score', label_score 390 | merge_score = label_score[0][1].data[0] 391 | #accumulate the probility scores of left and right 392 | struct_score = merge_score+left_candidate.score+right_candidate.score 393 | 394 | #make the cky_span_info of the now left and right candidates 395 | cky_span_infos = [ CKY_Span_Info(start_idx, middle_idx-start_idx, left_idx), CKY_Span_Info(middle_idx+1, end_idx-middle_idx-1, right_idx) ] 396 | #sense and center is temporarily initialized 397 | cky_candidate = CKY_Candidate(cky_span_infos, reduced, struct_score, data_structure.SENSE_TO_LABEL[data_structure.PSEUDO_SENSE], data_structure.CENTER_TO_LABEL[data_structure.NON_CENTER]) 398 | cky_candidate_score_list.append([merge_score, cky_candidate]) 399 | 400 | #sort according to the merge score 401 | sorted_cky_candidate_score_list = sorted(cky_candidate_score_list, key=lambda x: x[1].score, reverse=True) 402 | 403 | candidate_count = 0 404 | for cky_candidate_score in sorted_cky_candidate_score_list: 405 | cky_candidate = cky_candidate_score[1] 406 | infos = cky_candidate.cky_span_infos 407 | if word_level_flag: 408 | cky_candidate.sense = data_structure.SENSE_TO_LABEL[data_structure.PSEUDO_SENSE] 409 | cky_candidate.center = data_structure.CENTER_TO_LABEL[data_structure.NON_CENTER] 410 | else: 411 | #for center 412 | label_space = self.center_linear(cky_candidate.representation) 413 | label_score = F.log_softmax(label_space) 414 | center_score = [ x.data[0] for x in label_score[0] ] 415 | #exclude edu centers 416 | lowest = min(center_score) 417 | for idx in range(len(center_score)): 418 | if idx not in data_structure.DU_CENTER_LABEL: 419 | center_score[idx] = lowest-1 420 | idx_center = center_score.index( max(center_score) ) 421 | 422 | #for sense 423 | label_space = self.sense_linear(cky_candidate.representation) 424 | label_score = F.log_softmax(label_space) 425 | sense_score = [ x.data[0] for x in label_score[0] ] 426 | 427 | #for some condition, exclude pre edu senses 428 | ''' 429 | if args.args.cky_predict_structure: 430 | sense_score = sense_score[0:4] 431 | elif args.args.cky_predict_edu_and_structure: 432 | ''' 433 | cky_unit_left = cky_table[infos[0].start_idx][infos[0].cky_range] 434 | cky_unit_right = cky_table[infos[1].start_idx][infos[1].cky_range] 435 | left_candidate = cky_unit_left.cky_candidate_list[infos[0].candidate_idx] 436 | right_candidate = cky_unit_right.cky_candidate_list[infos[1].candidate_idx] 437 | left_sense = left_candidate.sense 438 | right_sense = right_candidate.sense 439 | # if one of left and right is not pre edu, exclude pre edu senses 440 | # or du_level_flag is true 441 | middle_punc = instance.puncs[infos[1].start_idx-1] 442 | right_punc = instance.puncs[infos[1].start_idx+infos[1].cky_range] 443 | if (left_sense in data_structure.DU_SENSE_LABEL or right_sense in data_structure.DU_SENSE_LABEL ) or middle_punc in EDU_PUNCS or du_level_flag: 444 | lowest = min(sense_score) 445 | for idx in range(len(sense_score)): 446 | if idx not in data_structure.DU_SENSE_LABEL or\ 447 | idx == data_structure.SENSE_TO_LABEL[data_structure.EDU_SENSE]: 448 | sense_score[idx] = lowest-1 449 | idx_sense = sense_score.index( max(sense_score) ) 450 | 451 | #force to PRE_EDU_SENSE if the end punc is of certain type 452 | elif right_punc in PRE_EDU_PUNCS: 453 | idx_sense = data_structure.SENSE_TO_LABEL[data_structure.PRE_EDU_SENSE] 454 | else: 455 | idx_sense = sense_score.index( max(sense_score) ) 456 | 457 | if idx_sense == data_structure.SENSE_TO_LABEL[data_structure.EDU_SENSE]: 458 | idx_center = data_structure.CENTER_TO_LABEL[data_structure.NON_CENTER] 459 | 460 | cky_candidate.sense = idx_sense 461 | cky_candidate.center = idx_center 462 | ''' 463 | print 'start_idx:', start_idx, 'cky_range: ', cky_range 464 | for info in infos: 465 | print '\t info_start_idx:', info.start_idx, 'info_cky_range: ', info.cky_range, 'sense: ', cky_table[info.start_idx][info.cky_range].cky_candidate_list[info.candidate_idx] .sense 466 | print 'struct_score: ', cky_candidate_score[1].score, 'merge_score:', cky_candidate_score[0], 'sense: ', cky_candidate.sense 467 | ''' 468 | if candidate_count == N_CKY_CANDIDATES: 469 | #break 470 | continue 471 | candidate_count += 1 472 | cky_table[start_idx][cky_range].cky_candidate_list.append(cky_candidate) 473 | 474 | #report.reporter.output_cky_table(cky_table) 475 | #the final representation replace the original instance.fragments 476 | instance.fragments = [cky_table[0][-1].cky_candidate_list[0].representation] 477 | instance = self.reconstruct_i_relations_from_cky_table(instance, cky_table, cky_table[0][n_fragments-1].cky_candidate_list[0]) 478 | return instance 479 | 480 | def forward(self, instance): 481 | #list of list 482 | fragments = [] 483 | w_i_relations_list = [] 484 | 485 | if lstm_flag: 486 | for span in instance.segment_spans: 487 | fragment = instance.fragments[span[0]:span[-1]+1] 488 | #print fragment 489 | fragment = self.word_embeddings(fragment) 490 | #print fragment 491 | fragment = self.lstm(fragment.view(span[-1]-span[0]+1, 1, -1), self.lstm_hidden)[0] 492 | fragments.append(fragment[-1].view(1, -1)) 493 | print instance.id 494 | if seg_lstm_flag: 495 | self.seg_hidden = self.init_hidden(self.d_seg_lstm_hidden, bidirectional=True) 496 | seg_embeds = torch.stack(fragments) 497 | fragments = list(self.seg_lstm(seg_embeds, self.seg_hidden)[0]) 498 | else: 499 | #we construct each segment parsing tree first 500 | for seg_span in instance.segment_spans: 501 | seg_instance = data_structure.Instance() 502 | seg_instance.fragments = instance.fragments[seg_span[0]:seg_span[-1]+1] 503 | seg_instance.fragment_spans = [[i,i] for i in range(seg_span[0], seg_span[-1]+1)] 504 | #convert segments to list of word embeddings 505 | seg_instance.fragments = self.word_embeddings(seg_instance.fragments) 506 | seg_instance = self.forward_cky_predict(seg_instance, word_level_flag=True) 507 | #get the word level i_relations of the segment 508 | w_i_relations_list.append(seg_instance.i_relations) 509 | #get the final representation of the segment 510 | fragments.append(seg_instance.fragments[0]) 511 | 512 | instance.w_i_relations_list = w_i_relations_list 513 | #construct the discourse parsing tree based on segments 514 | instance.fragments = fragments 515 | instance.fragment_spans = instance.segment_spans 516 | #empty the i_relations before we fill it with the predicted result 517 | instance.i_relations = [] 518 | #data_structure.print_test_instance(instance) 519 | du_level_flag = False 520 | if gold_edu_flag: 521 | instance = self.forward_process_gold_edu(instance) 522 | du_level_flag = True 523 | elif left_seq_flag: 524 | #will append the resulting i_relations in instance.i_relations 525 | instance = self.forward_sequence_predict_edu(instance) 526 | du_level_flag = True 527 | #will append the resulting i_relations in instance.i_relations 528 | #data_structure.print_test_instance(instance) 529 | if greedy_flag: 530 | instance = self.forward_greedy_predict_structure(instance) 531 | else: 532 | instance = self.forward_cky_predict(instance, word_level_flag=False, du_level_flag=du_level_flag) 533 | #data_structure.print_test_instance(instance) 534 | #empty the relation before predicting 535 | instance.du_i_relations = [] 536 | instance.seg_i_relations = [] 537 | for i_relation in instance.i_relations: 538 | if i_relation.sense in data_structure.DU_SENSE_LABEL and\ 539 | i_relation.sense != data_structure.SENSE_TO_LABEL[data_structure.EDU_SENSE]: 540 | instance.du_i_relations.append(i_relation) 541 | else: 542 | instance.seg_i_relations.append(i_relation) 543 | 544 | return instance 545 | 546 | def tree_lstm(c1, c2, lstm_in): 547 | a, i, f1, f2, o = lstm_in.chunk(5, 1) 548 | c = a.tanh() * i.sigmoid() + f1.sigmoid() * c1 + f2.sigmoid() * c2 549 | h = o.sigmoid() * c.tanh() 550 | return h, c 551 | 552 | class Reduce(nn.Module): 553 | def __init__(self, size): 554 | super(Reduce, self).__init__() 555 | self.left = nn.Linear(size, 5 * size) 556 | self.right = nn.Linear(size, 5 * size, bias=False) 557 | 558 | def forward(self, left_in, right_in): 559 | left = torch.chunk(left_in, 2, 1) 560 | right = torch.chunk(right_in, 2, 1) 561 | lstm_in = self.left(left[0]) + self.right(right[0]) 562 | lstm_out = tree_lstm(left[1], right[1], lstm_in) 563 | return torch.cat(lstm_out, 1) 564 | 565 | class BiLSTM_CRF(nn.Module): 566 | def __init__(self, vocab_size, tagset_size, embedding_dim, hidden_dim): 567 | super(BiLSTM_CRF, self).__init__() 568 | self.embedding_dim = embedding_dim 569 | self.hidden_dim = hidden_dim 570 | self.vocab_size = vocab_size 571 | #+2 for start and stop 572 | self.tagset_size = tagset_size+2 573 | 574 | self.word_embeds = nn.Embedding(vocab_size, embedding_dim) 575 | self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, 576 | num_layers=1, bidirectional=True) 577 | 578 | # Maps the output of the LSTM into tag space. 579 | self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size) 580 | 581 | # Matrix of transition parameters. Entry i,j is the score of 582 | # transitioning *to* i *from* j. 583 | self.transitions = nn.Parameter( 584 | torch.randn(self.tagset_size, self.tagset_size)) 585 | 586 | self.START_TAG_LABEL = tagset_size 587 | self.STOP_TAG_LABEL = tagset_size+1 588 | 589 | # These two statements enforce the constraint that we never transfer 590 | # to the start tag and we never transfer from the stop tag 591 | self.transitions.data[self.START_TAG_LABEL, :] = -10000 592 | self.transitions.data[:, self.STOP_TAG_LABEL] = -10000 593 | 594 | self.hidden = self.init_hidden() 595 | 596 | 597 | 598 | def init_hidden(self): 599 | #return (torch.randn(2, 1, self.hidden_dim // 2), 600 | # torch.randn(2, 1, self.hidden_dim // 2)) 601 | if GPU: 602 | return (autograd.Variable(torch.zeros(2, 1, self.hidden_dim // 2)).cuda(), 603 | autograd.Variable(torch.zeros(2, 1, self.hidden_dim // 2)).cuda()) 604 | else: 605 | return (autograd.Variable(torch.zeros(2, 1, self.hidden_dim // 2)), 606 | autograd.Variable(torch.zeros(2, 1, self.hidden_dim // 2))) 607 | 608 | 609 | def _forward_alg(self, feats): 610 | # Do the forward algorithm to compute the partition function 611 | #init_alphas = [-10000.]*self.tagset_size 612 | #init_alphas = torch.Tensor(init_alphas).view(1, -1) 613 | init_alphas = torch.Tensor(1, self.tagset_size).fill_(-10000.) 614 | 615 | # START_TAG has all of the score. 616 | init_alphas[0][self.START_TAG_LABEL] = 0. 617 | 618 | 619 | # Wrap in a variable so that we will get automatic backprop 620 | #forward_var = init_alphas 621 | if GPU: 622 | forward_var = autograd.Variable(init_alphas).cuda() 623 | else: 624 | forward_var = autograd.Variable(init_alphas) 625 | 626 | # Iterate through the sentence 627 | for feat in feats: 628 | alphas_t = [] # The forward tensors at this timestep 629 | for next_tag in range(self.tagset_size): 630 | # broadcast the emission score: it is the same regardless of 631 | # the previous tag 632 | emit_score = feat[next_tag].view(1, -1).expand(1, self.tagset_size) 633 | # the ith entry of trans_score is the score of transitioning to 634 | # next_tag from i 635 | trans_score = self.transitions[next_tag].view(1, -1) 636 | # The ith entry of next_tag_var is the value for the 637 | # edge (i -> next_tag) before we do log-sum-exp 638 | 639 | next_tag_var = forward_var + trans_score + emit_score 640 | # The forward variable for this tag is log-sum-exp of all the 641 | # scores. 642 | alphas_t.append(log_sum_exp(next_tag_var).view(1)) 643 | forward_var = torch.cat(alphas_t).view(1, -1) 644 | terminal_var = forward_var + self.transitions[self.STOP_TAG_LABEL] 645 | alpha = log_sum_exp(terminal_var) 646 | 647 | return alpha 648 | 649 | def _get_lstm_features(self, sentence): 650 | self.hidden = self.init_hidden() 651 | embeds = self.word_embeds(sentence).view(len(sentence), 1, -1) 652 | lstm_out, self.hidden = self.lstm(embeds, self.hidden) 653 | lstm_out = lstm_out.view(len(sentence), self.hidden_dim) 654 | lstm_feats = self.hidden2tag(lstm_out) 655 | return lstm_feats 656 | 657 | def _score_sentence(self, feats, tags): 658 | # Gives the score of a provided tag sequence 659 | #score = torch.zeros(1) 660 | #tags = torch.cat([torch.tensor([self.START_TAG_LABEL], dtype=torch.long), tags]) 661 | 662 | if GPU: 663 | score = autograd.Variable(torch.Tensor([0])).cuda() 664 | tags = torch.cat([torch.LongTensor([self.START_TAG_LABEL]).cuda(), tags]) 665 | else: 666 | score = autograd.Variable(torch.Tensor([0])) 667 | tags = torch.cat([torch.LongTensor([self.START_TAG_LABEL]), tags]) 668 | 669 | #tags = torch.cat([torch.LongTensor([self.START_TAG_LABEL]), tags]) 670 | 671 | for i, feat in enumerate(feats): 672 | score = score + \ 673 | self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]] 674 | score = score + self.transitions[self.STOP_TAG_LABEL, tags[-1]] 675 | return score 676 | 677 | def _viterbi_decode(self, feats): 678 | backpointers = [] 679 | 680 | # Initialize the viterbi variables in log space 681 | #init_vvars = [-10000.]*self.tagset_size 682 | #init_vvars = torch.Tensor(init_vvars).view(1, -1) 683 | #init_vvars = torch.full((1, self.tagset_size), -10000.) 684 | init_vvars = torch.Tensor(1, self.tagset_size).fill_(-10000.) 685 | init_vvars[0][self.START_TAG_LABEL] = 0 686 | 687 | # forward_var at step i holds the viterbi variables for step i-1 688 | if GPU: 689 | forward_var = autograd.Variable(init_vvars).cuda() 690 | else: 691 | forward_var = autograd.Variable(init_vvars) 692 | 693 | 694 | for feat in feats: 695 | bptrs_t = [] # holds the backpointers for this step 696 | viterbivars_t = [] # holds the viterbi variables for this step 697 | 698 | for next_tag in range(self.tagset_size): 699 | # next_tag_var[i] holds the viterbi variable for tag i at the 700 | # previous step, plus the score of transitioning 701 | # from tag i to next_tag. 702 | # We don't include the emission scores here because the max 703 | # does not depend on them (we add them in below) 704 | next_tag_var = forward_var + self.transitions[next_tag] 705 | best_tag_id = argmax(next_tag_var) 706 | bptrs_t.append(best_tag_id) 707 | viterbivars_t.append(next_tag_var[0][best_tag_id].view(1)) 708 | # Now add in the emission scores, and assign forward_var to the set 709 | # of viterbi variables we just computed 710 | forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1) 711 | backpointers.append(bptrs_t) 712 | 713 | # Transition to STOP_TAG 714 | terminal_var = forward_var + self.transitions[self.STOP_TAG_LABEL] 715 | best_tag_id = argmax(terminal_var) 716 | path_score = terminal_var[0][best_tag_id] 717 | 718 | # Follow the back pointers to decode the best path. 719 | best_path = [best_tag_id] 720 | for bptrs_t in reversed(backpointers): 721 | best_tag_id = bptrs_t[best_tag_id] 722 | best_path.append(best_tag_id) 723 | # Pop off the start tag (we dont want to return that to the caller) 724 | start = best_path.pop() 725 | assert start == self.START_TAG_LABEL # Sanity check 726 | best_path.reverse() 727 | return path_score, best_path 728 | 729 | def neg_log_likelihood(self, sentence, tags): 730 | feats = self._get_lstm_features(sentence) 731 | forward_score = self._forward_alg(feats) 732 | gold_score = self._score_sentence(feats, tags) 733 | return forward_score - gold_score 734 | #return forward_score 735 | 736 | def forward(self, sentence): # dont confuse this with _forward_alg above. 737 | # Get the emission scores from the BiLSTM 738 | lstm_feats = self._get_lstm_features(sentence) 739 | 740 | # Find the best path, given the features. 741 | score, tag_seq = self._viterbi_decode(lstm_feats) 742 | return score, tag_seq 743 | 744 | def argmax(vec): 745 | # return the argmax as a python int 746 | _, idx = torch.max(vec, 1) 747 | #return idx.item() 748 | return to_scalar(idx) 749 | 750 | # Compute log sum exp in a numerically stable way for the forward algorithm 751 | def log_sum_exp(vec): 752 | max_score = vec[0, argmax(vec)] 753 | max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1]) 754 | return max_score + \ 755 | torch.log(torch.sum(torch.exp(vec - max_score_broadcast))) 756 | 757 | def to_scalar(var): #var是Variable,维度是1 758 | # returns a python float 759 | return var.view(-1).data.tolist()[0] -------------------------------------------------------------------------------- /word_to_ix: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abccaba2000/discourse-parser/a7414e66e48621b3a1ae9bedd17b07a8be1487ce/word_to_ix --------------------------------------------------------------------------------