├── .gitignore ├── README.md ├── arc_eager_transition_system.py ├── arc_standard_transition_system.py ├── conll_utils.py ├── corpus ├── dsindex_postagfix.py ├── en-ud-dev.conllu ├── en-ud-test.conllu └── en-ud-train.conllu ├── decoded_parse_reader.py ├── dep_parser.py ├── feature_extractor.py ├── feature_map.py ├── gold_parse_reader.py ├── lexicon.py ├── model_parameters.py ├── other └── transition_system_test_framework.py ├── parser-config.sample ├── parser_state.py ├── projectivize_filter.py ├── sentence_batch.py ├── training_test.sh ├── utils.py └── well_formed_filter.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ash-parser 2 | 3 | This was originally for a class project. 4 | 5 | Utilizes a [Chen and Manning (2014)](http://cs.stanford.edu/people/danqi/papers/emnlp2014.pdf) style neural network parser in Python and TensorFlow. Many elements mimic [SyntaxNet](https://github.com/tensorflow/models/tree/master/syntaxnet). 6 | 7 | I analyze [SyntaxNet's Architecture](http://andrewmatteson.name/index.php/2017/02/04/inside-syntaxnet/) here. 8 | 9 | parsing-config file is required to be created in the model directory before execution. 10 | 11 | Run training_test.sh for an example of how to train a model. Evaluation during training works as well, but there is no API for tagging new input yet or serving a model. 12 | 13 | External dependencies 14 | - NumPy 15 | - TensorFlow 1.0 16 | 17 | Similarities to SyntaxNet 18 | - Same embedding system (configurable per-feature group deep embedding) 19 | - Same optimizer (Momentum with exponential moving average) 20 | - Lexicon builder is identical for words, tags, and labels 21 | - Map files output by SyntaxNet and AshParser should be identical 22 | - Evaluation metric is identical (SyntaxNet's corresponds to AshParser's UAS) 23 | - Feature system is almost identical (except perhaps some very rare corner cases) 24 | - Due to same architecture, accuracy should be very close to Greedy SyntaxNet 25 | 26 | Differences from SyntaxNet: 27 | - Arc-Eager transition system also supported 28 | - Context file with redundant or boilerplate information is unnecessary 29 | - Supports GPU: training phase can complete in minutes 30 | - Pure Python3 implementation. No need for bazel 31 | - LAS (Labeled Attachment Score) prints out during evaluation 32 | - Precalculation and caching of feature bags. This makes it easier to train multiple models with the same token features but different hyperparameters 33 | - No support for structured (beam) parsing. Considering LSTM or something simpler and faster instead for the future. Accuracy loss should be in the ballpark of 1-2% due to this. 34 | - Feature groups are automatically created by groups of tag, word, and label rather than by grouping together with semicolon in a context file 35 | - Only support for the transition parser, not the POS tagger, morphological analyzer, or tokenizer 36 | - ngrams, punctuation_amount, morph tags and other features not yet implemented 37 | -------------------------------------------------------------------------------- /arc_eager_transition_system.py: -------------------------------------------------------------------------------- 1 | from conll_utils import ParsedConllSentence, ParsedConllToken 2 | from parser_state import ParserState 3 | 4 | class ArcEagerTransitionState(object): 5 | def __init__(self, state): 6 | state.push(-1) # ROOT node 7 | 8 | def toString(self, state): 9 | s = '[' 10 | i = state.stackSize() - 1 11 | while i >= 0: 12 | word = state.getToken(state.stack(i)).FORM 13 | if i != state.stackSize() - 1: 14 | s += ' ' 15 | 16 | # only for internal ROOT token at start of stack 17 | if word == None: 18 | s += 'ROOT' 19 | else: 20 | s += word 21 | 22 | i -= 1 23 | s += ']' 24 | 25 | i = state.next() 26 | while i < state.numTokens(): 27 | s += ' ' + state.getToken(i).FORM 28 | i += 1 29 | return s 30 | 31 | class ArcEagerTransitionSystem(object): 32 | SHIFT = 0 33 | REDUCE = 1 34 | LEFT_ARC = 2 35 | RIGHT_ARC = 3 36 | 37 | Transitions = [SHIFT, REDUCE, LEFT_ARC, RIGHT_ARC] 38 | 39 | def __init__(self): 40 | pass 41 | 42 | def shiftAction(self): 43 | return ArcEagerTransitionSystem.SHIFT 44 | 45 | def reduceAction(self): 46 | return ArcEagerTransitionSystem.REDUCE 47 | 48 | def leftArcAction(self, label): 49 | return 2 + (label << 1) 50 | 51 | def rightArcAction(self, label): 52 | return 2 + ((label << 1) | 1) 53 | 54 | def label(self, action): 55 | if action < 2: 56 | return -1 57 | else: 58 | return (action - 2) >> 1 59 | 60 | def actionType(self, action): 61 | if action < 2: 62 | return action 63 | else: 64 | return 2 + (~(action-1) & 1) 65 | 66 | def numActionTypes(self): 67 | return 4 68 | 69 | def numActions(self, numLabels): 70 | return 2 + 2 * numLabels 71 | 72 | def getDefaultAction(self, state): 73 | if(not state.endOfInput()): 74 | return self.shiftAction() 75 | 76 | return self.rightArcAction(2) 77 | 78 | def getDepRelation(self, idx_parent, idx_child, state): 79 | if idx_child == -1: 80 | return None # root word 81 | assert idx_child >= 0 82 | 83 | if state.goldHead(idx_child) == idx_parent: 84 | # fixme: if label is -1 then?? 85 | return state.goldLabel(idx_child) 86 | else: 87 | return None 88 | 89 | def getNextGoldAction(self, state): 90 | b0 = state.input(0) 91 | 92 | bInput = -2 93 | bInc = 0 94 | bItems = [] 95 | while bInput != -2: 96 | bInput = state.input(bInc) 97 | if bInput == -2: 98 | break 99 | bItems.append(bInput) 100 | bInc += 1 101 | 102 | #print('B:', bItems) 103 | 104 | if state.stackSize() > 0: 105 | #print('S:', state.stack_) 106 | 107 | s0 = state.stack(0) 108 | rel = self.getDepRelation(b0, s0, state) 109 | if rel is not None: 110 | #print('return L-A action', self.leftArcAction(rel)) 111 | return self.leftArcAction(rel) 112 | 113 | rel = self.getDepRelation(s0, b0, state) 114 | if rel is not None: 115 | #print('return R-A action', self.rightArcAction(rel)) 116 | return self.rightArcAction(rel) 117 | 118 | flag = False 119 | for k in range(-1, s0): # s0 goes as low as -1 unlike NLTK 120 | if self.getDepRelation(k, b0, state) is not None: 121 | flag = True 122 | if self.getDepRelation(b0, k, state) is not None: 123 | flag = True 124 | 125 | if flag: 126 | #print('return R action', self.reduceAction()) 127 | return self.reduceAction() 128 | 129 | # S|i 130 | # nothing else we can do except shift to the next required arc 131 | # (S, i|B) => (S|i, B) 132 | #print('return S action', self.shiftAction()) 133 | return self.shiftAction() 134 | 135 | def doneChildrenRightOf(self, state, head): 136 | index = state.next() 137 | num_tokens = state.numTokens() 138 | 139 | while(index < num_tokens): 140 | actual_head = state.goldHead(index) 141 | if(actual_head == head): 142 | return False 143 | 144 | if(actual_head > index): 145 | index = actual_head 146 | else: 147 | index += 1 148 | 149 | return True 150 | 151 | def isAllowedAction(self, action, state): 152 | if(self.actionType(action) == ArcEagerTransitionSystem.SHIFT): 153 | return self.isAllowedShift(state) 154 | elif(self.actionType(action) == ArcEagerTransitionSystem.REDUCE): 155 | return self.isAllowedReduce(state) 156 | elif(self.actionType(action) == ArcEagerTransitionSystem.LEFT_ARC): 157 | return self.isAllowedLeftArc(state) 158 | elif(self.actionType(action) == ArcEagerTransitionSystem.RIGHT_ARC): 159 | return self.isAllowedRightArc(state) 160 | else: 161 | assert None 162 | 163 | def isAllowedShift(self, state): 164 | return (not state.endOfInput()) 165 | 166 | def isAllowedReduce(self, state): 167 | if state.stackSize() == 0: 168 | return False 169 | 170 | idx_wi = state.stack(0) 171 | flag = False 172 | for (idx_parent, r, idx_child) in state.arcs_: 173 | if idx_child == idx_wi: 174 | flag = True 175 | if not flag: 176 | return False 177 | 178 | return True 179 | 180 | def isAllowedLeftArc(self, state): 181 | if state.endOfInput() or state.stackSize() == 0: 182 | return False 183 | 184 | # this is the root element 185 | if state.input(0) == -1: 186 | return False 187 | 188 | # in nltk code, 0 if root node 189 | # here, -1 if root node 190 | idx_wi = state.stack(0) 191 | flag = True 192 | 193 | # but the problem is, root node is not in head_ 194 | # or label_ 195 | # (they start from node 0, not -1) 196 | # but head and label are initialized to -1 always, 197 | # so it's confusing. 198 | # store arcs_ separately based on transitions made 199 | # to state 200 | for (idx_parent, r, idx_child) in state.arcs_: 201 | if idx_child == idx_wi: 202 | flag = False 203 | if not flag: 204 | return False 205 | 206 | return True 207 | 208 | def isAllowedRightArc(self, state): 209 | if state.endOfInput() or state.stackSize() == 0: 210 | return False 211 | 212 | return True 213 | 214 | def performAction(self, action, state): 215 | self.performActionWithoutHistory(action, state) 216 | 217 | def performActionWithoutHistory(self, action, state): 218 | if self.actionType(action) == ArcEagerTransitionSystem.SHIFT: 219 | self.performShift(state) 220 | elif self.actionType(action) == ArcEagerTransitionSystem.REDUCE: 221 | self.performReduce(state) 222 | elif self.actionType(action) == ArcEagerTransitionSystem.LEFT_ARC: 223 | self.performLeftArc(state, self.label(action)) 224 | elif self.actionType(action) == ArcEagerTransitionSystem.RIGHT_ARC: 225 | self.performRightArc(state, self.label(action)) 226 | else: 227 | assert(None) 228 | 229 | def performShift(self, state): 230 | assert self.isAllowedShift(state) 231 | state.push(state.next()) 232 | state.advance() 233 | 234 | def performReduce(self, state): 235 | assert self.isAllowedReduce(state) 236 | state.pop() 237 | 238 | # Arc Eager 239 | # (S|i, j|B) => (S, j|B) 240 | # add left arc j->i: 241 | def performLeftArc(self, state, label): 242 | assert self.isAllowedLeftArc(state) 243 | s_j = state.next() 244 | s_i = state.pop() 245 | state.addArc(s_i, s_j, label) 246 | 247 | # Arc Eager 248 | # (S|i, j|B) => (S|i|j, B) 249 | # add right arc i->j: 250 | def performRightArc(self, state, label): 251 | assert self.isAllowedRightArc(state) 252 | s_j = state.next() 253 | s_i = state.stack(0) 254 | state.addArc(s_j, s_i, label) 255 | 256 | # next token 257 | # add to S 258 | state.push(s_j) 259 | # signify that we've pushed the next token onto the stack 260 | # remove from B 261 | state.advance() 262 | 263 | def isFinalState(self, state): 264 | return state.endOfInput() 265 | 266 | def actionAsTuple(self, action): 267 | if(self.actionType(action) == ArcEagerTransitionSystem.SHIFT): 268 | return (ArcEagerTransitionSystem.SHIFT,) 269 | if(self.actionType(action) == ArcEagerTransitionSystem.REDUCE): 270 | return (ArcEagerTransitionSystem.REDUCE,) 271 | elif(self.actionType(action) == ArcEagerTransitionSystem.LEFT_ARC): 272 | return (ArcEagerTransitionSystem.LEFT_ARC, self.label(action)) 273 | elif(self.actionType(action) == ArcEagerTransitionSystem.RIGHT_ARC): 274 | return (ArcEagerTransitionSystem.RIGHT_ARC, self.label(action)) 275 | else: 276 | return None 277 | 278 | def actionAsString(self, action, state, feature_maps): 279 | if(self.actionType(action) == ArcEagerTransitionSystem.SHIFT): 280 | return 'SHIFT' 281 | if(self.actionType(action) == ArcEagerTransitionSystem.REDUCE): 282 | return 'REDUCE' 283 | elif(self.actionType(action) == ArcEagerTransitionSystem.LEFT_ARC): 284 | return 'LEFT_ARC(' + \ 285 | feature_maps['label'].indexToValue(self.label(action)) + ')' 286 | elif(self.actionType(action) == ArcEagerTransitionSystem.RIGHT_ARC): 287 | return 'RIGHT_ARC(' + \ 288 | feature_maps['label'].indexToValue(self.label(action)) + ')' 289 | else: 290 | return 'UNKNOWN' 291 | 292 | ''' 293 | Dynamic Oracle 294 | ''' 295 | @staticmethod 296 | def legal(state): 297 | transitions = ArcEagerTransitionSystem.Transitions 298 | shift_ok = True 299 | right_ok = True 300 | left_ok = True 301 | reduce_ok = True 302 | state_buffer = list(range(state.next_, state.numTokens())) 303 | 304 | if len(state_buffer) == 1: 305 | right_ok = shift_ok = False 306 | 307 | if state.stackSize() == 0: 308 | left_ok = right_ok = reduce_ok = False 309 | else: 310 | s = state.stack(0) 311 | 312 | # if the s is already a dependent, we cannot left-arc 313 | # arcs_ storage: (parent, rel, child) 314 | if len(list(filter(lambda A: s == A[2], state.arcs_))) > 0: 315 | left_ok = False 316 | else: 317 | reduce_ok = False 318 | 319 | ok = [shift_ok, right_ok, left_ok, reduce_ok] 320 | 321 | legal_transitions = [] 322 | for it in range(len(transitions)): 323 | if ok[it] is True: 324 | legal_transitions.append(it) 325 | 326 | return legal_transitions 327 | 328 | @staticmethod 329 | def dynamicOracle(state, legal_transitions): 330 | options = [] 331 | if ArcEagerTransitionSystem.SHIFT in legal_transitions \ 332 | and ArcEagerTransitionSystem.zeroCostShift(state): 333 | options.append(ArcEagerTransitionSystem.SHIFT) 334 | if ArcEagerTransitionSystem.RIGHT_ARC in legal_transitions \ 335 | and ArcEagerTransitionSystem.zeroCostRight(state): 336 | options.append(ArcEagerTransitionSystem.RIGHT_ARC) 337 | if ArcEagerTransitionSystem.LEFT_ARC in legal_transitions \ 338 | and ArcEagerTransitionSystem.zeroCostLeft(state): 339 | options.append(ArcEagerTransitionSystem.LEFT_ARC) 340 | if ArcEagerTransitionSystem.REDUCE in legal_transitions \ 341 | and ArcEagerTransitionSystem.zeroCostReduce(state): 342 | options.append(ArcEagerTransitionSystem.REDUCE) 343 | return options 344 | 345 | @staticmethod 346 | def zeroCostShift(state): 347 | state_buffer = list(range(state.next_, state.numTokens())) 348 | if len(state_buffer) <= 1: 349 | return False 350 | b = state.input(0) 351 | 352 | for si in state.stack_: 353 | if state.goldHead(si) == b or state.goldHead(b) == si: 354 | return False 355 | return True 356 | 357 | @staticmethod 358 | def zeroCostRight(state): 359 | if state.stackSize() == 0 or state.endOfInput(): 360 | return False 361 | 362 | s = state.stack(0) 363 | b = state.input(0) 364 | 365 | # should be fine? k = b in gold_conf.heads and gold_conf.heads[b] or -1 366 | # but why would there be no head for B? makes no sense 367 | k = state.goldHead(b) 368 | if k == s: 369 | return True 370 | 371 | state_buffer = list(range(state.next_, state.numTokens())) 372 | 373 | k_b_costs = k in state.stack_ or k in state_buffer 374 | 375 | k_heads = dict((child, parent) for (parent, rel, child) in state.arcs_) 376 | 377 | b_deps = state.goldDeps(b) 378 | 379 | # (b, k) and k in S 380 | b_k_in_stack = filter(lambda dep: dep in state.stack_, b_deps) 381 | b_k_final = filter(lambda dep: dep not in k_heads, b_k_in_stack) 382 | if k not in state_buffer and k not in state.stack_ and len(list(b_k_in_stack)) is 0: 383 | return True 384 | 385 | if k_b_costs: 386 | return False 387 | 388 | return len(list(b_k_final)) == 0 389 | 390 | @staticmethod 391 | def zeroCostLeft(state): 392 | if state.stackSize() == 0 or state.endOfInput(): 393 | return False 394 | 395 | s = state.stack(0) 396 | b = state.input(0) 397 | 398 | for bi in range(b, state.numTokens()): 399 | if state.goldHead(bi) == s: 400 | return False 401 | if b != bi and state.goldHead(s) == bi: 402 | return False 403 | return True 404 | 405 | @staticmethod 406 | def zeroCostReduce(state): 407 | if state.stackSize() == 0 or state.endOfInput(): 408 | return False 409 | 410 | s = state.stack(0) 411 | b = state.input(0) 412 | 413 | for bi in range(b, state.numTokens()): 414 | if state.goldHead(bi) == s: 415 | return False 416 | 417 | return True 418 | -------------------------------------------------------------------------------- /arc_standard_transition_system.py: -------------------------------------------------------------------------------- 1 | # Translation of arc_standard_transitions.cc from SyntaxNet 2 | 3 | from conll_utils import ParsedConllSentence, ParsedConllToken 4 | from parser_state import ParserState 5 | 6 | class ArcStandardTransitionState(object): 7 | def __init__(self, state): 8 | state.push(-1) # ROOT node 9 | 10 | def toString(self, state): 11 | s = '[' 12 | i = state.stackSize() - 1 13 | while i >= 0: 14 | word = state.getToken(state.stack(i)).FORM 15 | if i != state.stackSize() - 1: 16 | s += ' ' 17 | 18 | # only for internal ROOT token at start of stack 19 | if word == None: 20 | s += 'ROOT' 21 | else: 22 | s += word 23 | 24 | i -= 1 25 | s += ']' 26 | 27 | i = state.next() 28 | while i < state.numTokens(): 29 | s += ' ' + state.getToken(i).FORM 30 | i += 1 31 | return s 32 | 33 | class ArcStandardTransitionSystem(object): 34 | SHIFT = 0 35 | LEFT_ARC = 1 36 | RIGHT_ARC = 2 37 | 38 | Transitions = [SHIFT, LEFT_ARC, RIGHT_ARC] 39 | 40 | def __init__(self): 41 | pass 42 | 43 | def shiftAction(self): 44 | return ArcStandardTransitionSystem.SHIFT 45 | 46 | def leftArcAction(self, label): 47 | return 1 + (label << 1) 48 | 49 | def rightArcAction(self, label): 50 | return 1 + ((label << 1) | 1) 51 | 52 | def label(self, action): 53 | if action < 1: 54 | return -1 55 | else: 56 | return (action - 1) >> 1 57 | 58 | def actionType(self, action): 59 | if action < 1: 60 | return action 61 | else: 62 | return 1 + (~action & 1) 63 | 64 | def numActionTypes(self): 65 | return 3 66 | 67 | def numActions(self, numLabels): 68 | return 1 + 2 * numLabels 69 | 70 | def getDefaultAction(self, state): 71 | if(not state.endOfInput()): 72 | return self.shiftAction() 73 | 74 | return self.rightArcAction(2) 75 | 76 | def getNextGoldAction(self, state): 77 | # nothing else we can do except shift to the end 78 | # (leaving us with only the remaining 'ROOT' element) 79 | if(state.stackSize() < 2): 80 | assert not state.endOfInput() 81 | return self.shiftAction() 82 | 83 | # S|i|j 84 | # if HEAD(j) == i... (and we are done with children to the right of j) 85 | # add right arc i->j: 86 | # (S|i|j, B) => (S|i, B) 87 | if(state.goldHead(state.stack(0)) == state.stack(1) and \ 88 | self.doneChildrenRightOf(state, state.stack(0))): 89 | gold_label = state.goldLabel(state.stack(0)) 90 | return self.rightArcAction(gold_label) 91 | 92 | # S|i|j 93 | # if HEAD(i) == j... add left arc j->i: 94 | # (S|i|j, B) => (S|j, B) 95 | if(state.goldHead(state.stack(1)) == state.top()): 96 | gold_label = state.goldLabel(state.stack(1)) 97 | return self.leftArcAction(gold_label) 98 | 99 | # S|i 100 | # nothing else we can do except shift to the next required arc 101 | # (S, i|B) => (S|i, B) 102 | return self.shiftAction() 103 | 104 | ''' 105 | def getNextGoldAction(self, state): 106 | # nothing else we can do except shift to the end 107 | # (leaving us with only the remaining 'ROOT' element) 108 | 109 | o0 = state.stack(0) 110 | o1 = state.stack(1) 111 | 112 | if(state.stackSize() < 2): 113 | assert not state.endOfInput() 114 | #print('STACK TOO SMALL: RETURN DEFAULT SHIFT ACTION') 115 | return self.shiftAction() 116 | 117 | # S|i|j 118 | # if HEAD(i) == j... add left arc j->i: 119 | # (S|i|j, B) => (S|j, B) 120 | if(state.goldHead(o1) == o0 and \ 121 | self.doneChildrenRightOf(state, o1)): 122 | gold_label = state.goldLabel(o1) 123 | return self.leftArcAction(gold_label) 124 | 125 | # S|i|j 126 | # if HEAD(j) == i... (and we are done with children to the right of j) 127 | # add right arc i->j: 128 | # (S|i|j, B) => (S|i, B) 129 | if(state.goldHead(o0) == o1 and \ 130 | self.doneChildrenRightOf(state, o0)): 131 | gold_label = state.goldLabel(o0) 132 | return self.rightArcAction(gold_label) 133 | 134 | # S|i 135 | # nothing else we can do except shift to the next required arc 136 | # (S, i|B) => (S|i, B) 137 | return self.shiftAction() 138 | ''' 139 | 140 | def doneChildrenRightOf(self, state, head): 141 | index = state.next() 142 | num_tokens = state.numTokens() 143 | 144 | while(index < num_tokens): 145 | actual_head = state.goldHead(index) 146 | if(actual_head == head): 147 | return False 148 | 149 | if(actual_head > index): 150 | index = actual_head 151 | else: 152 | index += 1 153 | 154 | return True 155 | 156 | def isAllowedAction(self, action, state): 157 | if(self.actionType(action) == ArcStandardTransitionSystem.SHIFT): 158 | return self.isAllowedShift(state) 159 | elif(self.actionType(action) == ArcStandardTransitionSystem.LEFT_ARC): 160 | return self.isAllowedLeftArc(state) 161 | elif(self.actionType(action) == ArcStandardTransitionSystem.RIGHT_ARC): 162 | return self.isAllowedRightArc(state) 163 | else: 164 | assert None 165 | 166 | def isAllowedShift(self, state): 167 | return (not state.endOfInput()) 168 | 169 | def isAllowedLeftArc(self, state): 170 | # Left-arc requires two or more tokens on the stack but the first token 171 | # is the root and we do not want a left arc to the root. 172 | return (state.stackSize() > 2) 173 | 174 | def isAllowedRightArc(self, state): 175 | # Right arc requires three or more tokens on the stack. 176 | return (state.stackSize() > 1) 177 | 178 | def performAction(self, action, state): 179 | self.performActionWithoutHistory(action, state) 180 | 181 | def performActionWithoutHistory(self, action, state): 182 | if self.actionType(action) == ArcStandardTransitionSystem.SHIFT: 183 | self.performShift(state) 184 | elif self.actionType(action) == ArcStandardTransitionSystem.LEFT_ARC: 185 | self.performLeftArc(state, self.label(action)) 186 | elif self.actionType(action) == ArcStandardTransitionSystem.RIGHT_ARC: 187 | self.performRightArc(state, self.label(action)) 188 | else: 189 | assert(None) 190 | 191 | def performShift(self, state): 192 | assert self.isAllowedShift(state) 193 | state.push(state.next()) 194 | state.advance() 195 | 196 | # S|i|j 197 | # if HEAD(i) == j... add left arc j->i: 198 | # (S|i|j, B) => (S|j, B) 199 | def performLeftArc(self, state, label): 200 | assert self.isAllowedLeftArc(state) 201 | s_j = state.pop() 202 | s_i = state.pop() 203 | state.addArc(s_i, s_j, label) 204 | state.push(s_j) 205 | 206 | # S|i|j 207 | # if HEAD(j) == i... (and we are done with children to the right of j) 208 | # add right arc i->j: 209 | # (S|i|j, B) => (S|i, B) 210 | def performRightArc(self, state, label): 211 | assert self.isAllowedRightArc(state) 212 | s_j = state.pop() 213 | s_i = state.pop() 214 | state.addArc(s_j, s_i, label) 215 | state.push(s_i) 216 | 217 | def isDeterministicState(self, state): 218 | return state.stackSize() < 2 and (not state.endOfInput()) 219 | 220 | def isFinalState(self, state): 221 | return state.endOfInput() and (state.stackSize() < 2) 222 | 223 | def actionAsTuple(self, action): 224 | if(self.actionType(action) == ArcStandardTransitionSystem.SHIFT): 225 | return (ArcStandardTransitionSystem.SHIFT,) 226 | elif(self.actionType(action) == ArcStandardTransitionSystem.LEFT_ARC): 227 | return (ArcStandardTransitionSystem.LEFT_ARC, self.label(action)) 228 | elif(self.actionType(action) == ArcStandardTransitionSystem.RIGHT_ARC): 229 | return (ArcStandardTransitionSystem.RIGHT_ARC, self.label(action)) 230 | else: 231 | return None 232 | 233 | def actionAsString(self, action, state, feature_maps): 234 | if(self.actionType(action) == ArcStandardTransitionSystem.SHIFT): 235 | return 'SHIFT' 236 | elif(self.actionType(action) == ArcStandardTransitionSystem.LEFT_ARC): 237 | return 'LEFT_ARC(' + \ 238 | feature_maps['label'].indexToValue(self.label(action)) + ')' 239 | elif(self.actionType(action) == ArcStandardTransitionSystem.RIGHT_ARC): 240 | return 'RIGHT_ARC(' + \ 241 | feature_maps['label'].indexToValue(self.label(action)) + ')' 242 | else: 243 | return 'UNKNOWN' 244 | -------------------------------------------------------------------------------- /conll_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | A set of classes to handle input and output of CoNLL-U files 3 | 4 | http://universaldependencies.org/docs/format.html 5 | 6 | The Parsed* classes are useful to store extra properties needed during 7 | the parsing process that are external to the Conll instances themselves 8 | ''' 9 | 10 | import logging 11 | import well_formed_filter 12 | 13 | def encodeNoneAsUnderscore(s): 14 | if s == None: 15 | return '_' 16 | else: 17 | return s 18 | 19 | def encodeNoneAsUnderscore_Int(i): 20 | if i == None: 21 | return '_' 22 | else: 23 | return str(i) 24 | 25 | ''' 26 | Represents a CoNLL token and all its properties (except index) 27 | ''' 28 | class ConllToken(object): 29 | def __init__(self): 30 | self.FORM = None 31 | self.LEMMA = None 32 | self.UPOSTAG = None 33 | self.XPOSTAG = None 34 | self.FEATS = [] 35 | 36 | ''' 37 | Make sure to subtract one from the HEAD value in the file 38 | Root becomes -1 39 | 40 | HEAD then becomes n, which refers to the n'th 0-based index entry 41 | in the parent ConllSentence 42 | 43 | Our parser also requires this to start at -1 44 | ''' 45 | self.HEAD = None 46 | 47 | self.DEPREL = None 48 | self.DEPS = None 49 | self.MISC = None 50 | 51 | def __str__(self): 52 | return self.toFileOutput('_') 53 | 54 | def __repr__(self): 55 | return self.__str__() 56 | 57 | def toFileOutput(self, ID): 58 | def checkTab(s): 59 | assert '\t' not in s, 'field must not contain a tab: ' + s 60 | return s 61 | 62 | def checkPipe(s): 63 | assert '|' not in s, 'field must not contain a pipe: ' + s 64 | return s 65 | 66 | assert self.FORM != None 67 | assert type(self.FEATS) is list 68 | 69 | cols = [str(ID), 70 | checkTab(self.FORM), 71 | checkTab(encodeNoneAsUnderscore(self.LEMMA)), 72 | checkTab(encodeNoneAsUnderscore(self.UPOSTAG)), 73 | checkTab(encodeNoneAsUnderscore(self.XPOSTAG)), 74 | '|'.join(checkPipe(checkTab(f)) for f in self.FEATS), 75 | encodeNoneAsUnderscore_Int(self.HEAD+1), # +1 when writing as file 76 | checkTab(encodeNoneAsUnderscore(self.DEPREL)), 77 | checkTab(encodeNoneAsUnderscore(self.DEPS)), # TODO 78 | checkTab(encodeNoneAsUnderscore(self.MISC))] 79 | 80 | return '\t'.join(cols) 81 | 82 | ''' 83 | Represents a ConllToken, as parsed 84 | ''' 85 | class ParsedConllToken(ConllToken): 86 | def __init__(self): 87 | super().__init__() 88 | self.parsedLabel = None 89 | self.parsedHead = None 90 | self.HEAD = -1 # match default value in sentence.proto 91 | 92 | def setParsedLabel(self, label): 93 | self.parsedLabel = label 94 | 95 | def setParsedHead(self, head): 96 | self.parsedHead = head 97 | 98 | def clearParsedHead(self): 99 | self.parsedHead = -1 # match ParserState: always use -1 as 100 | 101 | ''' 102 | Stores an ordered list of CoNLL tokens 103 | ''' 104 | class ConllSentence(object): 105 | def __init__(self): 106 | self.tokens = [] 107 | 108 | ''' 109 | Convert to file output representation 110 | ''' 111 | def toFileOutput(self): 112 | return '\n'.join(self.tokens[ID-1].toFileOutput(ID) \ 113 | for ID in range(1, len(self.tokens)+1)) 114 | 115 | def genSyntaxNetJson(self, token, break_level=None, start_index=0): 116 | break_contents = '' 117 | if break_level: 118 | break_contents = \ 119 | ''' 120 | break_level : %s''' % break_level 121 | 122 | return \ 123 | '''token: { 124 | word : "%s" 125 | start : %d 126 | end : %d 127 | head : %d 128 | tag : "%s" 129 | category: "%s" 130 | label : "%s"%s 131 | }''' % (token.FORM, start_index, start_index+len(token.FORM)-1, token.HEAD, token.XPOSTAG, token.UPOSTAG, token.DEPREL, break_contents) 132 | 133 | def genSyntaxNetTextHeader(self): 134 | return 'text : "%s"' % (' '.join(t.FORM for t in self.tokens)) 135 | 136 | ''' 137 | Convert to SyntaxNet JSON format 138 | ''' 139 | def toSyntaxNetJson(self): 140 | out = [] 141 | start_index = 0 142 | out.append(self.genSyntaxNetTextHeader()) 143 | for i in range(len(self.tokens)): 144 | if i == 0: 145 | out.append(self.genSyntaxNetJson(self.tokens[i], break_level='SENTENCE_BREAK', start_index=start_index)) 146 | else: 147 | out.append(self.genSyntaxNetJson(self.tokens[i], start_index=start_index)) 148 | start_index += len(self.tokens[i].FORM) + 1 # assume space 149 | return '\n'.join(out) 150 | 151 | ''' 152 | Output the token separated by spaces 153 | ''' 154 | def toSimpleRepresentation(self): 155 | return ' '.join(t.FORM for t in self.tokens) 156 | 157 | class ParsedConllSentence(ConllSentence): 158 | def __init__(self, docid): 159 | super().__init__() 160 | self.docid_ = docid 161 | 162 | def docid(self): 163 | return self.docid_ 164 | 165 | ## checked accessor 166 | def mutableToken(self, i): 167 | assert i >= 0 168 | assert i < len(self.tokens) 169 | return self.tokens[i] 170 | 171 | def tokenSize(self): 172 | return len(self.tokens) 173 | 174 | ''' 175 | Stores an ordered list of sentences within a CoNLL file 176 | 177 | keepMalformed: 178 | Whether to retain non-projective and invalid examples 179 | 180 | projectivize: 181 | Whether to retain non-projective examples by projectivizing them 182 | 183 | logStats: 184 | Log statistics about the corpus 185 | ''' 186 | class ConllFile(object): 187 | def __init__(self, parsed=False, keepMalformed=False, projectivize=False, 188 | logStats=False): 189 | #self.sentenceIndex = None 190 | self.sentences = [] 191 | # use parsed variant of structures 192 | self.parsed = parsed 193 | self.logger = logging.getLogger('ConllUtils') 194 | self.keepMalformed = keepMalformed 195 | self.projectivize = projectivize 196 | self.logStats = logStats 197 | 198 | 199 | ''' 200 | Read CoNLL-U from the given string 201 | 202 | excludeCols: CoNLL column indices to exclude from reading 203 | sometimes we just want to get rid of certain 204 | attributes of a token 205 | 1-based index 206 | ''' 207 | def read(self, s, excludeCols=[]): 208 | assert 1 not in excludeCols, 'cannot exclude reading of ID' 209 | assert 2 not in excludeCols, 'cannot exclude reading of FORM' 210 | 211 | well_formed_inst = well_formed_filter.WellFormedFilter() 212 | 213 | # arbitrary ID that can be used with parser 214 | if self.parsed: 215 | docid = 0 216 | 217 | ln_num = 0 218 | 219 | current_sentence = None 220 | 221 | # if we encounter an error during processing a sentence 222 | invalid_sentence = False 223 | 224 | # set up iterator 225 | # if there is no iterator, set one up 226 | # if there was an iterator, leave it at its current position 227 | #if self.sentenceIndex == None: 228 | # self.sentenceIndex = len(self.sentences) 229 | 230 | def commit(s): 231 | # if we're even getting rid of malformed sentences in the first 232 | # place... 233 | if not self.keepMalformed: 234 | if not well_formed_inst.isWellFormed(s, 235 | projectivize=self.projectivize): 236 | # if the sentence is non-projective and projectivize 237 | # is enabled, the sentence will be fixed and not discarded 238 | self.logger.debug('line %d: discarding malformed or non' \ 239 | '-projective sentence: "%s"' % \ 240 | (ln_num, s.toSimpleRepresentation())) 241 | # as long as we discard the sentence here, 242 | # discarded sentences' words, tags, and labels 243 | # won't be added to the lexicon, which is exactly the 244 | # behavior we want. 245 | return 246 | 247 | self.sentences.append(s) 248 | 249 | def processUnderscore(s): 250 | if s == '_': 251 | return None 252 | else: 253 | return s 254 | 255 | # token index (to check that it's in order) 256 | current_ID = 0 257 | 258 | lines = s.split('\n') 259 | for ln in lines: 260 | ln_num += 1 261 | ln = ln.strip() 262 | if not ln: 263 | # a completely blank line indicates we need to commit the 264 | # current sentence 265 | if current_sentence != None: 266 | if not invalid_sentence: 267 | commit(current_sentence) 268 | 269 | current_sentence = None 270 | current_ID = 0 271 | invalid_sentence = False 272 | continue 273 | if ln[0] == '#': # ignore comments completely 274 | continue 275 | if invalid_sentence: # don't process invalid sentences 276 | continue 277 | cols = [x.strip() for x in ln.split('\t')] 278 | assert len(cols) >= 2, \ 279 | 'line %d: must have at least ID and FORM: ' % ln_num + str(cols) 280 | 281 | if '-' in cols[0] or '.' in cols[0]: 282 | self.logger.warning('line %d: not implemented: ID=%s, ' \ 283 | 'invalidating sentence' % (ln_num, cols[0])) 284 | invalid_sentence = True 285 | continue 286 | else: 287 | ID = int(cols[0]) 288 | assert ID==current_ID+1, 'line %d: token IDs must be in order' \ 289 | ' and increment by one' % ln_num 290 | 291 | current_ID = ID 292 | 293 | if current_ID == 1: 294 | if self.parsed: 295 | current_sentence = ParsedConllSentence(docid) 296 | docid += 1 297 | else: 298 | current_sentence = ConllSentence() 299 | 300 | if self.parsed: 301 | current_token = ParsedConllToken() 302 | else: 303 | current_token = ConllToken() 304 | 305 | #if self.parsed: 306 | # current_token.FORM = normalizeDigits(cols[1]) 307 | #else: 308 | # current_token.FORM = cols[1] 309 | 310 | # for SyntaxNet, 311 | # normalization ONLY happens in lexicon builder 312 | # yet numbers and up as during training 313 | # interesting... 314 | 315 | # let this be underscore if needed (don't call processUnderscore()) 316 | current_token.FORM = cols[1] 317 | 318 | if len(cols) > 2 and (3 not in excludeCols): 319 | # let this be underscore if needed 320 | # (don't call processUnderscore()) 321 | current_token.LEMMA = cols[2] 322 | if len(cols) > 3 and (4 not in excludeCols): 323 | current_token.UPOSTAG = processUnderscore(cols[3]) 324 | if len(cols) > 4 and (5 not in excludeCols): 325 | current_token.XPOSTAG = processUnderscore(cols[4]) 326 | if len(cols) > 5 and (6 not in excludeCols): 327 | if processUnderscore(cols[5]): 328 | current_token.FEATS = \ 329 | [x.strip() for x in cols[5].split('|')] 330 | else: 331 | current_token.FEATS = [] 332 | if len(cols) > 6 and (7 not in excludeCols): 333 | current_token.HEAD = processUnderscore(cols[6]) 334 | if current_token.HEAD != None: 335 | if '-' in current_token.HEAD or '.' in current_token.HEAD: 336 | self.logger.warning('line %d: not implemented: HEAD=%s,' 337 | ' invalidating sentence' % (ln_num, \ 338 | current_token.HEAD)) 339 | 340 | invalid_sentence = True 341 | continue 342 | else: 343 | # it's important for parsing that HEAD start at -1 344 | current_token.HEAD = int(current_token.HEAD)-1 345 | if len(cols) > 7 and (8 not in excludeCols): 346 | current_token.DEPREL = processUnderscore(cols[7]) 347 | if len(cols) > 8 and (9 not in excludeCols): 348 | # TODO 349 | current_token.DEPS = processUnderscore(cols[8]) 350 | if len(cols) > 9 and (10 not in excludeCols): 351 | current_token.MISC = processUnderscore(cols[9]) 352 | 353 | current_sentence.tokens.append(current_token) 354 | 355 | # an EOF indicates we need to commit the current sentence 356 | if current_sentence != None: 357 | if not invalid_sentence: 358 | commit(current_sentence) 359 | 360 | current_sentence = None 361 | current_ID = 0 362 | invalid_sentence = False 363 | 364 | if self.logStats: 365 | self.logger.info('Projectivized %d/%d non-projective sentences' \ 366 | ' (%.2f%% of set)' % \ 367 | (well_formed_inst.projectivizedCount, \ 368 | well_formed_inst.nonProjectiveCount, 369 | 100.0 * float(well_formed_inst.projectivizedCount) \ 370 | / float(len(self.sentences)) 371 | )) 372 | 373 | # if we're even getting rid of malformed sentences in the first place... 374 | if not self.keepMalformed: 375 | if self.projectivize: 376 | # the definition of this variable changes when projectivize is on 377 | self.logger.info('Discarded %d non-well-formed sentences' % \ 378 | (well_formed_inst.nonWellFormedCount)) 379 | else: 380 | self.logger.info('Discarded %d non-well-formed and ' \ 381 | ' non-projective sentences' % \ 382 | (well_formed_inst.nonWellFormedCount)) 383 | 384 | self.logger.info('%d valid sentences processed in total' % \ 385 | len(self.sentences)) 386 | 387 | ''' 388 | Write the current CoNLL-U data to the specified file descriptor 389 | ''' 390 | def write(self, fd): 391 | data = [s.toFileOutput() for s in self.sentences] 392 | fd.write('\n\n'.join(data)) 393 | fd.flush() 394 | 395 | def __iter__(self): 396 | index = 0 397 | while index < len(self.sentences): 398 | yield self.sentences[index] 399 | index += 1 400 | 401 | class ParsedConllFile(ConllFile): 402 | def __init__(self, keepMalformed=False, projectivize=False, logStats=False): 403 | super().__init__(parsed=True, keepMalformed=keepMalformed, 404 | projectivize=projectivize, logStats=logStats) 405 | -------------------------------------------------------------------------------- /corpus/dsindex_postagfix.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding: utf8 -*- 3 | 4 | import os 5 | import sys 6 | from optparse import OptionParser 7 | import time 8 | 9 | # --verbose 10 | VERBOSE = 0 11 | 12 | if __name__ == '__main__': 13 | 14 | parser = OptionParser() 15 | parser.add_option("--verbose", action="store_const", const=1, dest="verbose", help="verbose mode") 16 | (options, args) = parser.parse_args() 17 | 18 | if options.verbose == 1 : VERBOSE = 1 19 | 20 | startTime = time.time() 21 | 22 | data = {} 23 | 24 | while 1 : 25 | try : line = sys.stdin.readline() 26 | except KeyboardInterrupt : break 27 | if not line : break 28 | 29 | line = line.strip() 30 | if not line : 31 | print '' 32 | continue 33 | if line[0] == '#' : 34 | print line 35 | continue 36 | 37 | tokens = line.split('\t') 38 | id = tokens[0] 39 | if '.' in id : continue # ex) 8.1 40 | if tokens[4] == '_' : 41 | tokens[4] = tokens[3] # there is no XPOS 42 | else : 43 | tokens[3] = tokens[4] # UPOS <- XPOS 44 | print '\t'.join(tokens) 45 | 46 | durationTime = time.time() - startTime 47 | sys.stderr.write("duration time = %f\n" % durationTime) 48 | -------------------------------------------------------------------------------- /decoded_parse_reader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from feature_extractor import SparseFeatureExtractor 3 | from sentence_batch import SentenceBatch 4 | from parser_state import ParserState 5 | from arc_standard_transition_system import ArcStandardTransitionSystem, \ 6 | ArcStandardTransitionState 7 | from arc_eager_transition_system import ArcEagerTransitionSystem, \ 8 | ArcEagerTransitionState 9 | 10 | ''' 11 | Provide a batch of decoded sentences to the trainer 12 | 13 | Maintains batch_size slots of sentences, each one with its own parser state 14 | ''' 15 | class DecodedParseReader(object): 16 | def __init__(self, input_corpus, batch_size, feature_strings, feature_maps, 17 | transition_system, epoch_print = True): 18 | self.input_corpus = input_corpus 19 | self.batch_size = batch_size 20 | self.feature_strings = feature_strings 21 | self.feature_maps = feature_maps 22 | self.epoch_print = epoch_print 23 | self.feature_extractor = SparseFeatureExtractor(self.feature_strings, 24 | self.feature_maps) 25 | self.sentence_batch = SentenceBatch(input_corpus, self.batch_size) 26 | self.parser_states = [None for i in range(self.batch_size)] 27 | self.arc_states = [None for i in range(self.batch_size)] 28 | 29 | if transition_system == 'arc-standard': 30 | self.transition_system = ArcStandardTransitionSystem() 31 | self.transition_system_class = ArcStandardTransitionSystem 32 | self.transition_state_class = ArcStandardTransitionState 33 | elif transition_system == 'arc-eager': 34 | self.transition_system = ArcEagerTransitionSystem() 35 | self.transition_system_class = ArcEagerTransitionSystem 36 | self.transition_state_class = ArcEagerTransitionState 37 | else: 38 | assert None, 'transition system must be arc-standard or arc-eager' 39 | 40 | self.logger = logging.getLogger('DecodedParseReader') 41 | self.num_epochs = 0 42 | 43 | self.docids_ = [] 44 | # map docid to sentence 45 | self.sentence_map_ = dict() 46 | 47 | def state(self, i): 48 | assert i >= 0 and i < self.batch_size 49 | return self.parser_states[i] 50 | 51 | ''' 52 | Advance the sentence for slot i 53 | ''' 54 | def advanceSentence(self, i): 55 | assert i >= 0 and i < self.batch_size 56 | if(self.sentence_batch.advanceSentence(i)): 57 | self.parser_states[i] = ParserState(self.sentence_batch.sentence(i), 58 | self.feature_maps) 59 | # necessary for initializing and pushing root 60 | # keep arc_states in sync with parser_states 61 | self.arc_states[i] = \ 62 | self.transition_state_class(self.parser_states[i]) 63 | else: 64 | self.parser_states[i] = None 65 | self.arc_states[i] = None 66 | if self.state(i) != None: 67 | self.docids_.insert(0, self.state(i).sentence().docid()) 68 | 69 | ''' 70 | Perform the next best decoded action for each state 71 | 72 | scores[i][k]: probability of each action k for token state i 73 | as far as I know, raw logits 74 | 75 | filled_count: number of items of scores filled (if 0, forces SHIFT 76 | for the first time). otherwise, should be greater than 0 an 77 | less than or equal to batch_size 78 | ''' 79 | def performActions(self, scores, filled_count): 80 | for batch_index in range(self.batch_size): 81 | state = self.state(batch_index) 82 | if state != None: 83 | # default action if none given 84 | bestAction = self.transition_system_class.SHIFT 85 | bestScore = float('-inf') 86 | 87 | # check to make sure decisions are filled for this batch i 88 | if filled_count > batch_index: 89 | # look through top k estimated transition actions and 90 | # pick most suitable one 91 | for action in range(len(scores[batch_index])): 92 | score = scores[batch_index][action] 93 | if self.transition_system \ 94 | .isAllowedAction(action, state): 95 | self.logger.debug('Slot(%d): action candidate:' 96 | ' %s, score=%.8f - allowed' % (batch_index, \ 97 | self.transition_system.actionAsString( 98 | action, state, 99 | self.feature_maps), score)) 100 | 101 | if score > bestScore: 102 | bestAction = action 103 | bestScore = score 104 | else: 105 | self.logger.debug('Slot(%d): action candidate:' 106 | ' %s, score=%.8f - unallowed' % (batch_index, \ 107 | self.transition_system.actionAsString( 108 | action, state, 109 | self.feature_maps), score)) 110 | 111 | self.logger.debug('Slot(%d): perform action %s, score=%.8f' % 112 | (batch_index, self.transition_system.actionAsString( 113 | bestAction, state, \ 114 | self.feature_maps), bestScore)) 115 | 116 | try: 117 | self.transition_system.performAction( 118 | bestAction, state) 119 | except: 120 | self.logger.debug( 121 | 'Slot(%d): invalid action at batch slot' % batch_index) 122 | 123 | self.transition_system.performAction( 124 | action=self.transition_system.getDefaultAction( 125 | state), state=state) 126 | 127 | if self.transition_system.isFinalState(state): 128 | #self.computeTokenAccuracy(state) 129 | self.sentence_map_ \ 130 | [state.sentence().docid()] = state.sentence() 131 | 132 | self.logger.debug('Slot(%d): final state reached' \ 133 | % batch_index) 134 | 135 | self.addParseToDocument(state, True, \ 136 | self.sentence_map_[state.sentence().docid()]) 137 | 138 | ''' 139 | Concatenate and return feature bags for all sentence slots, grouped 140 | by feature major type 141 | 142 | Returns (None, None, None, ...) if no sentences left 143 | ''' 144 | def nextFeatureBags(self, scores, filled_count): 145 | self.performActions(scores, filled_count) 146 | 147 | for i in range(self.batch_size): 148 | if self.state(i) == None: 149 | continue 150 | 151 | while(self.transition_system.isFinalState(self.state(i))): 152 | self.logger.debug('Advancing sentence %d' % i) 153 | self.advanceSentence(i) 154 | if self.state(i) == None: 155 | break 156 | 157 | if self.sentence_batch.size() == 0: 158 | self.num_epochs += 1 159 | if self.epoch_print: 160 | self.logger.info('Starting epoch %d' % self.num_epochs) 161 | self.sentence_batch.rewind() 162 | for i in range(self.batch_size): 163 | self.advanceSentence(i) 164 | 165 | # a little bit different from SyntaxNet: 166 | # we don't support feature groups 167 | # we automatically group together the similar types 168 | # features_output = [[] for i in range(self.feature_strings)] 169 | features_major_types = None 170 | features_output = None 171 | 172 | filled_count = 0 173 | 174 | # Populate feature outputs 175 | for i in range(self.batch_size): 176 | if self.state(i) == None: 177 | continue 178 | 179 | fvec = self.feature_extractor.extract(self.state(i)) 180 | assert len(fvec.types) == len(self.feature_strings) 181 | major_types, ids = fvec.concatenateSimilarTypes() 182 | 183 | if features_output == None: 184 | features_major_types = [t for t in major_types] 185 | features_output = [[] for t in major_types] 186 | else: 187 | assert len(features_major_types) == len(major_types) 188 | assert len(features_output) == len(major_types) 189 | 190 | for k in range(len(features_major_types)): 191 | features_output[k] += ids[k] 192 | 193 | filled_count += 1 194 | 195 | return features_major_types, features_output, self.num_epochs, \ 196 | filled_count 197 | 198 | ''' 199 | Adds transition state specific annotations to the document 200 | ''' 201 | def addParseToDocument(self, state, rewrite_root_labels, sentence): 202 | for i in range(state.numTokens()): 203 | token = sentence.mutableToken(i) 204 | 205 | try: 206 | token.setParsedLabel( 207 | self.feature_maps['label'].indexToValue(state.label(i))) 208 | except: 209 | # label failure (happens often in ARC-EAGER due to SHIFT/ 210 | # REDUCE sequences) 211 | # TODO: dis-allow REDUCE if no labels assigned? 212 | pass 213 | 214 | if (state.head(i) != -1): 215 | token.setParsedHead(state.head(i)) 216 | else: 217 | token.clearParsedHead() 218 | if rewrite_root_labels: 219 | token.setParsedLabel('ROOT') 220 | 221 | ''' 222 | Concatenate and return sentence annotations for all sentence slots 223 | 224 | Returns (None, None, None, ...) if no sentences left 225 | ''' 226 | def getNextAnnotations(self): 227 | sentences = [] 228 | while (len(self.docids_) > 0) and \ 229 | (self.docids_[-1] in self.sentence_map_): 230 | 231 | self.logger.debug('Sentence(%d): %s' % (self.docids_[-1], \ 232 | str(self.sentence_map_[self.docids_[-1]].tokens))) 233 | 234 | sentences.append(self.sentence_map_[self.docids_[-1]]) 235 | del self.sentence_map_[self.docids_[-1]] 236 | self.docids_.pop() 237 | 238 | return sentences 239 | -------------------------------------------------------------------------------- /dep_parser.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import hashlib 6 | import sys 7 | import pickle 8 | import copy 9 | import logging 10 | import math 11 | import tensorflow as tf 12 | import numpy as np 13 | import sys 14 | import os 15 | import argparse 16 | import json 17 | import random 18 | 19 | from model_parameters import * 20 | from lexicon import * 21 | from utils import * 22 | from conll_utils import * 23 | from feature_extractor import SparseFeatureExtractor 24 | from parser_state import ParserState 25 | from arc_standard_transition_system import ArcStandardTransitionState, \ 26 | ArcStandardTransitionSystem 27 | from arc_eager_transition_system import ArcEagerTransitionState, \ 28 | ArcEagerTransitionSystem 29 | from gold_parse_reader import GoldParseReader 30 | from decoded_parse_reader import DecodedParseReader 31 | from tensorflow.python.ops import state_ops 32 | 33 | logger = logging.getLogger('DepParser') 34 | 35 | parser = argparse.ArgumentParser( 36 | description='Train a Chen and Manning-style neural network dependency' \ 37 | ' parser') 38 | 39 | # Required positional argument 40 | parser.add_argument('model_folder', type=str, 41 | help='Folder in which to load or save model') 42 | parser.add_argument('training_file', type=str, 43 | help='CoNLL-U format tagged training corpus (UTF-8)') 44 | parser.add_argument('testing_file', type=str, 45 | help='CoNLL-U format tagged evaluation corpus (UTF-8)') 46 | parser.add_argument('--train', action='store_true', default=False, 47 | help='Training a new model or continue training of an ' 48 | 'old model') 49 | parser.add_argument('--evaluate', action='store_true', default=False, 50 | help='Evaluate an existing model') 51 | parser.add_argument('--debug', action='store_true', default=False, 52 | help='Enable verbose debug lines') 53 | parser.add_argument('--restart', action='store_true', default=False, 54 | help='Re-train model from scratch instead of restoring ' 55 | 'a previously saved model') 56 | parser.add_argument('--epochs', type=int, default=10, 57 | help='Number of epochs to run (run-throughs over all ' 58 | 'training corpus feature bags). Default 10') 59 | parser.add_argument('--scoring-strategy', type=str, default='default', 60 | help='Choices: "default", "conllx", "ignore_parens"') 61 | #parser.add_argument('--feature-bag', type=str, 62 | # help='Specify pre-created feature bag file to save' \ 63 | # ' computation time (saved in model dir by default') 64 | #parser.add_argument('--epochs', type=int, default=10, 65 | # help='Training epochs (default 10). Shuffle sentences ' 66 | # ' and re-train during each training epoch.') 67 | 68 | ## TODO: 69 | # add param: use pretrained word/sense embeddings gensim/Mikolov 70 | 71 | ## FIXME: 72 | # trying to continue training with different corpus should throw better error 73 | 74 | args = parser.parse_args() 75 | 76 | if args.debug: 77 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', 78 | level=logging.DEBUG) 79 | else: 80 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', 81 | level=logging.INFO) 82 | 83 | try: 84 | os.makedirs(args.model_folder) 85 | except: 86 | pass 87 | 88 | if (args.train and args.evaluate) or ((not args.train) and (not args.evaluate)): 89 | print('Please specify either training or evaluation mode ' 90 | '(--train/--evaluate)') 91 | sys.exit(1) 92 | 93 | if not (args.scoring_strategy == 'default' or \ 94 | args.scoring_strategy == 'conllx' or \ 95 | args.scoring_strategy == 'ignore_parens'): 96 | print('Unknown scoring strategy "%s"' % args.scoring_strategy) 97 | sys.exit(1) 98 | 99 | def batchedSparseToDense(sparse_indices, output_size): 100 | """Batch compatible sparse to dense conversion. 101 | 102 | This is useful for one-hot coded target labels. 103 | 104 | Args: 105 | sparse_indices: [batch_size] tensor containing one index per batch 106 | output_size: needed in order to generate the correct dense output 107 | 108 | Returns: 109 | A [batch_size, output_size] dense tensor. 110 | """ 111 | eye = tf.diag(tf.fill([output_size], tf.constant(1, tf.float32))) 112 | return tf.nn.embedding_lookup(eye, sparse_indices) 113 | 114 | def embeddingLookupFeatures(params, ids): 115 | """Computes embeddings for each entry of sparse features sparse_features. 116 | 117 | Args: 118 | params: list of 2D tensors containing vector embeddings 119 | sparse_features: 1D tensor of strings. Each entry is a string encoding of 120 | dist_belief.SparseFeatures, and represents a variable length list of 121 | feature ids, and optionally, corresponding weights values. 122 | allow_weights: boolean to control whether the weights returned from the 123 | SparseFeatures are used to multiply the embeddings. 124 | 125 | Returns: 126 | A tensor representing the combined embeddings for the sparse features. 127 | For each entry s in sparse_features, the function looks up the embeddings 128 | for each id and sums them into a single tensor weighing them by the 129 | weight of each id. It returns a tensor with each entry of sparse_features 130 | replaced by this combined embedding. 131 | """ 132 | if not isinstance(params, list): 133 | params = [params] 134 | 135 | # Lookup embeddings. 136 | embeddings = tf.nn.embedding_lookup(params, ids) 137 | 138 | return embeddings 139 | 140 | ''' 141 | Takes an SHA-1 hash of a file 142 | (Useful for hashing training corpus) 143 | ''' 144 | def fileHash(fname): 145 | fd = open(fname, 'rb') 146 | retval = hashlib.sha1(fd.read()).hexdigest() 147 | fd.close() 148 | return retval 149 | 150 | ''' 151 | Entry point for dependency parser 152 | ''' 153 | class Parser(object): 154 | def __init__(self, modelParams): 155 | self.logger = logging.getLogger('Parser') 156 | self.modelParams = modelParams 157 | 158 | self.variables = {} 159 | self.params = {} 160 | self.trainableParams = [] 161 | self.inits = {} 162 | self.averaging = {} 163 | self.averaging_decay = self.modelParams.cfg['averagingDecay'] 164 | self.use_averaging = True 165 | self.check_parameters = True 166 | self.training = {} 167 | self.evaluation = {} 168 | 169 | with tf.name_scope('params') as self._param_scope: 170 | pass 171 | 172 | #self.trainingCorpus = None 173 | #self.testingCorpus = None 174 | 175 | def getStep(self): 176 | def onesInitializer(shape, dtype=tf.float32, partition_info=None): 177 | return tf.ones(shape, dtype) 178 | return self.addVariable([], tf.int32, 'step', onesInitializer) 179 | 180 | def incrementCounter(self, counter): 181 | return state_ops.assign_add(counter, 1, use_locking=True) 182 | 183 | def addLearningRate(self, initial_learning_rate, decay_steps): 184 | """Returns a learning rate that decays by 0.96 every decay_steps. 185 | 186 | Args: 187 | initial_learning_rate: initial value of the learning rate 188 | decay_steps: decay by 0.96 every this many steps 189 | 190 | Returns: 191 | learning rate variable. 192 | """ 193 | step = self.getStep() 194 | return cf.with_dependencies( 195 | [self.incrementCounter(step)], 196 | tf.train.exponential_decay(initial_learning_rate, 197 | step, 198 | decay_steps, 199 | 0.96, 200 | staircase=True)) 201 | 202 | def addVariable(self, shape, dtype, name, initializer=None): 203 | if name in self.variables: 204 | return self.variables[name] 205 | self.variables[name] = tf.get_variable(name, shape, dtype, initializer) 206 | if initializer is not None: 207 | self.inits[name] = state_ops.init_variable(self.variables[name], 208 | initializer) 209 | return self.variables[name] 210 | 211 | ''' 212 | Don't use variable_scope, as param names will overwrite each other 213 | ''' 214 | def addParam(self, shape, dtype, name, initializer=None, 215 | return_average=False): 216 | # this isn't a problem. we reload variables if they already exist. 217 | #if name in self.params: 218 | # self.logger.warning(name + ' already exists!') 219 | 220 | if name not in self.params: 221 | step = tf.cast(self.getStep(), tf.float32) 222 | with tf.name_scope(self._param_scope): 223 | # Put all parameters and their initializing ops in their own 224 | # scope irrespective of the current scope (training or eval). 225 | self.params[name] = tf.get_variable(name, shape, dtype, 226 | initializer) 227 | param = self.params[name] 228 | 229 | if initializer is not None: 230 | self.inits[name] = state_ops.init_variable(param, 231 | initializer) 232 | if self.averaging_decay == 1: 233 | self.logging.info('Using vanilla averaging of parameters.') 234 | ema = tf.train.ExponentialMovingAverage( 235 | decay=(step / (step + 1.0)), num_updates=None) 236 | else: 237 | ema = tf.train.ExponentialMovingAverage( 238 | decay=self.averaging_decay, num_updates=step) 239 | 240 | self.averaging[name + '_avg_update'] = ema.apply([param]) 241 | self.variables[name + '_avg_var'] = ema.average(param) 242 | self.inits[name + '_avg_init'] = state_ops.init_variable( 243 | ema.average(param), tf.zeros_initializer()) 244 | return (self.variables[name + '_avg_var'] if return_average else 245 | self.params[name]) 246 | 247 | def addEmbedding(self, features, num_features, num_ids, embedding_size, 248 | major_type, return_average=False): 249 | initializer = tf.random_normal_initializer( 250 | stddev=1.0 / embedding_size**.5, \ 251 | seed=0) 252 | 253 | embedding_matrix = self.addParam( 254 | [num_ids, embedding_size], 255 | tf.float32, 256 | 'embedding_matrix_%s' % major_type, 257 | initializer, 258 | return_average=return_average) 259 | 260 | embedding = embeddingLookupFeatures(embedding_matrix, 261 | tf.reshape(features, 262 | [-1], 263 | name='feature_%s' % major_type)) 264 | 265 | return tf.reshape(embedding, [-1, num_features * embedding_size]) 266 | 267 | ''' 268 | Setup transition and action system and feature maps 269 | (necessary whether training or evaluating) 270 | ''' 271 | def setupParser(self, mode): 272 | hiddenLayerSizes = self.modelParams.cfg['hiddenLayerSizes'] 273 | featureStrings = self.modelParams.cfg['featureStrings'] 274 | embeddingSizes = self.modelParams.cfg['embeddingSizes'] 275 | batchSize = self.modelParams.cfg['batchSize'] 276 | transitionSystem = self.modelParams.cfg['transitionSystem'] 277 | 278 | if transitionSystem == 'arc-standard': 279 | self.transitionSystem = ArcStandardTransitionSystem() 280 | elif transitionSystem == 'arc-eager': 281 | self.transitionSystem = ArcEagerTransitionSystem() 282 | else: 283 | assert None, 'transition system must be arc-standard or arc-eager' 284 | 285 | assert len(hiddenLayerSizes) > 0, 'must have at least one hidden layer' 286 | assert len(featureStrings) == len(set(featureStrings)), \ 287 | 'duplicate feature string detected' 288 | 289 | if mode == 'train': 290 | # determine if we have to compute or read the lexicon 291 | self.logger.info('Computing lexicon from training corpus...') 292 | self.modelParams.lexicon.compute() 293 | self.logger.info('Done building lexicon') 294 | self.modelParams.lexicon.write() 295 | elif mode == 'evaluate': 296 | self.logger.info('Reading lexicon from trained model...') 297 | self.modelParams.lexicon.read() 298 | else: 299 | assert None, 'invalid mode: ' + mode 300 | 301 | self.featureMaps = self.modelParams.lexicon.getFeatureMaps() 302 | 303 | self.logger.info('Feature strings: ' + str(featureStrings)) 304 | 305 | # Get major type groups in sorted order by contructing null parser 306 | # state and extracting features, and then concatenating the similar 307 | # types 308 | fvec = SparseFeatureExtractor(featureStrings, self.featureMaps) \ 309 | .extract(ParserState(ParsedConllSentence(docid=None), 310 | self.featureMaps), doLogging=False) 311 | 312 | featureTypeInstances = fvec.types 313 | self.featureMajorTypeGroups, _ = fvec.concatenateSimilarTypes() 314 | 315 | # index: major feature type index 316 | # values: feature names under that type 317 | self.featureNames = [[] for t in self.featureMajorTypeGroups] 318 | 319 | self.logger.info('Detected major feature groups (in alphabetical ' 320 | 'order): ' + str(self.featureMajorTypeGroups)) 321 | 322 | self.featureDomainSizes = [] 323 | #self.featureEmbeddings = [] 324 | 325 | # For now, use all same embedding sizes 326 | self.featureEmbeddingSizes = \ 327 | [embeddingSizes[t] for t in self.featureMajorTypeGroups] 328 | 329 | self.BAG_OF_FEATURES_LEN = 0 330 | 331 | for i in range(len(featureTypeInstances)): 332 | major_type = featureTypeInstances[i].major_type 333 | major_type_index = self.featureMajorTypeGroups.index(major_type) 334 | 335 | self.featureNames[major_type_index].append( 336 | featureTypeInstances[i].name) 337 | 338 | self.BAG_OF_FEATURES_LEN += \ 339 | (self.featureEmbeddingSizes[major_type_index]) 340 | 341 | for i in range(len(self.featureMajorTypeGroups)): 342 | major_type = self.featureMajorTypeGroups[i] 343 | 344 | self.logger.info('') 345 | self.logger.info('Feature group \'%s\'' % major_type) 346 | self.logger.info('... domain size: %d' % \ 347 | (self.featureMaps[major_type].getDomainSize( \ 348 | includeSpecial=True))) 349 | self.logger.info('... embedding size: %d' % \ 350 | (self.featureEmbeddingSizes[i])) 351 | #self.logger.info('... feature count: %d' % \ 352 | # (len(self.featureNames[i]))) 353 | self.logger.info('... features') 354 | 355 | for fname in self.featureNames[i]: 356 | self.logger.info('....... %s' % (fname)) 357 | 358 | self.logger.info('... total group embedding size: %d' % \ 359 | (len(self.featureNames[i]) * self.featureEmbeddingSizes[i])) 360 | 361 | self.logger.info('... initializing random normal embeddings...') 362 | self.featureDomainSizes.append( 363 | self.featureMaps[major_type].getDomainSize( \ 364 | includeSpecial=True)) 365 | 366 | assert len(self.featureDomainSizes) == len(self.featureEmbeddingSizes) 367 | #assert len(self.featureDomainSizes) == len(self.featureEmbeddings) 368 | assert len(self.featureDomainSizes) == len(self.featureNames) 369 | 370 | self.logger.info('') 371 | self.logger.info('Batch size (number of parser states): %d' % batchSize) 372 | self.logger.info('Total feature count: %d' % \ 373 | (len(featureTypeInstances))) 374 | self.logger.info('Total bag of features length per state: %d' % \ 375 | (self.BAG_OF_FEATURES_LEN)) 376 | self.logger.info('Total features input size: %d' % \ 377 | (batchSize*self.BAG_OF_FEATURES_LEN)) 378 | 379 | # for actions, we don't encode UNKNOWN, ROOT, or OUTSIDE 380 | # we only encode the number of base values 381 | self.ACTION_COUNT = self.transitionSystem.numActions( 382 | self.featureMaps['label'].getDomainSize(includeSpecial=False)) 383 | 384 | self.logger.info('Total action count: %d' % self.ACTION_COUNT) 385 | 386 | ''' 387 | Setup TensorFlow Variables in model 388 | ''' 389 | def buildNetwork(self, mode='train'): 390 | assert mode == 'train' or mode == 'eval' 391 | 392 | if mode == 'train': 393 | return_average = False 394 | nodes = self.training 395 | else: 396 | return_average = self.use_averaging 397 | nodes = self.evaluation 398 | 399 | learningRate = self.modelParams.cfg['learningRate'] 400 | decaySteps = self.modelParams.cfg['decaySteps'] 401 | # FIXME: does momentum/learning rate reload properly when retraining? 402 | momentum = self.modelParams.cfg['momentum'] 403 | topK = self.modelParams.cfg['topK'] 404 | hiddenLayerSizes = self.modelParams.cfg['hiddenLayerSizes'] 405 | batchSize = self.modelParams.cfg['batchSize'] 406 | 407 | with tf.name_scope(mode): 408 | weights = [] 409 | biases = [] 410 | embeddings = [] 411 | nodes['feature_endpoints'] = [] 412 | 413 | for i in range(len(self.featureMajorTypeGroups)): 414 | major_type = self.featureMajorTypeGroups[i] 415 | # shape will be [-1, number of sparse integer features in group] 416 | nodes['feature_endpoints'].append(tf.placeholder(tf.int32, \ 417 | [None, len(self.featureNames[i])], 418 | name="ph_feature_endpoints_%s" % major_type)) 419 | embeddings.append(self.addEmbedding( \ 420 | nodes['feature_endpoints'][i], 421 | len(self.featureNames[i]), 422 | self.featureDomainSizes[i], 423 | self.featureEmbeddingSizes[i], 424 | major_type, 425 | return_average=return_average)) 426 | 427 | # Input layer 428 | last_layer = tf.concat(embeddings, 1) 429 | last_layer_size = self.BAG_OF_FEATURES_LEN 430 | 431 | # Hidden layers 432 | for i in range(len(hiddenLayerSizes)): 433 | h = hiddenLayerSizes[i] 434 | 435 | weights.append(self.addParam( 436 | [last_layer_size, h], 437 | tf.float32, 438 | 'layer_%d_weights' % i, 439 | tf.random_normal_initializer(stddev=1e-4, seed=0), 440 | return_average=return_average)) 441 | 442 | biases.append(self.addParam( 443 | [h], 444 | tf.float32, 445 | 'layer_%d_biases' % i, 446 | tf.constant_initializer(0.2), 447 | return_average=return_average)) 448 | 449 | last_layer = tf.nn.relu_layer(last_layer, 450 | weights[-1], 451 | biases[-1], 452 | name='layer_%d' % i) 453 | last_layer_size = h 454 | 455 | # Output layer 456 | weights.append(self.addParam( 457 | [last_layer_size, self.ACTION_COUNT], 458 | tf.float32, 459 | 'softmax_weights', 460 | tf.random_normal_initializer(stddev=1e-4, seed=0), 461 | return_average=return_average)) 462 | 463 | biases.append(self.addParam( 464 | [self.ACTION_COUNT], 465 | tf.float32, 466 | 'softmax_biases', 467 | tf.zeros_initializer(), 468 | return_average=return_average)) 469 | 470 | logits = tf.nn.xw_plus_b(last_layer, 471 | weights[-1], 472 | biases[-1], 473 | name='logits') 474 | 475 | if mode == 'train': 476 | nodes['gold_actions'] = tf.placeholder(tf.int32, [None], \ 477 | name='ph_gold_actions') 478 | nodes['filled_slots'] = tf.placeholder(tf.int32, \ 479 | name='ph_filled_slots') 480 | 481 | # one-hot encoding for each batch 482 | dense_golden = batchedSparseToDense(nodes['gold_actions'], \ 483 | self.ACTION_COUNT) 484 | 485 | #cross_entropy = tf.div( 486 | # tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits( 487 | # logits=logits, labels=dense_golden)), 488 | # tf.cast(nodes['filled_slots'], tf.float32)) 489 | 490 | # we should divide by batch size here, not filled slots 491 | # seems to fix the accuracy issue for whatever reason, 492 | # even though cost seems to go crazy momentarily 493 | # (plummets because only a few slots are filled) 494 | cross_entropy = tf.div( 495 | tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits( 496 | logits=logits, labels=dense_golden)), 497 | batchSize) 498 | 499 | # regularize all parameters except output layer 500 | regularized_params = [tf.nn.l2_loss(p) for p in weights[:-1]] 501 | regularized_params += [tf.nn.l2_loss(p) for p in biases[:-1]] 502 | 503 | l2_loss = 1e-4 * tf.add_n(regularized_params) \ 504 | if regularized_params else 0 505 | 506 | cost = tf.add(cross_entropy, l2_loss, name='cost') 507 | 508 | lr = self.addLearningRate(learningRate, decaySteps) 509 | 510 | optimizer = tf.train.MomentumOptimizer(lr, 511 | momentum, 512 | use_locking=False) 513 | 514 | trainableParams = self.params.values() 515 | 516 | train_op = optimizer.minimize(cost, var_list=trainableParams) 517 | 518 | for param in trainableParams: 519 | slot = optimizer.get_slot(param, 'momentum') 520 | self.inits[slot.name] = state_ops.init_variable(slot, 521 | tf.zeros_initializer()) 522 | self.variables[slot.name] = slot 523 | 524 | numerical_checks = [ 525 | tf.check_numerics(param, 526 | message='Parameter is not finite.') 527 | for param in trainableParams 528 | if param.dtype.base_dtype in [tf.float32, tf.float64] 529 | ] 530 | check_op = tf.group(*numerical_checks) 531 | avg_update_op = tf.group(*self.averaging.values()) 532 | train_ops = [train_op] 533 | if self.check_parameters: 534 | train_ops.append(check_op) 535 | if self.use_averaging: 536 | train_ops.append(avg_update_op) 537 | 538 | nodes['train_op'] = tf.group(*train_ops, name='train_op') 539 | nodes['cost'] = cost 540 | nodes['logits'] = logits 541 | #nodes['softmax'] = tf.nn.softmax(logits) 542 | else: 543 | nodes['logits'] = logits 544 | #nodes['softmax'] = tf.nn.softmax(logits) 545 | 546 | ''' 547 | Serialize the feature definitions 548 | (so that we can determine when they change) 549 | ''' 550 | def serializeFeatureDef(self): 551 | d = [] 552 | 553 | bs = self.modelParams.cfg['batchSize'] 554 | d.append(bs) 555 | 556 | # when transition system changes, so do the gold actions 557 | ts = self.modelParams.cfg['transitionSystem'] 558 | d.append(ts) 559 | 560 | # if projectivize parameter is changed, we may have to recalculate 561 | # features as well (in case there are non-projective sentences) 562 | p = self.modelParams.cfg['projectivizeTrainingSet'] 563 | d.append(p) 564 | 565 | fs = self.modelParams.cfg['featureStrings'] 566 | # order doesn't matter 567 | fs.sort() 568 | d.append(fs) 569 | 570 | e = [] 571 | # because dictionaries aren't ordered... 572 | for (k, v) in self.modelParams.cfg['embeddingSizes'].items(): 573 | e.append((k,v)) 574 | # sort by key 575 | e.sort() 576 | 577 | d.append(e) 578 | return json.dumps(d) 579 | 580 | ''' 581 | Generate or load pre-computed feature bags 582 | ''' 583 | def obtainFeatureBags(self, trainingFileName): 584 | batchSize = self.modelParams.cfg['batchSize'] 585 | projectivizeTrainingSet = self.modelParams.cfg \ 586 | ['projectivizeTrainingSet'] 587 | transitionSystem = self.modelParams.cfg['transitionSystem'] 588 | 589 | activeFeatureDef = self.serializeFeatureDef().strip() 590 | activeCorpusHash = fileHash(trainingFileName) 591 | 592 | cachedFeatureDef = None 593 | try: 594 | fd = open(self.modelParams.getFilePath('feature-def'), 'r', 595 | encoding='utf-8') 596 | cachedFeatureDef = fd.read().strip() 597 | fd.close() 598 | except: 599 | cachedFeatureDef = None 600 | 601 | cachedCorpusHash = None 602 | try: 603 | fd = open(self.modelParams.getFilePath('training-corpus-hash'), 'r', 604 | encoding='utf-8') 605 | cachedCorpusHash = fd.read().strip() 606 | self.logger.debug('Cached corpus hash: %s' % cachedCorpusHash) 607 | fd.close() 608 | except: 609 | cachedCorpusHash = None 610 | 611 | self.logger.debug('Training corpus hash: %s' % activeCorpusHash) 612 | self.logger.debug('Cached corpus hash: %s' % cachedCorpusHash) 613 | 614 | self.logger.debug('Active feature definition: %s' % activeFeatureDef) 615 | self.logger.debug('Cached feature definition: %s' % cachedFeatureDef) 616 | 617 | if activeFeatureDef == cachedFeatureDef and \ 618 | activeCorpusHash == cachedCorpusHash: 619 | self.logger.info('Loading pre-existing feature bags...') 620 | fd = open(self.modelParams.getFilePath('feature-bag-bin'), 'rb') 621 | batches = pickle.load(fd) 622 | fd.close() 623 | else: 624 | featureStrings = self.modelParams.cfg['featureStrings'] 625 | self.logger.info('Feature bag needs recalculation (first training' \ 626 | ' or features changed)') 627 | 628 | # parameters here must match parameters during lexicon generation 629 | trainingCorpus = ParsedConllFile(keepMalformed=False, 630 | projectivize=projectivizeTrainingSet) 631 | 632 | trainingCorpus.read(open(self.modelParams.trainingFile, 'r', 633 | encoding='utf-8').read()) 634 | 635 | # Start getting sentence batches... 636 | reader = GoldParseReader(trainingCorpus, batchSize, \ 637 | featureStrings, self.featureMaps, transitionSystem, 638 | epoch_print=False) 639 | 640 | batches = [] 641 | 642 | i = 0 643 | while(True): 644 | self.logger.info('Generating feature bag #%d...' % (i+1)) 645 | reader_output = reader.nextFeatureBags() 646 | if reader_output[0] == None: 647 | self.logger.debug('Iter(%d): reader output is None' % i) 648 | break 649 | 650 | features_major_types, features_output, gold_actions, \ 651 | epoch_num = reader_output 652 | 653 | if epoch_num > 1: 654 | # don't make more than one epoch 655 | break 656 | 657 | batches.append(reader_output) 658 | i += 1 659 | 660 | self.logger.info('Saving feature bags...') 661 | 662 | fd = open(self.modelParams.getFilePath('feature-bag-bin'), 'wb') 663 | pickle.dump(batches, fd) 664 | fd.close() 665 | 666 | fd = open(self.modelParams.getFilePath('feature-def'), 'w', 667 | encoding='utf-8') 668 | fd.write(activeFeatureDef) 669 | fd.close() 670 | 671 | fd = open(self.modelParams.getFilePath('training-corpus-hash'), 'w', 672 | encoding='utf-8') 673 | fd.write(activeCorpusHash) 674 | fd.close() 675 | 676 | return batches 677 | 678 | ''' 679 | Start training from scratch, or from where we left off 680 | ''' 681 | def startTraining(self, sess, epochs_to_run=10, restart=False): 682 | batchSize = self.modelParams.cfg['batchSize'] 683 | featureStrings = self.modelParams.cfg['featureStrings'] 684 | 685 | ckpt_dir = fixPath(self.modelParams.modelFolder) + '/' 686 | saver = tf.train.Saver() 687 | 688 | if restart: 689 | self.logger.info('Start fitting') 690 | else: 691 | ckpt = tf.train.get_checkpoint_state(ckpt_dir) 692 | if ckpt and ckpt.model_checkpoint_path: 693 | # Restore variables from disk. 694 | saver.restore(sess, ckpt.model_checkpoint_path) 695 | self.logger.info('Model restored') 696 | self.logger.info('Continue fitting') 697 | else: 698 | self.logger.info('Start fitting') 699 | 700 | print_freq = 10 701 | 702 | save_freq = 500 703 | #eval_freq = 200 704 | 705 | batches = self.obtainFeatureBags(self.modelParams.trainingFile) 706 | 707 | if epochs_to_run <= 0: 708 | # just do attachment metric if epochs is 0 709 | self.attachmentMetric(sess, runs=200, mode='testing') 710 | return 711 | 712 | epoch_num = 0 713 | 714 | while epoch_num < epochs_to_run: 715 | i = 0 716 | while i < len(batches): 717 | reader_output = batches[i] 718 | if reader_output[0] == None: 719 | self.logger.debug('Iter(%d): reader output is None' % i) 720 | break 721 | 722 | ''' 723 | epoch_num refers to the number of run-throughs through the 724 | whole training corpus, whereas `i` is just the batch 725 | iteration number 726 | ''' 727 | 728 | features_major_types, features_output, gold_actions, \ 729 | _ = reader_output 730 | 731 | filled_count = len(gold_actions) 732 | if filled_count < batchSize: 733 | # break out (partial batches seem to completely ruin the 734 | # model for whatever reason) 735 | # use continue because in case we shuffle the outer 736 | # dimension, we might get the partial batches in the 737 | # middle 738 | # FIXME: investigate what SyntaxNet does in this case 739 | # have a feeling this might be negatively affecting 740 | # attachmentMetric() function as well, which does process 741 | # partial batches 742 | i += 1 743 | continue 744 | pass 745 | 746 | #print('feature(0) len: %d' % len(features_output[0])) 747 | #print('feature(1) len: %d' % len(features_output[1])) 748 | #print('feature(2) len: %d' % len(features_output[2])) 749 | 750 | # debug: print out first 40 actions (useful to compare with 751 | # SyntaxNet) 752 | self.logger.debug('gold_actions: %s' % \ 753 | str(gold_actions[:40])) 754 | 755 | assert len(self.training['feature_endpoints']) == \ 756 | len(features_output) 757 | 758 | feed_dict = {} 759 | for k in range(len(self.training['feature_endpoints'])): 760 | features_output[k] = np.asarray(features_output[k]) 761 | feed_dict[self.training['feature_endpoints'][k]] = \ 762 | features_output[k].reshape( \ 763 | [-1, len(self.featureNames[k])]) 764 | 765 | feed_dict[self.training['filled_slots']] = filled_count 766 | feed_dict[self.training['gold_actions']] = gold_actions 767 | 768 | c, _ = sess.run([self.training['cost'], 769 | self.training['train_op']], 770 | feed_dict=feed_dict) 771 | 772 | if i > 0 and i % print_freq == 0: 773 | self.logger.info('Epoch: %04d Iter: %06d cost=%s' % \ 774 | (epoch_num+1, i+1, "{:.2f}".format(c))) 775 | #self.quickEvaluationMetric(sess, mode='training') 776 | # reset avg 777 | #avg_cost = 0.0 778 | 779 | if i > 0 and i % save_freq == 0: 780 | save_path = saver.save(sess, ckpt_dir + 'model.ckpt') 781 | self.logger.info('Model saved to file: %s' % save_path) 782 | #self.attachmentMetric(sess, runs=100, mode='training') 783 | #self.attachmentMetric(sess, runs=100, mode='testing') 784 | 785 | #if i > 0 and i % eval_freq == 0: 786 | # self.attachmentMetric(sess, runs=200) 787 | 788 | i += 1 789 | 790 | epoch_num += 1 791 | 792 | if epoch_num < epochs_to_run: 793 | # evaluate now. otherwise evaluate after training 794 | # complete message is shown 795 | #self.attachmentMetric(sess, runs=100, mode='training') 796 | pass 797 | else: 798 | self.logger.info('Training is complete (%d epochs)' % \ 799 | epochs_to_run) 800 | save_path = saver.save(sess, ckpt_dir + 'model.ckpt') 801 | self.logger.info('Model saved to file: %s' % save_path) 802 | self.attachmentMetric(sess, runs=200, mode='testing') 803 | return 804 | 805 | ''' 806 | Runs features through the network and gets logits 807 | 808 | 'features' must be a list with length being 809 | the number of major feature groups 810 | - Each major index will represent feature group 811 | - Each minor index will represent an id in that feature group 812 | ''' 813 | def feedForward(self, sess, features, mode): 814 | assert mode == 'train' or mode == 'eval' 815 | 816 | nodes = None 817 | if mode == 'train': 818 | # training feed-forward never returns exponentially averaged value 819 | nodes = self.training 820 | else: 821 | # evaluation returns exponentially averaged value if enabled 822 | nodes = self.evaluation 823 | 824 | if len(nodes) == 0: 825 | # if not already built... 826 | self.buildNetwork(mode) 827 | 828 | assert len(nodes['feature_endpoints']) == \ 829 | len(features), 'feature group count must match' 830 | 831 | feed_dict = {} 832 | for k in range(len(nodes['feature_endpoints'])): 833 | feed_dict[nodes['feature_endpoints'][k]] = \ 834 | np.asarray(features[k]).reshape( \ 835 | [-1, len(self.featureNames[k])]) 836 | 837 | logits = sess.run(nodes['logits'], feed_dict=feed_dict) 838 | return np.asarray(logits) 839 | 840 | def attachmentMetric(self, sess, runs=200, mode='testing'): 841 | batchSize = self.modelParams.cfg['batchSize'] 842 | transitionSystem = self.modelParams.cfg['transitionSystem'] 843 | 844 | #batchSize = 128 # let's try a smaller batch for evaluation 845 | featureStrings = self.modelParams.cfg['featureStrings'] 846 | topK = self.modelParams.cfg['topK'] 847 | 848 | assert mode == 'testing' or mode == 'training' 849 | 850 | testingCorpus = ParsedConllFile() 851 | if mode == 'testing': 852 | testingCorpus.read(open(self.modelParams.testingFile, 'r', 853 | encoding='utf-8').read()) 854 | elif mode == 'training': 855 | testingCorpus.read(open(self.modelParams.trainingFile, 'r', 856 | encoding='utf-8').read()) 857 | 858 | # evaluate sentence-wide accuracy by UAS and LAS 859 | # of course, token errors can accumulate and this is why sentence-wide 860 | # accuracy is lower than token-only accuracy given by 861 | # quickEvaluationMetric() 862 | 863 | # batch size set at one temporarily 864 | test_reader_decoded = DecodedParseReader(testingCorpus, \ 865 | batchSize, featureStrings, self.featureMaps, transitionSystem, 866 | epoch_print=False) 867 | 868 | correctActions = 0 869 | correctElems = 0 870 | totalElems = 0 871 | 872 | outputs = [] 873 | 874 | filled_count = 0 875 | 876 | # eventually will be (filled_count, num_actions) 877 | logits = np.asarray([]) 878 | 879 | test_runs = runs 880 | for i in range(test_runs): 881 | logger.debug('Evaluation(batch %d)' % i) 882 | test_reader_output = test_reader_decoded.nextFeatureBags( 883 | logits, filled_count) 884 | 885 | if test_reader_output[0] == None: 886 | logger.critical('Reader error') 887 | return 888 | 889 | features_major_types, features_output, epochs, \ 890 | filled_count = test_reader_output 891 | 892 | logits = self.feedForward(sess=sess, features=features_output, 893 | mode='eval') 894 | 895 | logger.info('Evaluating batch %d/%d...' % (i+1, test_runs)) 896 | 897 | sentences = test_reader_decoded.getNextAnnotations() 898 | outputs.append(sentences) 899 | 900 | token_count = 0 901 | deprel_correct = 0 902 | head_correct = 0 903 | deprel_and_head_correct = 0 904 | 905 | for sentences in outputs: 906 | logger.info('-'*20) 907 | for sentence in sentences: 908 | logger.info('-'*20) 909 | #logger.info([w for w in sentence.tokens]) 910 | for w in sentence.tokens: 911 | suffix = '' 912 | 913 | gold_head = w.HEAD 914 | gold_deprel = w.DEPREL 915 | if gold_head == -1: 916 | gold_deprel = 'ROOT' 917 | 918 | if w.parsedHead == -1: 919 | # make it simple 920 | w.parsedLabel = 'ROOT' 921 | 922 | if shouldScoreToken(w.FORM, w.UPOSTAG, 923 | self.modelParams.scoring_strategy): 924 | if w.parsedLabel == gold_deprel: 925 | deprel_correct += 1 926 | else: 927 | suffix = 'L' 928 | 929 | if w.parsedHead == gold_head: 930 | head_correct += 1 931 | else: 932 | suffix += 'H' 933 | 934 | if w.parsedLabel == gold_deprel and \ 935 | w.parsedHead == gold_head: 936 | deprel_and_head_correct += 1 937 | # mark both correct 938 | suffix = 'O' 939 | 940 | token_count += 1 941 | 942 | if w.parsedHead == -1: 943 | logger.info('%-20s%-10s%-5d%-5s' % \ 944 | (w.FORM, 'ROOT', w.parsedHead, suffix)) 945 | else: 946 | logger.info('%-20s%-10s%-5d%-5s' % \ 947 | (w.FORM, w.parsedLabel, w.parsedHead, suffix)) 948 | else: 949 | logger.debug('Not scoring token: form="%s", tag="%s"' \ 950 | % (w.FORM, w.UPOSTAG)) 951 | 952 | if token_count <= 0: 953 | logger.warning('No tokens to calculate Attachment Error Metric') 954 | return 955 | 956 | # errors that accumulate (tokens are tested based on previous decoded 957 | # decisions, which could screw up shifting and arcing, etc) 958 | # SyntaxNet uses UAS (HEAD-only) for its evaluation during training! 959 | logger.info('Attachment Error Metric (%s_set)' % mode) 960 | logger.info('Scoring Strategy: %s' % \ 961 | self.modelParams.scoring_strategy) 962 | logger.info('Accuracy(UAS): %d/%d (%.2f%%)' % \ 963 | (head_correct, token_count, 964 | 100.0 * float(head_correct) / float(token_count))) 965 | logger.info('Accuracy(LAS): %d/%d (%.2f%%)' % \ 966 | (deprel_and_head_correct, token_count, 967 | 100.0 * float(deprel_and_head_correct) / float(token_count))) 968 | logger.info('Accuracy(DepRel): %d/%d (%.2f%%)' % \ 969 | (deprel_correct, token_count, 970 | 100.0 * float(deprel_correct) / float(token_count))) 971 | 972 | def __main__(): 973 | modelParams = ModelParameters(args.model_folder) 974 | modelParams.trainingFile = args.training_file 975 | modelParams.testingFile = args.testing_file 976 | 977 | # set variables from parser-config and isolate them in a separate namespace 978 | # to avoid collisions with this code 979 | fd = open(modelParams.getFilePath('parser-config'), 'r', \ 980 | encoding='utf-8') 981 | configFile = fd.read() 982 | fd.close() 983 | 984 | fd = open(modelParams.getFilePath('trained-config'), 'w', \ 985 | encoding='utf-8') 986 | fd.write(configFile) 987 | fd.close() 988 | 989 | compile(configFile, '', 'exec') 990 | configNamespace = {} 991 | exec(configFile, configNamespace) 992 | 993 | requiredFields = ['learningRate', 'batchSize', 'topK', 994 | 'hiddenLayerSizes', 'embeddingSizes', 'featureStrings', 995 | 'momentum', 'projectivizeTrainingSet', 'transitionSystem'] 996 | for field in requiredFields: 997 | assert configNamespace[field] != None, 'please set %s in config' % field 998 | 999 | modelParams.cfg = configNamespace 1000 | modelParams.lexicon = Lexicon(modelParams) 1001 | modelParams.scoring_strategy = args.scoring_strategy 1002 | 1003 | if args.train: 1004 | config = tf.ConfigProto() 1005 | #config.gpu_options.allow_growth=True 1006 | #config.gpu_options.per_process_gpu_memory_fraction=1.0 1007 | 1008 | # very important for Parser to be under a session scope 1009 | with tf.Session(config=config) as sess: 1010 | parser = Parser(modelParams) 1011 | #print(parser.inits.values()) 1012 | 1013 | # perform variable initialization 1014 | 1015 | parser.setupParser('train') 1016 | parser.buildNetwork('train') 1017 | sess.run(list(parser.inits.values())) 1018 | 1019 | writer = tf.summary.FileWriter(modelParams.modelFolder, \ 1020 | graph=tf.get_default_graph()) 1021 | 1022 | parser.startTraining(sess, epochs_to_run=args.epochs, 1023 | restart=args.restart) 1024 | else: 1025 | assert None, 'evaluation mode not implemented' 1026 | 1027 | __main__() 1028 | -------------------------------------------------------------------------------- /feature_extractor.py: -------------------------------------------------------------------------------- 1 | # based on parser_features.cc 2 | 3 | import logging 4 | from conll_utils import ParsedConllSentence, ParsedConllToken 5 | from parser_state import ParserState 6 | 7 | GlobalFeatureStringCache = dict() 8 | 9 | ''' 10 | Represents a feature (input.tag, stack.child(-1).sibling(1).label, etc) 11 | 12 | Currently only tag, word, and label are possible 13 | ''' 14 | class FeatureType(object): 15 | KNOWN_FEATURE_TYPES = ['tag', 'label', 'word'] 16 | 17 | def __init__(self, feature_major_type, feature_name): 18 | assert feature_major_type in FeatureType.KNOWN_FEATURE_TYPES, \ 19 | 'unsupported feature major type ' + str(feature_major_type) 20 | 21 | # 'label', etc 22 | self.major_type = feature_major_type 23 | 24 | # 'stack.child(-1).sibling(1).label', etc 25 | self.name = feature_name 26 | 27 | ''' 28 | Decodes a feature separated between a dot value into a feature name and 29 | argument list 30 | 31 | Only supports integer arguments for now 32 | 33 | Input: 34 | FeatureString: label 35 | 36 | Output: 37 | FeatureName: label 38 | FeatureArgs: [] 39 | -- 40 | Input: 41 | FeatureString: input(0) 42 | 43 | Output: 44 | FeatureName: input 45 | FeatureArgs: [0] 46 | -- 47 | Input: 48 | FeatureString: xxx(0,5) 49 | 50 | Output: 51 | FeatureName: xxx 52 | FeatureArgs: [0, 5] 53 | ''' 54 | def decodeFeatureString(featureString): 55 | if '(' in featureString: 56 | featureName = featureString.split('(')[0] 57 | tmp = featureString.split('(')[1].split(')')[0].split(',') 58 | featureArgs = [] 59 | 60 | for t in tmp: 61 | t = t.strip() 62 | assert t.lstrip('-').isdecimal() 63 | featureArgs.append(int(t)) 64 | 65 | return featureName, featureArgs 66 | else: 67 | return featureString, [] 68 | 69 | ''' 70 | Represents all feature groups' values retrieved for one parser state at once 71 | ''' 72 | class FeatureVector(object): 73 | def __init__(self): 74 | self.types = [] 75 | self.values = [] 76 | 77 | ''' 78 | Returns values concatenated for similar feature major types 79 | (like feature groups per-token) 80 | 81 | However in reality we usually concatenate features with each other at a 82 | batch-level, not at a token-level 83 | ''' 84 | def concatenateSimilarTypes(self): 85 | all_major_types = set() 86 | for t in self.types: 87 | all_major_types.add(t.major_type) 88 | 89 | all_major_types = list(all_major_types) 90 | # for consistency 91 | all_major_types.sort() 92 | 93 | concat_major_types = [] 94 | concat_values = [] 95 | 96 | for t in all_major_types: 97 | concat_major_types.append(t) 98 | concat_values.append([]) 99 | for i in range(len(self.types)): 100 | if self.types[i].major_type == t: 101 | concat_values[-1].append(self.values[i]) 102 | return concat_major_types, concat_values 103 | 104 | ''' 105 | Given feature strings, returns FeatureVector for a particular parser state 106 | ''' 107 | class SparseFeatureExtractor(object): 108 | def __init__(self, feature_strings, feature_maps): 109 | self.feature_strings = feature_strings 110 | self.feature_maps = feature_maps 111 | self.logger = logging.getLogger('SparseFeatureExtractor') 112 | 113 | ''' 114 | doLogging=False: don't log if we're just in init mode where we determine 115 | major feature types during initialization, etc... 116 | ''' 117 | def extract(self, parser, doLogging=True): 118 | fvec = FeatureVector() 119 | for fstr in self.feature_strings: 120 | ftype, fval = self.extractOne(parser, fstr, doLogging=doLogging) 121 | fvec.types.append(ftype) 122 | fvec.values.append(fval) 123 | return fvec 124 | 125 | ''' 126 | featureString: stack(1).child(-1).sibling(1).word 127 | doLogging=False: don't log if we're just in init mode 128 | 129 | This function was optimized for speed, so although previously 130 | each Locator was a separate class, they have now been 131 | inlined into this function. 132 | ''' 133 | def extractOne(self, parser, featureString, doLogging=True): 134 | global GlobalFeatureStringCache 135 | 136 | if featureString not in GlobalFeatureStringCache: 137 | featureParts = featureString.split('.') 138 | 139 | # must reference at least one focus and at least one feature 140 | # (tag/label/etc), therefore, at least two elements 141 | assert len(featureParts) >= 2 142 | 143 | # featureParts: ['stack(1)', 'child(-1)', 'sibling(1)', 'word'] 144 | decodedParts = [] 145 | for p in featureParts: 146 | p = p.strip() 147 | decodedParts.append(decodeFeatureString(p)) 148 | # decodedParts: [('stack', [1]), ('child', [-1]), ('sibling', [1]), 149 | # ('word', [])] 150 | assert len(decodedParts) >= 2 151 | assert(decodedParts[0][0] == 'input' or \ 152 | decodedParts[0][0] == 'stack') 153 | 154 | GlobalFeatureStringCache[featureString] = decodedParts 155 | else: 156 | decodedParts = GlobalFeatureStringCache[featureString] 157 | 158 | # start setting focus and follow focus modifiers until real feature 159 | # (tag/label/etc) 160 | focus = None 161 | feature_name = featureString 162 | feature_major_type = None #featureString.split('.')[-1] 163 | feature_index = None 164 | 165 | for d in decodedParts: 166 | fname = d[0] 167 | fargs = d[1] 168 | if fname == 'input': 169 | ''' 170 | InputParserLocator 171 | args[0]: optional: n index of input(n) 172 | if not specified, index 0 is looked up 173 | ''' 174 | assert feature_index == None, \ 175 | 'can\'t update focus if feature is already set' 176 | if len(fargs) == 0: 177 | fargs=[0] 178 | focus = parser.input(fargs[0]) 179 | elif fname == 'stack': 180 | ''' 181 | StackParserLocator 182 | args[0]: optional: n index of stack(n) 183 | if not specified, index 0 is looked up 184 | ''' 185 | assert feature_index == None, \ 186 | 'can\'t update focus if feature is already set' 187 | if len(fargs) == 0: 188 | fargs=[0] 189 | focus = parser.stack(fargs[0]) 190 | elif fname == 'head': 191 | ''' 192 | HeadFeatureLocator 193 | Arguments: args[0]: number of times to call head() function 194 | ''' 195 | assert focus != None, 'can\'t take HEAD of null focus' 196 | assert feature_index == None, \ 197 | 'can\'t update focus if feature is already set' 198 | 199 | assert len(fargs) == 1 200 | levels = fargs[0] 201 | assert levels >= 1 202 | 203 | # same logic as SyntaxNet 204 | if (focus < -1) or (focus >= parser.numTokens()): 205 | focus = -2 206 | else: 207 | focus = parser.parent(focus, levels) 208 | elif fname == 'child': 209 | ''' 210 | ChildFeatureLocator 211 | Arguments: args[0]: get n'th child 212 | (< 0 indicates leftmost, > 0 indicates rightmost) 213 | ''' 214 | 215 | assert focus != None, 'can\'t take CHILD of null focus' 216 | assert feature_index == None, \ 217 | 'can\'t update focus if feature is already set' 218 | 219 | levels = fargs[0] 220 | assert levels != 0 221 | 222 | # same logic as SyntaxNet 223 | if (focus < -1) or (focus >= parser.numTokens()): 224 | if doLogging: 225 | self.logger.debug('ChildFeatureLocator: focus=-2') 226 | focus = -2 227 | else: 228 | oldfocus = focus 229 | if (levels < 0): 230 | focus = parser.leftmostChild(focus, -levels) 231 | if doLogging: 232 | self.logger.debug( \ 233 | 'ChildFeatureLocator: leftmostChild: ' \ 234 | ' levels=%d,' 235 | ' focus=%d->%d' % (levels, oldfocus, focus)) 236 | else: 237 | focus = parser.rightmostChild(focus, levels) 238 | if doLogging: 239 | self.logger.debug( \ 240 | 'ChildFeatureLocator: rightmostChild: ' \ 241 | ' levels=%d,' \ 242 | ' focus=%d->%d' % (levels, oldfocus, focus)) 243 | elif fname == 'sibling': 244 | ''' 245 | SiblingFeatureLocator 246 | Arguments: args[0]: get n'th sibling 247 | (< 0 indicates to left, > 0 indicates to right) 248 | ''' 249 | 250 | assert focus != None, 'can\'t take SIBLING of null focus' 251 | assert feature_index == None, \ 252 | 'can\'t update focus if feature is already set' 253 | position = fargs[0] 254 | assert position != 0 255 | 256 | # same logic as SyntaxNet 257 | if (focus < -1) or (focus >= parser.numTokens()): 258 | if doLogging: 259 | self.logger.debug('SiblingFeatureLocator: focus=-2') 260 | focus = -2 261 | else: 262 | oldfocus = focus 263 | if (position < 0): 264 | focus = parser.leftSibling(focus, -position) 265 | if doLogging: 266 | self.logger.debug( \ 267 | 'SiblingFeatureLocator: leftSibling: ' \ 268 | 'position=%d, ' \ 269 | 'focus=%d->%d' % (position, oldfocus, focus)) 270 | else: 271 | focus = parser.rightSibling(focus, position) 272 | if doLogging: 273 | self.logger.debug( \ 274 | 'SiblingFeatureLocator: rightSibling: ' \ 275 | 'position=%d, ' \ 276 | 'focus=%d->%d' % (position, oldfocus, focus)) 277 | else: 278 | assert focus != None, 'can\'t request feature of null focus' 279 | assert feature_index == None, \ 280 | 'can\'t request feature when feature is already set; ' \ 281 | 'nested features not supported' 282 | 283 | if doLogging: 284 | self.logger.debug('focus: %d' % focus) 285 | 286 | if fname == 'label': 287 | feature_major_type = 'label' 288 | if focus == -1: 289 | feature_index = \ 290 | self.feature_maps[feature_major_type] \ 291 | .valueToIndex('') 292 | 293 | if doLogging: 294 | self.logger.debug('%s: %d (%s)' % \ 295 | (feature_name, feature_index, '')) 296 | 297 | elif focus < -1 or focus >= parser.numTokens(): 298 | feature_index = \ 299 | self.feature_maps[feature_major_type] \ 300 | .valueToIndex('') 301 | 302 | if doLogging: 303 | self.logger.debug('%s: %d (%s)' % \ 304 | (feature_name, feature_index, '')) 305 | 306 | else: 307 | # pulls label from parser itself, which means it won't 308 | # be gold as long as parser wasn't initialized with 309 | # gold labels 310 | feature_index = parser.label(focus) 311 | 312 | if feature_index == -1: 313 | feature_index = \ 314 | self.feature_maps[feature_major_type] \ 315 | .valueToIndex('') 316 | 317 | if doLogging: 318 | self.logger.debug('%s: %d (%s)' % \ 319 | (feature_name, feature_index, '')) 320 | else: 321 | if doLogging: 322 | self.logger.debug('%s: %d (%s)' % \ 323 | (feature_name, feature_index, \ 324 | self.feature_maps[feature_major_type] \ 325 | .indexToValue(parser.label(focus)))) 326 | 327 | elif fname == 'word': 328 | feature_major_type = 'word' 329 | if focus < 0 or focus >= parser.numTokens(): 330 | feature_index = \ 331 | self.feature_maps[feature_major_type] \ 332 | .valueToIndex('') 333 | 334 | if doLogging: 335 | self.logger.debug('%s: %d (%s)' % \ 336 | (feature_name, feature_index, '')) 337 | else: 338 | try: 339 | feature_index = \ 340 | self.feature_maps[feature_major_type] \ 341 | .valueToIndex(parser.getToken( 342 | focus).FORM) 343 | 344 | if doLogging: 345 | self.logger.debug('%s: %d (%s)' % \ 346 | (feature_name, feature_index, \ 347 | parser.getToken(focus).FORM)) 348 | except: # Out of Vocabulary 349 | feature_index = \ 350 | self.feature_maps[feature_major_type] \ 351 | .valueToIndex('') 352 | 353 | if doLogging: 354 | self.logger.debug('%s: %d (%s)' % \ 355 | (feature_name, feature_index, '')) 356 | elif fname == 'tag': 357 | feature_major_type = 'tag' 358 | if focus < 0 or focus >= parser.numTokens(): 359 | feature_index = self.feature_maps[feature_major_type] \ 360 | .valueToIndex('') 361 | 362 | if doLogging: 363 | self.logger.debug('%s: %d (%s)' % (feature_name, \ 364 | feature_index, '')) 365 | else: 366 | try: 367 | feature_index = \ 368 | self.feature_maps[feature_major_type] \ 369 | .valueToIndex(parser.getToken( 370 | focus).XPOSTAG) 371 | 372 | if doLogging: 373 | self.logger.debug('%s: %d (%s)' % \ 374 | (feature_name, feature_index, 375 | parser.getToken(focus).XPOSTAG)) 376 | except: # Out of Vocabulary 377 | feature_index = \ 378 | self.feature_maps[feature_major_type] \ 379 | .valueToIndex('') 380 | 381 | if doLogging: 382 | self.logger.debug('%s: %d (%s)' % \ 383 | (feature_name, feature_index, '')) 384 | else: 385 | assert None, 'unknown feature name \'' + fname + '\'' 386 | 387 | assert feature_name != None, 'feature name undetermined' 388 | assert feature_major_type != None, 'feature major type undetermined' 389 | assert feature_index != None, 'focus set but feature never requested' 390 | return FeatureType(feature_major_type, feature_name), feature_index 391 | -------------------------------------------------------------------------------- /feature_map.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Represents a feature encoded as a sparse index 3 | ''' 4 | class UnsortedIndexEncodedFeatureMap(object): 5 | def __init__(self): 6 | # base value doesn't exist 7 | # (don't set to 0, as 0 should be a valid index in that case) 8 | self.lastBaseValue = -1 9 | self.isFinalized = False 10 | self.indexToValueMap = dict() 11 | self.valueToIndexMap = dict() 12 | 13 | ''' 14 | Increment frequency for the specified term 15 | If we've never seen this term before, make an entry for it 16 | ''' 17 | def addTerm(self, term): 18 | assert not self.isFinalized 19 | 20 | self.lastBaseValue += 1 21 | self.indexToValueMap[self.lastBaseValue] = term 22 | self.valueToIndexMap[term] = self.lastBaseValue 23 | 24 | ''' 25 | Finalize indices and sort by descending frequency of each term, and then 26 | alphabetically 27 | Index 0 will be the most frequent term 28 | ''' 29 | def finalizeBaseValues(self): 30 | assert not self.isFinalized 31 | assert len(self.valueToIndexMap) == len(self.indexToValueMap), \ 32 | 'index<->value map length mismatch' 33 | self.isFinalized = True 34 | 35 | ''' 36 | Append special value after finalization, like , etc... 37 | ''' 38 | def appendSpecialValue(self, term): 39 | assert self.isFinalized 40 | if term in self.valueToIndexMap: 41 | return # no need to add another index for it 42 | 43 | newTermIndex = len(self.valueToIndexMap) 44 | assert newTermIndex not in self.indexToValueMap 45 | self.indexToValueMap[newTermIndex] = term 46 | self.valueToIndexMap[term] = newTermIndex 47 | 48 | assert len(self.valueToIndexMap) == len(self.indexToValueMap) 49 | 50 | def valueToIndex(self, v): 51 | assert self.isFinalized 52 | return self.valueToIndexMap[v] 53 | 54 | def indexToValue(self, i): 55 | assert self.isFinalized 56 | return self.indexToValueMap[i] 57 | 58 | ''' 59 | Get the number of possible unique values for this feature 60 | (optionally excluding special features) 61 | ''' 62 | def getDomainSize(self, includeSpecial=True): 63 | assert self.isFinalized 64 | if includeSpecial: 65 | return len(self.valueToIndexMap) 66 | else: 67 | return self.lastBaseValue + 1 68 | 69 | ''' 70 | Represents a feature encoded as a sparse index 71 | 72 | Sorts base values by frequency in descending order and then name in ascending 73 | order 74 | 75 | Sorting ensures equivalent behavior per run 76 | ''' 77 | class IndexEncodedFeatureMap(UnsortedIndexEncodedFeatureMap): 78 | def __init__(self): 79 | super().__init__() 80 | self.freq = dict() 81 | 82 | def addTerm(self, term): 83 | assert None, 'addTerm() not allowed in IndexEncodedFeatureMap' 84 | 85 | ''' 86 | Increment frequency for the specified term 87 | If we've never seen this term before, make an entry for it 88 | ''' 89 | def incrementTerm(self, term): 90 | assert not self.isFinalized 91 | if term not in self.freq: 92 | self.freq[term] = 0 93 | self.freq[term] += 1 94 | 95 | def loadFrom(self, fname): 96 | assert not self.isFinalized 97 | fd = open(fname, 'r', encoding='utf-8') 98 | contents = fd.read() 99 | fd.close() 100 | 101 | ln_num = 0 102 | itemCount = 0 103 | currentItem = 0 104 | for ln in contents: 105 | ln = ln.strip() 106 | if not ln: 107 | continue 108 | ln_num += 1 109 | if ln_num == 1: 110 | itemCount = int(ln) 111 | else: 112 | assert ln.count(' ') == 2 113 | 114 | term, freq = ln.split() 115 | term = term.strip() 116 | freq = int(freq) 117 | 118 | assert term not in self.freq, 'term already loaded' 119 | 120 | self.freq[term] = freq 121 | 122 | currentItem += 1 123 | if currentItem >= itemCount: 124 | break 125 | 126 | assert currentItem == itemCount, 'not all items loaded properly' 127 | 128 | # caller should do finalization 129 | # sets finalize flag and counts base values 130 | # self.finalizeBaseValues() 131 | 132 | def writeTo(self, fname): 133 | assert self.isFinalized 134 | fd = open(fname, 'w', encoding='utf-8') 135 | 136 | itemCount = self.lastBaseValue+1 # 0-based indexing 137 | fd.write('%d\n' % itemCount) 138 | 139 | for (i, term) in self.indexToValueMap.items(): 140 | if i > self.lastBaseValue: # don't write special values 141 | break 142 | fd.write('%s %d\n' % (term, self.freq[term])) 143 | 144 | fd.close() 145 | 146 | ''' 147 | Finalize indices and sort by descending frequency of each term, and then 148 | alphabetically 149 | Index 0 will be the most frequent term 150 | ''' 151 | def finalizeBaseValues(self): 152 | assert not self.isFinalized 153 | round1Items = [] 154 | 155 | allFreqs = set() 156 | for (termName, termFreq) in self.freq.items(): 157 | round1Items.append((termFreq, termName)) 158 | allFreqs.add(termFreq) 159 | allFreqs = list(allFreqs) 160 | allFreqs.sort(reverse=True) 161 | 162 | sortFinal = [] 163 | # iterate frequencies in descending order 164 | for f in allFreqs: 165 | round2Tmp = [] 166 | # find all items with this frequency and sort them by name, 167 | # ascending 168 | for (termFreq, termName) in round1Items: 169 | if termFreq == f: 170 | round2Tmp.append(termName) 171 | assert len(round2Tmp) > 0, 'term not found' 172 | round2Tmp.sort() 173 | # append all term names with this frequency, sorted by name 174 | sortFinal += round2Tmp 175 | 176 | assert len(self.freq) == len(sortFinal), 'missing items detected' 177 | assert len(set(sortFinal)) == len(sortFinal), 'duplicates detected' 178 | 179 | i = 0 180 | self.indexToValueMap = dict() 181 | for v in sortFinal: 182 | self.indexToValueMap[i] = v 183 | self.lastBaseValue = i 184 | i += 1 185 | 186 | self.valueToIndexMap = dict() 187 | for (i, v) in self.indexToValueMap.items(): 188 | self.valueToIndexMap[v] = i 189 | 190 | assert len(self.valueToIndexMap) == len(self.indexToValueMap), \ 191 | 'index<->value map length mismatch' 192 | self.isFinalized = True 193 | -------------------------------------------------------------------------------- /gold_parse_reader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from feature_extractor import SparseFeatureExtractor 3 | from sentence_batch import SentenceBatch 4 | from parser_state import ParserState 5 | from arc_standard_transition_system import ArcStandardTransitionSystem, \ 6 | ArcStandardTransitionState 7 | from arc_eager_transition_system import ArcEagerTransitionState, \ 8 | ArcEagerTransitionSystem 9 | 10 | ''' 11 | Verify that GoldParseReader parsed a sentence properly 12 | ''' 13 | def verifyGoldSentenceIntegrity(state): 14 | for k in range(state.numTokens()): 15 | assert state.head(k) == state.goldHead(k), '%d, %s, %d!=%d' % \ 16 | (k, state.getToken(k).FORM, state.head(k), state.goldHead(k)) 17 | 18 | ''' 19 | Provide a batch of gold sentences to the trainer 20 | 21 | Maintains batch_size slots of sentences, each one with its own parser state 22 | ''' 23 | class GoldParseReader(object): 24 | def __init__(self, input_corpus, batch_size, feature_strings, feature_maps, 25 | transition_system, epoch_print = True): 26 | self.input_corpus = input_corpus 27 | self.batch_size = batch_size 28 | self.feature_strings = feature_strings 29 | self.feature_maps = feature_maps 30 | self.epoch_print = epoch_print 31 | self.feature_extractor = SparseFeatureExtractor(self.feature_strings, 32 | self.feature_maps) 33 | self.sentence_batch = SentenceBatch(input_corpus, self.batch_size) 34 | self.parser_states = [None for i in range(self.batch_size)] 35 | self.arc_states = [None for i in range(self.batch_size)] 36 | 37 | if transition_system == 'arc-standard': 38 | self.transition_system = ArcStandardTransitionSystem() 39 | self.transition_state_class = ArcStandardTransitionState 40 | elif transition_system == 'arc-eager': 41 | self.transition_system = ArcEagerTransitionSystem() 42 | self.transition_state_class = ArcEagerTransitionState 43 | else: 44 | assert None, 'transition system must be arc-standard or arc-eager' 45 | 46 | self.logger = logging.getLogger('GoldParseReader') 47 | self.num_epochs = 0 48 | 49 | def state(self, i): 50 | assert i >= 0 and i < self.batch_size 51 | return self.parser_states[i] 52 | 53 | ''' 54 | Advance the sentence for slot i 55 | ''' 56 | def advanceSentence(self, i): 57 | self.logger.debug('Slot(%d): advance sentence' % i) 58 | assert i >= 0 and i < self.batch_size 59 | if(self.sentence_batch.advanceSentence(i)): 60 | self.parser_states[i] = ParserState(self.sentence_batch.sentence(i), 61 | self.feature_maps) 62 | # necessary for initializing and pushing root 63 | # keep arc_states in sync with parser_states 64 | self.arc_states[i] = \ 65 | self.transition_state_class(self.parser_states[i]) 66 | else: 67 | self.parser_states[i] = None 68 | self.arc_states[i] = None 69 | 70 | ''' 71 | Perform the next gold action for each state 72 | ''' 73 | def performActions(self): 74 | for i in range(self.batch_size): 75 | if self.state(i) != None: 76 | self.logger.debug('Slot(%d): perform actions' % i) 77 | 78 | nextGoldAction = \ 79 | self.transition_system.getNextGoldAction(self.state(i)) 80 | 81 | #print('nextGoldAction:', nextGoldAction) 82 | 83 | self.logger.debug('Slot(%d): perform action %d=%s' % 84 | (i, nextGoldAction, self.transition_system.actionAsString( 85 | nextGoldAction, self.state(i), self.feature_maps))) 86 | 87 | try: 88 | self.transition_system.performAction( 89 | action=nextGoldAction, 90 | state=self.state(i)) 91 | except: 92 | self.logger.debug( 93 | 'Slot(%d): invalid action at batch slot' % i) 94 | # This is probably because of a non-projective input 95 | # We could projectivize or remove it... 96 | self.transition_system.performAction( 97 | action=self.transition_system.getDefaultAction( 98 | self.state(i)), 99 | state=self.state(i)) 100 | 101 | ''' 102 | Concatenate and return feature bags for all sentence slots, grouped 103 | by feature major type 104 | 105 | Returns (None, None, None, ...) if no sentences left 106 | ''' 107 | def nextFeatureBags(self): 108 | self.performActions() 109 | for i in range(self.batch_size): 110 | if self.state(i) == None: 111 | continue 112 | 113 | while(self.transition_system.isFinalState(self.state(i))): 114 | verifyGoldSentenceIntegrity(self.state(i)) 115 | self.logger.debug('Advancing sentence ' + str(i)) 116 | self.advanceSentence(i) 117 | if self.state(i) == None: 118 | break 119 | 120 | if self.sentence_batch.size() == 0: 121 | self.num_epochs += 1 122 | if self.epoch_print: 123 | self.logger.info('Starting epoch ' + str(self.num_epochs)) 124 | self.sentence_batch.rewind() 125 | for i in range(self.batch_size): 126 | self.advanceSentence(i) 127 | 128 | # a little bit different from SyntaxNet: 129 | # we don't support feature groups 130 | # we automatically group together the similar types 131 | # features_output = [[] for i in range(self.feature_strings)] 132 | features_major_types = None 133 | features_output = None 134 | gold_actions = None 135 | 136 | statesToExtract = [] 137 | # Populate feature outputs 138 | for i in range(self.batch_size): 139 | if self.state(i) == None: 140 | continue 141 | statesToExtract.append(self.state(i)) 142 | 143 | # Populate feature outputs 144 | for i in range(self.batch_size): 145 | if self.state(i) == None: 146 | continue 147 | 148 | self.logger.debug('Slot(%d): extract features' % i) 149 | 150 | ''' 151 | If you want to enable more detailed logging, please set 152 | doLogging here. Disabled for performance. 153 | ''' 154 | fvec = self.feature_extractor.extract(self.state(i), \ 155 | doLogging=False) 156 | assert len(fvec.types) == len(self.feature_strings) 157 | major_types, ids = fvec.concatenateSimilarTypes() 158 | 159 | if features_output == None: 160 | features_major_types = [t for t in major_types] 161 | features_output = [[] for t in major_types] 162 | else: 163 | assert len(features_major_types) == len(major_types) 164 | assert len(features_output) == len(major_types) 165 | 166 | for k in range(len(features_major_types)): 167 | features_output[k] += ids[k] 168 | 169 | # Fill in gold actions 170 | for i in range(self.batch_size): 171 | if self.state(i) != None: 172 | if gold_actions == None: 173 | gold_actions = [] 174 | 175 | try: 176 | gold_actions.append( 177 | self.transition_system.getNextGoldAction(self.state(i))) 178 | except: 179 | self.logger.info('Warning: invalid batch slot') 180 | gold_actions.append( 181 | self.transition_system.getDefaultAction(self.state(i))) 182 | 183 | return features_major_types, features_output, gold_actions, \ 184 | self.num_epochs 185 | -------------------------------------------------------------------------------- /lexicon.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Represents a lexicon, which describes all encountered tokens and frequencies, 3 | along with unknown tokens. 4 | 5 | The lexicon is typically computed during training time. 6 | ''' 7 | 8 | from utils import * 9 | from conll_utils import ParsedConllFile, ParsedConllSentence, ParsedConllToken 10 | from feature_map import IndexEncodedFeatureMap 11 | 12 | class Lexicon(object): 13 | def __init__(self, modelParams): 14 | self.modelParams = modelParams 15 | 16 | self.featureMaps = None 17 | 18 | self.tagMap = IndexEncodedFeatureMap() 19 | self.labelMap = IndexEncodedFeatureMap() 20 | self.wordMap = IndexEncodedFeatureMap() 21 | 22 | 23 | ''' 24 | Compute a lexicon (using the training data) 25 | ''' 26 | def compute(self): 27 | projectivizeTrainingSet = self.modelParams.cfg \ 28 | ['projectivizeTrainingSet'] 29 | 30 | # parameters here must match parameters during corpus feature bag 31 | # generation (such as projectivization) 32 | trainingData = ParsedConllFile(keepMalformed=False, 33 | projectivize=projectivizeTrainingSet, logStats=True) 34 | # log stats here instead of during bag-of-features generation 35 | # because lexicon computation always happens during training 36 | 37 | trainingData.read(open(self.modelParams.trainingFile, 'r', 38 | encoding='utf-8').read()) 39 | 40 | for sentence in trainingData: 41 | for token in sentence.tokens: 42 | # for SyntaxNet, 43 | # normalization ONLY happens in lexicon builder 44 | # yet numbers and up as during training 45 | # interesting... 46 | form = normalizeDigits(token.FORM) 47 | 48 | self.wordMap.incrementTerm(form) 49 | self.tagMap.incrementTerm(token.XPOSTAG) 50 | self.labelMap.incrementTerm(token.DEPREL) 51 | 52 | self.finalizeLexicon() 53 | 54 | 55 | def read(self): 56 | self.tagMap = IndexEncodedFeatureMap().loadFrom( 57 | self.modelParams.getFilePath('tag-map')) 58 | self.labelMap = IndexEncodedFeatureMap().loadFrom( 59 | self.modelParams.getFilePath('label-map')) 60 | self.wordMap = IndexEncodedFeatureMap().loadFrom( 61 | self.modelParams.getFilePath('word-map')) 62 | 63 | # special values don't get saved, so we still need to finalize lexicon 64 | self.finalizeLexicon() 65 | 66 | 67 | def write(self): 68 | self.tagMap.writeTo(self.modelParams.getFilePath('tag-map')) 69 | self.labelMap.writeTo(self.modelParams.getFilePath('label-map')) 70 | self.wordMap.writeTo(self.modelParams.getFilePath('word-map')) 71 | 72 | ''' 73 | After done reading corpus... 74 | ''' 75 | def finalizeLexicon(self): 76 | self.wordMap.finalizeBaseValues() 77 | self.tagMap.finalizeBaseValues() 78 | self.labelMap.finalizeBaseValues() 79 | 80 | # order of special tokens matches SyntaxNet 81 | 82 | self.wordMap.appendSpecialValue("") 83 | self.tagMap.appendSpecialValue("") 84 | self.labelMap.appendSpecialValue("") 85 | 86 | self.wordMap.appendSpecialValue("") 87 | self.tagMap.appendSpecialValue("") 88 | self.labelMap.appendSpecialValue("") 89 | 90 | # FIXME: is in tag even possible? it seemed to happen in 91 | # testdata but not in UD_English 92 | # difference between stack.tag and stack.token.tag? 93 | #self.tagMap.appendSpecialValue("") 94 | self.labelMap.appendSpecialValue("") 95 | 96 | self.featureMaps = {'word': self.wordMap, 'tag': self.tagMap, 97 | 'label': self.labelMap} 98 | 99 | 100 | def getFeatureMaps(self): 101 | assert self.featureMaps != None, 'feature maps not yet created' 102 | return self.featureMaps 103 | -------------------------------------------------------------------------------- /model_parameters.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Contains all parameters for a model 3 | ''' 4 | FILE_TYPES = ['label-map', 'word-map', 'tag-map', 'parser-config', 'embeddings', 5 | 'feature-def', 'feature-bag-bin', 'training-corpus-hash', 6 | 'trained-config'] 7 | 8 | def fixPath(m): 9 | if m.endswith('/'): 10 | return m[:-1] 11 | else: 12 | return m 13 | 14 | class ModelParameters(object): 15 | def __init__(self, modelFolder): 16 | assert modelFolder != None 17 | 18 | self.modelFolder = modelFolder 19 | self.trainingFile = None 20 | #self.tuningFile = None 21 | self.testingFile = None 22 | self.cfg = None 23 | self.lexicon = None 24 | 25 | ''' 26 | Returns the filename for the requested file type 27 | Corresponds to files in SyntaxNet context 28 | e.g., word-map, label-map 29 | ''' 30 | def getFilePath(self, fileType): 31 | assert fileType in FILE_TYPES 32 | assert self.modelFolder != None 33 | return '%s/%s' % (fixPath(self.modelFolder), fileType) 34 | 35 | def isValidModel(self): 36 | return self.modelFolder != None and self.trainingFile != None and \ 37 | self.testingFile != None 38 | -------------------------------------------------------------------------------- /other/transition_system_test_framework.py: -------------------------------------------------------------------------------- 1 | from conll_utils import * 2 | from parser_state import ParserState 3 | from arc_standard_transition_system import ArcStandardTransitionState, \ 4 | ArcStandardTransitionSystem 5 | from arc_eager_transition_system import ArcEagerTransitionState, \ 6 | ArcEagerTransitionSystem 7 | from gold_parse_reader import GoldParseReader 8 | from decoded_parse_reader import DecodedParseReader 9 | from lexicon import Lexicon 10 | from model_parameters import ModelParameters 11 | import logging 12 | 13 | logger = logging.getLogger('TransitionSystemTest') 14 | 15 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', 16 | level=logging.INFO) 17 | 18 | # arc-standard, arc-eager 19 | system = 'arc-eager' 20 | trainingFile = '/home/andy/Downloads/arcs-py/sample-data/train.parses' 21 | 22 | if system == 'arc-standard': 23 | transition_system = ArcStandardTransitionSystem() 24 | transition_system_class = ArcStandardTransitionSystem 25 | transition_state_class = ArcStandardTransitionState 26 | elif system == 'arc-eager': 27 | transition_system = ArcEagerTransitionSystem() 28 | transition_system_class = ArcEagerTransitionSystem 29 | transition_state_class = ArcEagerTransitionState 30 | else: 31 | assert None, 'transition system must be arc-standard or arc-eager' 32 | 33 | 34 | def dynamicOracleTrainTest(parser_state): 35 | LUT = ['SHIFT', 'RIGHT', 'LEFT', 'REDUCE'] 36 | 37 | n = 0 38 | while not transition_system.isFinalState(parser_state): 39 | n += 1 40 | legal_transitions = transition_system_class.legal(parser_state) 41 | print('LEGAL ', ' '.join([LUT[p] for p in legal_transitions])) 42 | zero_cost = transition_system_class.dynamicOracle(parser_state, 43 | legal_transitions) 44 | print(str(n) + ' [ ' + ' '.join([LUT[z] for z in zero_cost]) + ' ]') 45 | 46 | if len(zero_cost) == 0: 47 | raise Exception('no zero cost') 48 | 49 | ## TODO: make it actually perform operation 50 | transition_system.performShift(parser_state) # FOR TESTING 51 | 52 | def __main__(): 53 | trainingCorpus = ParsedConllFile(keepMalformed=False, 54 | projectivize=True) 55 | 56 | trainingCorpus.read( \ 57 | open(trainingFile, 'r', 58 | encoding='utf-8').read()) 59 | 60 | # make fake model params, enough for lexicon builder 61 | # we still need feature_maps to use ParserState 62 | modelParams = ModelParameters('') 63 | modelParams.trainingFile = trainingFile 64 | modelParams.cfg = {'projectivizeTrainingSet': True} 65 | 66 | lexicon = Lexicon(modelParams) 67 | lexicon.compute() 68 | 69 | sentence = trainingCorpus.sentences[0] 70 | 71 | parser_state = ParserState(sentence, 72 | lexicon.getFeatureMaps()) 73 | 74 | # necessary for initializing and pushing root 75 | # (only initialize transition_state_class once!) 76 | # keep arc_state in sync with parser_state 77 | arc_state = transition_state_class(parser_state) 78 | 79 | dynamicOracleTrainTest(parser_state) 80 | 81 | __main__() 82 | -------------------------------------------------------------------------------- /parser-config.sample: -------------------------------------------------------------------------------- 1 | decaySteps = 4400 2 | averagingDecay = 0.9999 3 | learningRate = 0.08 4 | momentum = 0.85 5 | 6 | transitionSystem = 'arc-standard' 7 | #transitionSystem = 'arc-eager' 8 | 9 | # For sample training set 10 | # 931 feature bags at batch size 512 11 | # 6501 feature bags at batch size 64 (performs better) 12 | batchSize = 64 13 | 14 | # How many predictions to consider 15 | topK = 10 16 | 17 | projectivizeTrainingSet = True 18 | 19 | hiddenLayerSizes = [256, 256] 20 | 21 | embeddingSizes = {'tag': 32, 'label': 32, 'word': 64} 22 | 23 | # same as Parsey McParseface 24 | # https://github.com/tensorflow/models/blob/master/syntaxnet/syntaxnet/models/parsey_mcparseface/context.pbtxt 25 | featureStrings = [ 26 | 'stack.child(1).label', 27 | 'stack.child(1).sibling(-1).label', 28 | 'stack.child(-1).label', 29 | 'stack.child(-1).sibling(1).label', 30 | 'stack.child(2).label', 31 | 'stack.child(-2).label', 32 | 'stack(1).child(1).label', 33 | 'stack(1).child(1).sibling(-1).label', 34 | 'stack(1).child(-1).label', 35 | 'stack(1).child(-1).sibling(1).label', 36 | 'stack(1).child(2).label', 37 | 'stack(1).child(-2).label', 38 | 'input.tag', 39 | 'input(1).tag', 40 | 'input(2).tag', 41 | 'input(3).tag', 42 | 'stack.tag', 43 | 'stack.child(1).tag', 44 | 'stack.child(1).sibling(-1).tag', 45 | 'stack.child(-1).tag', 46 | 'stack.child(-1).sibling(1).tag', 47 | 'stack.child(2).tag', 48 | 'stack.child(-2).tag', 49 | 'stack(1).tag', 50 | 'stack(1).child(1).tag', 51 | 'stack(1).child(1).sibling(-1).tag', 52 | 'stack(1).child(-1).tag', 53 | 'stack(1).child(-1).sibling(1).tag', 54 | 'stack(1).child(2).tag', 55 | 'stack(1).child(-2).tag', 56 | 'stack(2).tag', 57 | 'stack(3).tag', 58 | 'input.word', 59 | 'input(1).word', 60 | 'input(2).word', 61 | 'input(3).word', 62 | 'stack.word', 63 | 'stack.child(1).word', 64 | 'stack.child(1).sibling(-1).word', 65 | 'stack.child(-1).word', 66 | 'stack.child(-1).sibling(1).word', 67 | 'stack.child(2).word', 68 | 'stack.child(-2).word', 69 | 'stack(1).word', 70 | 'stack(1).child(1).word', 71 | 'stack(1).child(1).sibling(-1).word', 72 | 'stack(1).child(-1).word', 73 | 'stack(1).child(-1).sibling(1).word', 74 | 'stack(1).child(2).word', 75 | 'stack(1).child(-2).word', 76 | 'stack(2).word', 77 | 'stack(3).word' 78 | ] 79 | -------------------------------------------------------------------------------- /parser_state.py: -------------------------------------------------------------------------------- 1 | # Translation of parser_state.cc from SyntaxNet 2 | 3 | from conll_utils import ParsedConllSentence, ParsedConllToken 4 | 5 | ''' 6 | Handles an individual parsing state within a sentence 7 | ''' 8 | class ParserState(object): 9 | def __init__(self, sentence, feature_maps): 10 | self.sentence_ = sentence 11 | self.stack_ = [] 12 | self.head_ = [] 13 | self.label_ = [] 14 | self.num_tokens_ = len(sentence.tokens) 15 | self.next_ = 0 16 | self.root_label_ = -1 # always use -1 as 17 | self.feature_maps = feature_maps 18 | 19 | # keep in sync with head_ and label_ 20 | self.arcs_ = [] 21 | 22 | for i in range(self.num_tokens_): 23 | self.head_.append(-1) 24 | self.label_.append(self.rootLabel()) 25 | 26 | def sentence(self): 27 | return self.sentence_ 28 | 29 | def numTokens(self): 30 | return self.num_tokens_ 31 | 32 | def rootLabel(self): 33 | return self.root_label_ 34 | 35 | def next(self): 36 | assert self.next_ >= -1 37 | assert self.next_ <= self.num_tokens_ 38 | return self.next_ 39 | 40 | def input(self, offset): 41 | index = self.next_ + offset 42 | if index >= -1 and index < self.num_tokens_: 43 | return index 44 | else: 45 | return -2 46 | 47 | def advance(self): 48 | assert self.next_ < self.num_tokens_ 49 | self.next_ += 1 50 | 51 | def endOfInput(self): 52 | return self.next_ == self.num_tokens_ 53 | 54 | def push(self, index): 55 | assert len(self.stack_) <= self.num_tokens_ 56 | self.stack_.append(index) 57 | 58 | def pop(self): 59 | assert len(self.stack_) > 0 60 | return self.stack_.pop() 61 | 62 | def top(self): 63 | assert len(self.stack_) > 0 64 | return self.stack_[-1] 65 | 66 | def stack(self, position): 67 | if position < 0: 68 | return -2 69 | 70 | index = len(self.stack_) - 1 - position 71 | if index < 0: 72 | return -2 73 | else: 74 | return self.stack_[index] 75 | 76 | def stackSize(self): 77 | return len(self.stack_) 78 | 79 | def stackEmpty(self): 80 | return len(self.stack_) == 0 81 | 82 | def head(self, index): 83 | assert index >= -1 84 | assert index < self.num_tokens_ 85 | if index == -1: 86 | return -1 87 | else: 88 | return self.head_[index] 89 | 90 | def label(self, index): 91 | assert index >= -1 92 | assert index < self.num_tokens_ 93 | if index == -1: 94 | return self.rootLabel() 95 | else: 96 | return self.label_[index] 97 | 98 | def parent(self, index, n): 99 | assert index >= -1 100 | assert index < self.num_tokens_ 101 | while (n > 0): 102 | n -= 1 103 | index = self.head(index) 104 | 105 | return index 106 | 107 | def leftmostChild(self, index, n): 108 | assert index >= -1 109 | assert index < self.num_tokens_ 110 | while (n > 0): 111 | n -= 1 112 | i = -1 113 | while i < index: 114 | if self.head(i) == index: 115 | break 116 | i += 1 117 | if i == index: 118 | return -2 119 | index = i 120 | return index 121 | 122 | def rightmostChild(self, index, n): 123 | assert index >= -1 124 | assert index < self.num_tokens_ 125 | while (n > 0): 126 | n -= 1 127 | i = self.num_tokens_ - 1 128 | while i > index: 129 | if self.head(i) == index: 130 | break 131 | i -= 1 132 | if i == index: 133 | return -2 134 | index = i 135 | return index 136 | 137 | def leftSibling(self, index, n): 138 | assert index >= -1 139 | assert index < self.num_tokens_ 140 | if index == -1 and n > 0: 141 | return -2 142 | i = index 143 | while n > 0: 144 | i -= 1 145 | if i == -1: 146 | return -2 147 | if self.head(i) == self.head(index): 148 | n -= 1 149 | return i 150 | 151 | def rightSibling(self, index, n): 152 | assert index >= -1 153 | assert index < self.num_tokens_ 154 | i = index 155 | while n > 0: 156 | i += 1 157 | if i == self.num_tokens_: 158 | return -2 159 | if self.head(i) == self.head(index): 160 | n -= 1 161 | return i 162 | 163 | def addArc(self, index, head, label): 164 | assert index >= 0 165 | assert index < self.num_tokens_ 166 | self.head_[index] = head 167 | self.label_[index] = label 168 | self.arcs_.append((head, label, index)) 169 | 170 | def goldHead(self, index): 171 | assert index >= -1 172 | assert index < self.num_tokens_ 173 | if index == -1: 174 | return -1 175 | offset = 0 176 | gold_head = self.getToken(index).HEAD 177 | if gold_head == -1: 178 | return -1 179 | else: 180 | return gold_head - offset 181 | 182 | def goldLabel(self, index): 183 | assert index >= -1 184 | assert index < self.num_tokens_ 185 | if index == -1: 186 | return self.rootLabel() 187 | 188 | try: 189 | gold_label = self.feature_maps['label'] \ 190 | .valueToIndex(self.getToken(index).DEPREL) 191 | 192 | assert gold_label >= 0 193 | 194 | if gold_label > self.feature_maps['label'].lastBaseValue: 195 | # if strange value, match SyntaxNet behavior 196 | return self.rootLabel() 197 | else: 198 | return gold_label 199 | except: 200 | # if strange value, match SyntaxNet behavior 201 | return self.rootLabel() 202 | 203 | def getToken(self, index): 204 | if index == -1: 205 | return ParsedConllToken() 206 | return self.sentence_.tokens[index] 207 | 208 | ## TODO: what about index==-1?? 209 | def hasHead(self, index): 210 | assert index >= 0 211 | assert index < self.num_tokens_ 212 | return self.head_[index] != -1 213 | 214 | # TODO: cache or something? 215 | # finds dependents 216 | def goldDeps(self, index): 217 | deps = {} 218 | for dep in range(0, self.num_tokens): 219 | head = self.goldHead(dep) 220 | #if dep != head: # in case of root 221 | if head not in deps: 222 | deps[head] = [] 223 | deps[head].append(dep) 224 | return deps 225 | -------------------------------------------------------------------------------- /projectivize_filter.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Projectivize Filter 3 | (document_filters.cc) 4 | 5 | Check whether the given sentence is projective or not. 6 | 7 | Return value is whether the sentence was originally projective or not. 8 | - If it wasn't, return False 9 | - If it was, return True 10 | 11 | projectivize parameter: whether or not to fix sentence to be projective 12 | (does not affect return value) 13 | ''' 14 | 15 | import copy 16 | import logging 17 | logger = logging.getLogger('ProjectivizeFilter') 18 | 19 | def checkProjectivity(sentence, projectivize=False): 20 | if projectivize: 21 | oldsentence = copy.deepcopy(sentence) 22 | 23 | wasProjective = True 24 | num_tokens = len(sentence.tokens) 25 | 26 | # Left and right boundaries for arcs. The left and right ends of an arc are 27 | # bounded by the arcs that pass over it. If an arc exceeds these bounds it 28 | # will cross an arc passing over it, making it a non-projective arc. 29 | 30 | left = [None for i in range(num_tokens)] 31 | right = [None for i in range(num_tokens)] 32 | 33 | # Lift the shortest non-projective arc until the document is projective. 34 | while True: 35 | # Initialize boundaries to the whole document for all arcs. 36 | for i in range(num_tokens): 37 | left[i] = -1 38 | right[i] = num_tokens - 1 39 | 40 | # Find left and right bounds for each token. 41 | for i in range(num_tokens): 42 | head_index = sentence.tokens[i].HEAD 43 | 44 | # Find left and right end of arc 45 | l = min(i, head_index) 46 | r = max(i, head_index) 47 | 48 | # Bound all tokens under the arc. 49 | for j in range(l+1, r): 50 | if left[j] < l: 51 | left[j] = l 52 | if right[j] > r: 53 | right[j] = r 54 | 55 | # Find deepest non-projective arc. 56 | deepest_arc = -1 57 | max_depth = -1 58 | 59 | # The non-projective arcs are those that exceed their bounds. 60 | for i in range(num_tokens): 61 | head_index = sentence.tokens[i].HEAD 62 | 63 | if head_index == -1: 64 | # any crossing arc must be deeper 65 | continue 66 | 67 | l = min(i, head_index) 68 | r = max(i, head_index) 69 | 70 | left_bound = max(left[l], left[r]) 71 | right_bound = min(right[l], right[r]) 72 | 73 | if (l < left_bound) or (r > right_bound): 74 | # Found non-projective arc. 75 | logger.debug('Found non-projective arc') 76 | wasProjective = False 77 | if not projectivize: 78 | return wasProjective 79 | 80 | # Pick the deepest as the best candidate for lifting. 81 | depth = 0 82 | j = i 83 | while j != -1: 84 | depth += 1 85 | j = sentence.tokens[j].HEAD 86 | 87 | if depth > max_depth: 88 | deepest_arc = i 89 | max_depth = depth 90 | 91 | # If there are no more non-projective arcs we are done. 92 | if deepest_arc == -1: 93 | if not wasProjective: 94 | logger.debug('Projectivized non-projective arc') 95 | logger.debug('Before\n' + oldsentence.toFileOutput()) 96 | logger.debug('After\n' + sentence.toFileOutput()) 97 | return wasProjective 98 | 99 | # Lift non-projective arc. 100 | lifted_head = sentence.tokens[sentence.tokens[deepest_arc].HEAD].HEAD 101 | sentence.tokens[deepest_arc].HEAD = lifted_head 102 | 103 | assert None 104 | -------------------------------------------------------------------------------- /sentence_batch.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Provide a batch of sentences to the trainer 3 | 4 | Maintains batch_size slots of sentences, each one with its own parser state 5 | ''' 6 | 7 | from conll_utils import ParsedConllFile 8 | 9 | class SentenceBatch(object): 10 | def __init__(self, input_corpus, batch_size=50): 11 | assert type(input_corpus) is ParsedConllFile 12 | assert len(input_corpus.sentences) > 0, \ 13 | 'corpus contains no valid sentences ' \ 14 | '(or please call read() on input_corpus beforehand)' 15 | self.input_corpus = input_corpus 16 | self.batch_size = batch_size 17 | self.rewind() 18 | 19 | def rewind(self): 20 | # so that sentence can advance to 0 from the beginning! 21 | self.highest_sentence_index = -1 22 | self.sentences = [None for i in range(self.batch_size)] 23 | self.num_active = 0 24 | 25 | ''' 26 | Return current number of non-null sentences in the batch 27 | ''' 28 | def size(self): 29 | return self.num_active 30 | 31 | def sentence(self, index): 32 | assert index >= 0 and index < self.batch_size, \ 33 | 'batch index out of bounds' 34 | return self.sentences[index] 35 | 36 | def advanceSentence(self, index): 37 | assert index >= 0 and index < self.batch_size, \ 38 | 'batch index out of bounds' 39 | 40 | if self.sentences[index] == None: 41 | self.num_active += 1 42 | 43 | if (self.highest_sentence_index+1) >= len(self.input_corpus.sentences): 44 | # EOF reached 45 | self.num_active -= 1 46 | return False 47 | 48 | self.highest_sentence_index += 1 49 | self.sentences[index] = \ 50 | self.input_corpus.sentences[self.highest_sentence_index] 51 | return True 52 | -------------------------------------------------------------------------------- /training_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # the sample provided training corpus contains 6501 feature bags with this configuration 3 | mkdir -p /tmp/testmodelEN 4 | cp parser-config.sample /tmp/testmodelEN/parser-config 5 | python3 dep_parser.py /tmp/testmodelEN corpus/en-ud-train.conllu corpus/en-ud-test.conllu --train --epochs 5 --restart 6 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops import control_flow_ops as cf 3 | 4 | ''' 5 | Replace all digits with 9s like SyntaxNet 6 | ''' 7 | def normalizeDigits(form): 8 | newform = '' 9 | for i in range(len(form)): 10 | if ord(form[i]) >= ord('0') and ord(form[i]) <= ord('9'): 11 | newform += '9' 12 | else: 13 | newform += form[i] 14 | return newform 15 | 16 | ''' 17 | Gets array shape of dynamically shaped tensors 18 | 19 | Ex. 20 | dense_golden = tensorPrintShape(dense_golden, [dense_golden], 21 | 'dense_golden shape') 22 | ''' 23 | def tensorPrintShape(inp, data, comment): 24 | def np_print(*args): 25 | for x in args: 26 | print(comment, x.shape) 27 | return cf.with_dependencies([tf.py_func(np_print, data, [])], inp) 28 | 29 | ''' 30 | Ex. 31 | dense_golden = tensorPrint(dense_golden, [dense_golden], 'dense_golden data') 32 | ''' 33 | def tensorPrint(inp, data, comment): 34 | def np_print(*args): 35 | for x in args: 36 | print(comment, x) 37 | return cf.with_dependencies([tf.py_func(np_print, data, [])], inp) 38 | 39 | tensorDumpValsCallCount = {} 40 | 41 | ''' 42 | Ex. 43 | dense_golden = tensorDumpVals(dense_golden, [dense_golden], 44 | '/tmp/ash_dense_golden_1', 1) 45 | ''' 46 | # print only the desired_iter'th time the function is called (1-based) 47 | # for this particular filename 48 | def tensorDumpVals(inp, data, fname, desired_iter): 49 | global tensorDumpValsCallCount 50 | 51 | def np_print(*args): 52 | global tensorDumpValsCallCount 53 | 54 | if fname not in tensorDumpValsCallCount: 55 | tensorDumpValsCallCount[fname] = 0 56 | tensorDumpValsCallCount[fname] += 1 57 | 58 | # only execute for the iteration # desired 59 | if tensorDumpValsCallCount[fname] == desired_iter: 60 | fd = open(fname, 'w') 61 | 62 | for x in args: 63 | for elem in x.flatten(): 64 | fd.write('%.8f\n' % elem) 65 | 66 | fd.close() 67 | 68 | return cf.with_dependencies([tf.py_func(np_print, data, [])], inp) 69 | 70 | ''' 71 | Ex. 72 | dense_golden = tensorDumpValsAllIter(dense_golden, [dense_golden], 73 | '/tmp/ash_dense_golden') 74 | ''' 75 | 76 | def tensorDumpValsAllIter(inp, data, fname): 77 | global tensorDumpValsCallCount 78 | 79 | def np_print(*args): 80 | global tensorDumpValsCallCount 81 | 82 | if fname not in tensorDumpValsCallCount: 83 | tensorDumpValsCallCount[fname] = 0 84 | tensorDumpValsCallCount[fname] += 1 85 | 86 | fd = open('%s_%04d' % (fname, tensorDumpValsCallCount[fname]), 'w') 87 | 88 | for x in args: 89 | for elem in x.flatten(): 90 | fd.write('%.8f\n' % elem) 91 | 92 | fd.close() 93 | 94 | return cf.with_dependencies([tf.py_func(np_print, data, [])], inp) 95 | 96 | ''' 97 | See SyntaxNet utils.h 98 | ''' 99 | kPunctuation = [ 100 | (33, 35), (37, 42), (44, 47), (58, 59), 101 | (63, 64), (91, 93), (95, 95), (123, 123), 102 | (125, 125), (161, 161), (171, 171), (183, 183), 103 | (187, 187), (191, 191), (894, 894), (903, 903), 104 | (1370, 1375), (1417, 1418), (1470, 1470), (1472, 1472), 105 | (1475, 1475), (1478, 1478), (1523, 1524), (1548, 1549), 106 | (1563, 1563), (1566, 1567), (1642, 1645), (1748, 1748), 107 | (1792, 1805), (2404, 2405), (2416, 2416), (3572, 3572), 108 | (3663, 3663), (3674, 3675), (3844, 3858), (3898, 3901), 109 | (3973, 3973), (4048, 4049), (4170, 4175), (4347, 4347), 110 | (4961, 4968), (5741, 5742), (5787, 5788), (5867, 5869), 111 | (5941, 5942), (6100, 6102), (6104, 6106), (6144, 6154), 112 | (6468, 6469), (6622, 6623), (6686, 6687), (8208, 8231), 113 | (8240, 8259), (8261, 8273), (8275, 8286), (8317, 8318), 114 | (8333, 8334), (9001, 9002), (9140, 9142), (10088, 10101), 115 | (10181, 10182), (10214, 10219), (10627, 10648), (10712, 10715), 116 | (10748, 10749), (11513, 11516), (11518, 11519), (11776, 11799), 117 | (11804, 11805), (12289, 12291), (12296, 12305), (12308, 12319), 118 | (12336, 12336), (12349, 12349), (12448, 12448), (12539, 12539), 119 | (64830, 64831), (65040, 65049), (65072, 65106), (65108, 65121), 120 | (65123, 65123), (65128, 65128), (65130, 65131), (65281, 65283), 121 | (65285, 65290), (65292, 65295), (65306, 65307), (65311, 65312), 122 | (65339, 65341), (65343, 65343), (65371, 65371), (65373, 65373), 123 | (65375, 65381), (65792, 65793), (66463, 66463), (68176, 68184) 124 | ] 125 | 126 | ''' 127 | Determines if the specified unicode ordinal is punctuation or not 128 | ''' 129 | def isPunctuation(uni_ord): 130 | assert type(uni_ord) is int 131 | i = 0 132 | while kPunctuation[i][0] > 0: 133 | if uni_ord < kPunctuation[i][0]: 134 | return False 135 | if uni_ord <= kPunctuation[i][1]: 136 | return True 137 | i += 1 138 | return False 139 | 140 | ''' 141 | Returns true if word consists of punctuation characters. 142 | ''' 143 | def isPunctuationToken(word): 144 | for c in word: 145 | if not isPunctuation(ord(c)): 146 | return False 147 | return True 148 | 149 | ''' 150 | Determine if tag is a punctuation tag. 151 | ''' 152 | def isPunctuationTag(tag): 153 | # match SyntaxNet behavior 154 | #if len(tag) == 0: 155 | # return False 156 | for c in tag: 157 | if (c != ',' and c != ':' and c != '.' and c != '\'' and c != '`'): 158 | return False 159 | return True 160 | 161 | ''' 162 | Returns true if tag is non-empty and has only punctuation or parens 163 | symbols. 164 | ''' 165 | def isPunctuationTagOrParens(tag): 166 | if len(tag) == 0: 167 | return False 168 | for c in tag: 169 | if (c != '(' and c != ')' and c != ',' and c != ':' and c != '.' and \ 170 | c != '\'' and c != '`'): 171 | return False 172 | return True 173 | 174 | # FIXME: empty tags might show as '_' in CoNLL. '_' behavior is not 175 | # well-defined in the specification. 176 | 177 | ''' 178 | Return whether or not we should score a token based on the current 179 | scoring strategy 180 | ''' 181 | def shouldScoreToken(word, tag, scoring_strategy): 182 | if scoring_strategy == 'default': 183 | return len(tag) == 0 or not isPunctuationTag(tag) 184 | elif scoring_strategy == 'conllx': 185 | return not isPunctuationToken(word) 186 | elif scoring_strategy == 'ignore_parens': 187 | return not isPunctuationTagOrParens(tag) 188 | assert None, 'unknown scoring strategy: ' + scoring_strategy 189 | -------------------------------------------------------------------------------- /well_formed_filter.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Well-Formed Filter 3 | (document_filters.cc) 4 | 5 | Check that input is single-root, connected, acyclic, and projective. 6 | ''' 7 | import logging 8 | logger = logging.getLogger('WellFormedFilter') 9 | from projectivize_filter import checkProjectivity 10 | 11 | ''' 12 | Determine whether all HEADs are within the bounds of the sentence 13 | ''' 14 | def allHeadsExist(sentence): 15 | minIndex = -1 # root token 16 | maxIndex = len(sentence.tokens)-1 17 | 18 | for t in sentence.tokens: 19 | if t.HEAD < minIndex or t.HEAD > maxIndex: 20 | return False 21 | 22 | return True 23 | 24 | ''' 25 | Determine whether the sentence is single rooted 26 | ''' 27 | def isSingleRooted(sentence): 28 | allHeads = [] 29 | for t in sentence.tokens: 30 | if t.HEAD == -1: 31 | allHeads.append(t) 32 | return len(allHeads) == 1 33 | 34 | ''' 35 | Determine whether or not the sentence has a cycle (in HEADs) 36 | ''' 37 | def hasCycle(sentence): 38 | visited = [-1 for t in sentence.tokens] 39 | 40 | for i in range(len(sentence.tokens)): 41 | # Already visited node 42 | if visited[i] != -1: 43 | continue 44 | 45 | t = i 46 | while t != -1: 47 | if visited[t] == -1: 48 | # If it is not visited yet, mark it. 49 | visited[t] = i 50 | elif visited[t] < i: 51 | # If the index number is smaller than index and not -1, the 52 | # token has already been visited. 53 | break 54 | else: 55 | # Loop detected 56 | return True 57 | t = sentence.tokens[t].HEAD 58 | 59 | return False 60 | 61 | class WellFormedFilter(object): 62 | def __init__(self): 63 | self.nonProjectiveCount = 0 64 | self.projectivizedCount = 0 65 | self.nonWellFormedCount = 0 66 | 67 | ''' 68 | Determine whether the sentence can be parsed by arc-standard and arc-eager 69 | or not 70 | 71 | projectivize: whether to make non-projective sentences projective 72 | ''' 73 | def isWellFormed(self, sentence, projectivize=False): 74 | if len(sentence.tokens) == 0: 75 | logger.debug('Not well-formed: token length is zero') 76 | logger.debug('"'+sentence.toSimpleRepresentation()+'"') 77 | self.nonWellFormedCount += 1 78 | return False 79 | 80 | if not allHeadsExist(sentence): 81 | logger.debug('Not well-formed: not all HEADs exist as tokens') 82 | logger.debug('"'+sentence.toSimpleRepresentation()+'"') 83 | self.nonWellFormedCount += 1 84 | return False 85 | 86 | if not isSingleRooted(sentence): 87 | logger.debug('Not well-formed: tree doesn\'t have single ROOT') 88 | logger.debug('"'+sentence.toSimpleRepresentation()+'"') 89 | self.nonWellFormedCount += 1 90 | return False 91 | 92 | if hasCycle(sentence): 93 | logger.debug('Not well-formed: tree has a cycle') 94 | logger.debug('"'+sentence.toSimpleRepresentation()+'"') 95 | self.nonWellFormedCount += 1 96 | return False 97 | 98 | if not checkProjectivity(sentence, projectivize=projectivize): 99 | self.nonProjectiveCount += 1 100 | 101 | # if it wasn't projective 102 | if not projectivize: 103 | # ... and we didn't projectivize it... then it's invalid 104 | logger.debug('Not well-formed: non-projective and' \ 105 | ' projectivize disabled') 106 | logger.debug('"'+sentence.toSimpleRepresentation()+'"') 107 | 108 | # only count them as non-well-formed when projectivize is off 109 | self.nonWellFormedCount += 1 110 | return False 111 | else: 112 | # we succesfully projectivized a non-projective sentence 113 | # consider well-formed 114 | self.projectivizedCount += 1 115 | 116 | # if we did projectivize it, then we can keep going 117 | 118 | return True 119 | --------------------------------------------------------------------------------