├── README.md ├── fileio.py ├── fx.py ├── parse.py └── sample-data ├── test.parses └── train.parses /README.md: -------------------------------------------------------------------------------- 1 | arcs-py 2 | ======= 3 | 4 | Python implementation of Arc-Eager and Arc-Hybrid Greedy Dependency Parsing trained with a dynamic oracle described in: 5 | 6 | __Goldberg, Yoav, and Joakim Nivre. "Training Deterministic Parsers with Non-Deterministic Oracles." (2013)__ 7 | 8 | __Goldberg, Yoav, and Joakim Nivre. "A Dynamic Oracle for Arc-Eager Dependency Parsing" (2012)__ 9 | 10 | The sample data comes from Question Bank and a sample of PTB provided by NLTK in the corpora section, which I converted to a labeled dependency CONLL file using LTH converter and David Vadas' patches. The sample data also contains non-projective parses, which are ignored by the sample driver program. 11 | 12 | If you are looking for a modern flexible dependency parser that supports non-projective parses and is near SoTA, you probably want to use a neural parser architecture like the one in [mead](https://github.com/dpressel/mead-baseline) 13 | -------------------------------------------------------------------------------- /fileio.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | WORD = 0 4 | POS = 1 5 | HEAD = 2 6 | LABEL = 3 7 | 8 | 9 | def read_conll_deps(f): 10 | 11 | sentences = [] 12 | 13 | with open(f) as csvfile: 14 | reader = csv.reader(csvfile, delimiter='\t', quoting=csv.QUOTE_NONE) 15 | 16 | sentence = [] 17 | 18 | for row in reader: 19 | if len(row) == 0: 20 | sentence = [tok if tok[HEAD] is not -1 else (tok[WORD], tok[POS], len(sentence), tok[LABEL]) for tok in sentence] 21 | sentences.append(sentence) 22 | sentence = [] 23 | continue 24 | sentence.append((row[1].lower(), row[3], int(row[6]) - 1, row[7])) 25 | 26 | return sentences 27 | -------------------------------------------------------------------------------- /fx.py: -------------------------------------------------------------------------------- 1 | NULL_TUPLE = ('NULL', 'NULL') 2 | 3 | WORD = 0 4 | POS = 1 5 | HEAD = 2 6 | LABEL = 3 7 | 8 | 9 | def head_children(arcs, h, sent): 10 | children = list(filter(lambda x: x[0] == h, arcs)) 11 | if len(children): 12 | lc1 = min(d for (h, d) in children) 13 | rc1 = max(d for (h, d) in children) 14 | return sent[lc1], sent[rc1] 15 | return NULL_TUPLE, NULL_TUPLE 16 | 17 | def baseline(conf): 18 | sentence = conf.sentence 19 | fv = {} 20 | 21 | b0w = "NULL" 22 | b0p = "NULL" 23 | b10p = "NULL" 24 | b1w = "NULL" 25 | b1p = "NULL" 26 | b2w = "NULL" 27 | b2p = "NULL" 28 | s0w = "NULL" 29 | s0p = "NULL" 30 | s10p = "NULL" 31 | sr0p = "NULL" 32 | sh0p = "NULL" 33 | heads = dict((arc[1], arc[0]) for arc in conf.arcs) 34 | 35 | if len(conf.buffer) > 0: 36 | b0_pos = conf.buffer[0] 37 | if b0_pos < len(sentence): 38 | b0w = sentence[b0_pos][WORD] 39 | b0p = sentence[b0_pos][POS] 40 | else: 41 | b0w = "ROOT" 42 | b0p = "ROOT" 43 | hc = head_children(conf.arcs, b0_pos, sentence) 44 | left_most = hc[0] 45 | b10p = left_most[POS] 46 | if len(conf.buffer) > 1: 47 | b1_pos = conf.buffer[1] 48 | if b1_pos < len(sentence): 49 | b1w = sentence[b1_pos][WORD] 50 | b1p = sentence[b1_pos][POS] 51 | else: 52 | b1w = "ROOT" 53 | b1p = "ROOT" 54 | if len(conf.buffer) > 2: 55 | b2_pos = conf.buffer[2] 56 | if b2_pos < len(sentence): 57 | b2w = sentence[b2_pos][WORD] 58 | b2p = sentence[b2_pos][POS] 59 | else: 60 | b2w = "ROOT" 61 | b2p = "ROOT" 62 | 63 | if len(conf.stack) > 0: 64 | s0_pos = conf.stack[-1] 65 | if s0_pos < len(sentence): 66 | s0w = sentence[s0_pos][WORD] 67 | s0p = sentence[s0_pos][POS] 68 | else: 69 | s0w = "ROOT" 70 | s0p = "ROOT" 71 | hc = head_children(conf.arcs, s0_pos, sentence) 72 | left_most = hc[0] 73 | s10p = left_most[POS] 74 | right_most = hc[1] 75 | sr0p = right_most[POS] 76 | 77 | sh0p = "NULL" 78 | if s0_pos in heads: 79 | sh0p = sentence[heads[s0_pos]][POS] 80 | 81 | b0wp = b0w + "/" + b0p 82 | b1wp = b1w + "/" + b1p 83 | s0wp = s0w + "/" + s0p 84 | b2wp = b2w + "/" + b2p 85 | 86 | fv["s0wp=" + s0wp] = 1 87 | fv["s0w=" + s0w] = 1 88 | fv["s0p=" + s0p] = 1 89 | fv["b0wp=" + b0wp] = 1 90 | fv["b0w=" + b0w] = 1 91 | fv["b0p=" + b0p] = 1 92 | fv["b1wp=" + b1wp] = 1 93 | fv["b1w=" + b1w] = 1 94 | fv["b1p=" + b1p] = 1 95 | fv["b2wp=" + b2wp] = 1 96 | fv["b2w=" + b2w] = 1 97 | fv["b2p=" + b2p] = 1 98 | 99 | s0wp_b0wp = s0wp + ";" + b0wp 100 | s0wp_b0w = s0wp + ";" + b0w 101 | s0w_b0wp = s0w + ";" + b0wp 102 | s0wp_b0p = s0wp + ";" + b0p 103 | s0p_b0wp = s0p + ";" + b0wp 104 | s0w_b0w = s0w + ";" + b0w 105 | s0p_b0p = s0p + ";" + b0p 106 | b0p_b1p = b0p + ";" + b1p 107 | 108 | fv["s0wp_b0wp=" + s0wp_b0wp] = 1 109 | fv["s0wp_b0w=" + s0wp_b0w] = 1 110 | fv["s0w_b0wp=" + s0w_b0wp] = 1 111 | fv["s0wp_b0p=" + s0wp_b0p] = 1 112 | fv["s0p_b0wp=" + s0p_b0wp] = 1 113 | fv["s0w_b0w=" + s0w_b0w] = 1 114 | fv["s0p_b0p=" + s0p_b0p] = 1 115 | fv["b0p_b1p" + b0p_b1p] = 1 116 | 117 | b0p_b1p_b2p = b0p + ";" + b1p + ";" + b2p 118 | s0p_b0p_b1p = s0p + ";" + b0p + ";" + b1p 119 | sh0p_s0p_b0p = sh0p + ";" + s0p + ";" + b0p 120 | s0p_s10p_b0p = s0p + ";" + s10p + ";" + b0p 121 | s0p_sr0p_b0p = s0p + ";" + sr0p + ";" + b0p 122 | s0p_b0p_b10p = s0p + ";" + b0p + ";" + b10p 123 | fv["b0p_b1p_b2p=" + b0p_b1p_b2p] = 1 124 | fv["s0p_b0p_b1p=" + s0p_b0p_b1p] = 1 125 | fv["sh0p_s0p_b0p=" + sh0p_s0p_b0p] = 1 126 | fv["s0p_s10p_b0p=" + s0p_s10p_b0p] = 1 127 | fv["s0p_sr0p_b0p=" + s0p_sr0p_b0p] = 1 128 | fv["s0p_b0p_b10p" + s0p_b0p_b10p] = 1 129 | 130 | return fv 131 | 132 | # This was intended to be a lot more impressive than the baseline, but so far its only adding 2 features 133 | # soon, though! 134 | def ex(conf): 135 | sentence = conf.sentence 136 | fv = {} 137 | 138 | b0w = "NULL" 139 | b0p = "NULL" 140 | b10p = "NULL" 141 | b1w = "NULL" 142 | b1p = "NULL" 143 | b2w = "NULL" 144 | b2p = "NULL" 145 | s0w = "NULL" 146 | s0p = "NULL" 147 | 148 | s1w = "NULL" 149 | s1p = "NULL" 150 | 151 | s10p = "NULL" 152 | sr0p = "NULL" 153 | sh0p = "NULL" 154 | heads = dict((arc[1], arc[0]) for arc in conf.arcs) 155 | 156 | if len(conf.buffer) > 0: 157 | b0_pos = conf.buffer[0] 158 | if b0_pos < len(sentence): 159 | b0w = sentence[b0_pos][WORD] 160 | b0p = sentence[b0_pos][POS] 161 | else: 162 | b0w = "ROOT" 163 | b0p = "ROOT" 164 | hc = head_children(conf.arcs, b0_pos, sentence) 165 | left_most = hc[0] 166 | b10p = left_most[POS] 167 | if len(conf.buffer) > 1: 168 | b1_pos = conf.buffer[1] 169 | if b1_pos < len(sentence): 170 | b1w = sentence[b1_pos][WORD] 171 | b1p = sentence[b1_pos][POS] 172 | else: 173 | b1w = "ROOT" 174 | b1p = "ROOT" 175 | if len(conf.buffer) > 2: 176 | b2_pos = conf.buffer[2] 177 | if b2_pos < len(sentence): 178 | b2w = sentence[b2_pos][WORD] 179 | b2p = sentence[b2_pos][POS] 180 | else: 181 | b2w = "ROOT" 182 | b2p = "ROOT" 183 | 184 | if len(conf.stack) > 0: 185 | s0_pos = conf.stack[-1] 186 | if s0_pos < len(sentence): 187 | s0w = sentence[s0_pos][WORD] 188 | s0p = sentence[s0_pos][POS] 189 | else: 190 | s0w = "ROOT" 191 | s0p = "ROOT" 192 | 193 | if len(conf.stack) > 1: 194 | s1_pos = conf.stack[-2] 195 | s1w = sentence[s1_pos][WORD] 196 | s1p = sentence[s1_pos][POS] 197 | 198 | hc = head_children(conf.arcs, s0_pos, sentence) 199 | left_most = hc[0] 200 | s10p = left_most[POS] 201 | right_most = hc[1] 202 | sr0p = right_most[POS] 203 | 204 | sh0p = "NULL" 205 | if s0_pos in heads: 206 | sh0p = sentence[heads[s0_pos]][POS] 207 | 208 | b0wp = b0w + "/" + b0p 209 | b1wp = b1w + "/" + b1p 210 | s0wp = s0w + "/" + s0p 211 | s1wp = s1w + "/" + s1p 212 | b2wp = b2w + "/" + b2p 213 | 214 | fv["s0wp=" + s0wp] = 1 215 | fv["s0w=" + s0w] = 1 216 | fv["s0p=" + s0p] = 1 217 | fv["s1wp=" + s1wp] = 1 218 | fv["s1w=" + s1w] = 1 219 | fv["s1p=" + s1p] = 1 220 | 221 | fv["b0wp=" + b0wp] = 1 222 | fv["b0w=" + b0w] = 1 223 | fv["b0p=" + b0p] = 1 224 | fv["b1wp=" + b1wp] = 1 225 | fv["b1w=" + b1w] = 1 226 | fv["b1p=" + b1p] = 1 227 | fv["b2wp=" + b2wp] = 1 228 | fv["b2w=" + b2w] = 1 229 | fv["b2p=" + b2p] = 1 230 | 231 | s0wp_b0wp = s0wp + ";" + b0wp 232 | s0wp_b0w = s0wp + ";" + b0w 233 | s0w_b0wp = s0w + ";" + b0wp 234 | s0wp_b0p = s0wp + ";" + b0p 235 | s0p_b0wp = s0p + ";" + b0wp 236 | s0w_b0w = s0w + ";" + b0w 237 | s0p_b0p = s0p + ";" + b0p 238 | b0p_b1p = b0p + ";" + b1p 239 | 240 | fv["s0wp_b0wp=" + s0wp_b0wp] = 1 241 | fv["s0wp_b0w=" + s0wp_b0w] = 1 242 | fv["s0w_b0wp=" + s0w_b0wp] = 1 243 | fv["s0wp_b0p=" + s0wp_b0p] = 1 244 | fv["s0p_b0wp=" + s0p_b0wp] = 1 245 | fv["s0w_b0w=" + s0w_b0w] = 1 246 | fv["s0p_b0p=" + s0p_b0p] = 1 247 | fv["b0p_b1p" + b0p_b1p] = 1 248 | 249 | b0p_b1p_b2p = b0p + ";" + b1p + ";" + b2p 250 | s0p_b0p_b1p = s0p + ";" + b0p + ";" + b1p 251 | sh0p_s0p_b0p = sh0p + ";" + s0p + ";" + b0p 252 | s0p_s10p_b0p = s0p + ";" + s10p + ";" + b0p 253 | s0p_sr0p_b0p = s0p + ";" + sr0p + ";" + b0p 254 | s0p_b0p_b10p = s0p + ";" + b0p + ";" + b10p 255 | fv["b0p_b1p_b2p=" + b0p_b1p_b2p] = 1 256 | fv["s0p_b0p_b1p=" + s0p_b0p_b1p] = 1 257 | fv["sh0p_s0p_b0p=" + sh0p_s0p_b0p] = 1 258 | fv["s0p_s10p_b0p=" + s0p_s10p_b0p] = 1 259 | fv["s0p_sr0p_b0p=" + s0p_sr0p_b0p] = 1 260 | fv["s0p_b0p_b10p" + s0p_b0p_b10p] = 1 261 | 262 | return fv 263 | -------------------------------------------------------------------------------- /parse.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | 4 | 5 | class Configuration: 6 | def __init__(self, buf, s): 7 | self.arcs = [] 8 | self.buffer = buf 9 | self.stack = [] 10 | self.sentence = s 11 | 12 | 13 | class GoldConfiguration: 14 | def __init__(self): 15 | self.heads = {} 16 | self.deps = defaultdict(lambda: []) 17 | 18 | 19 | class Classifier: 20 | def __init__(self, weights, labels): 21 | self.weights = weights 22 | self.labels = labels 23 | 24 | def score(self, fv): 25 | 26 | scores = dict((label, 0) for label in self.labels) 27 | 28 | for k, v in fv.items(): 29 | 30 | if v == 0: 31 | continue 32 | if k not in self.weights: 33 | continue 34 | 35 | wv = self.weights[k] 36 | 37 | for label, weight in wv.items(): 38 | scores[label] += weight * v 39 | 40 | return scores 41 | 42 | 43 | class GreedyDepParser: 44 | 45 | SHIFT = 0 46 | RIGHT = 1 47 | LEFT = 2 48 | REDUCE = 3 49 | 50 | MAX_EX_ITER = 5 51 | MAX_EX_THRESH = 0.8 52 | 53 | def __init__(self, m, feature_extractor): 54 | self.model = m 55 | self.fx = feature_extractor 56 | self.transition_funcs = {} 57 | self.train_tick = 0 58 | self.train_last_tick = defaultdict(lambda: 0) 59 | self.train_totals = defaultdict(lambda: 0) 60 | 61 | def initial(self, sentence): 62 | pass 63 | 64 | @staticmethod 65 | def terminal(conf): 66 | return len(conf.stack) == 0 and len(conf.buffer) == 1 67 | 68 | def legal(self, conf): 69 | pass 70 | 71 | LUT = ["SHIFT", 'RIGHT', 'LEFT', 'REDUCE'] 72 | 73 | def update(self, truth, guess, features): 74 | def update_feature_label(label, fj, v): 75 | wv = 0 76 | 77 | try: 78 | wv = self.model.weights[fj][label] 79 | except KeyError: 80 | if fj not in self.model.weights: 81 | self.model.weights[fj] = {} 82 | self.model.weights[fj][label] = 0 83 | 84 | t_delt = self.train_tick - self.train_last_tick[(fj, label)] 85 | self.train_totals[(fj, label)] += t_delt * wv 86 | self.model.weights[fj][label] += v 87 | self.train_last_tick[(fj, label)] = self.train_tick 88 | 89 | self.train_tick += 1 90 | for f in features.items(): 91 | update_feature_label(truth, f[0], 1.0) 92 | update_feature_label(guess, f[0], -1.0) 93 | 94 | def dyn_oracle(self, gold_conf, conf, legal_transitions): 95 | pass 96 | 97 | def avg_weights(self): 98 | for fj in self.model.weights: 99 | for label in self.model.weights[fj]: 100 | total = self.train_totals[(fj, label)] 101 | t_delt = self.train_tick - self.train_last_tick[(fj, label)] 102 | total += t_delt * self.model.weights[fj][label] 103 | avg = round(total / float(self.train_tick)) 104 | if avg: 105 | self.model.weights[fj][label] = avg 106 | 107 | @staticmethod 108 | def get_gold_conf(sentence): 109 | gold_conf = GoldConfiguration() 110 | for dep in range(len(sentence)): 111 | head = sentence[dep][2] 112 | gold_conf.heads[dep] = head 113 | if head not in gold_conf.deps: 114 | gold_conf.deps[head] = [] 115 | gold_conf.deps[head].append(dep) 116 | 117 | return gold_conf 118 | 119 | def run(self, sentence): 120 | conf = self.initial(sentence) 121 | while not GreedyDepParser.terminal(conf): 122 | legal_transitions = self.legal(conf) 123 | features = self.fx(conf) 124 | scores = self.model.score(features) 125 | t_p = max(legal_transitions, key=lambda p: scores[p]) 126 | conf = self.transition(t_p, conf) 127 | 128 | return conf.arcs 129 | 130 | # We need to have arcs that are dominated with no crossing lines, excluding the root 131 | @staticmethod 132 | def non_projective(conf): 133 | for dep1 in conf.heads.keys(): 134 | head1 = conf.heads[dep1] 135 | for dep2 in conf.heads.keys(): 136 | head2 = conf.heads[dep2] 137 | if head1 < 0 or head2 < 0: 138 | continue 139 | if (dep1 > head2 and dep1 < dep2 and head1 < head2) or (dep1 < head2 and dep1 > dep2 and head1 < dep2): 140 | return True 141 | 142 | if dep1 < head1 and head1 is not head2: 143 | if (head1 > head2 and head1 < dep2 and dep1 < head2) or (head1 < head2 and head1 > dep2 and dep1 < dep2): 144 | return True 145 | return False 146 | 147 | def train(self, sentence, iter_num): 148 | conf = self.initial(sentence) 149 | gold_conf = GreedyDepParser.get_gold_conf(sentence) 150 | train_correct = train_all = 0 151 | 152 | n = 0 153 | while not GreedyDepParser.terminal(conf): 154 | n += 1 155 | legal_transitions = self.legal(conf) 156 | # print('LEGAL ', ' '.join([self.LUT[p] for p in legal_transitions])) 157 | features = self.fx(conf) 158 | scores = self.model.score(features) 159 | t_p = max(legal_transitions, key=lambda p: scores[p]) 160 | zero_cost = self.dyn_oracle(gold_conf, conf, legal_transitions) 161 | # print(str(n) + ' [ ' + ' '.join([self.LUT[z] for z in zero_cost]) + ' ]') 162 | 163 | if len(zero_cost) == 0: 164 | raise Exception('no zero cost') 165 | 166 | if t_p not in zero_cost: 167 | t_o = max(zero_cost, key=lambda p: scores[p]) 168 | self.update(t_o, t_p, features) 169 | self.explore(t_o, t_p, conf, iter_num) 170 | 171 | else: 172 | train_correct += 1 173 | conf = self.transition(t_p, conf) 174 | 175 | train_all += 1 176 | return train_correct, train_all 177 | 178 | def explore(self, t_o, t_p, conf, iter_i): 179 | 180 | if iter_i > GreedyDepParser.MAX_EX_ITER and random.random() > GreedyDepParser.MAX_EX_THRESH: 181 | return self.transition(t_p, conf) 182 | 183 | return self.transition(t_o, conf) 184 | 185 | def transition(self, t_p, conf): 186 | return self.transition_funcs[t_p](conf) 187 | 188 | 189 | class ArcEagerDepParser(GreedyDepParser): 190 | 191 | def __init__(self, m, f): 192 | GreedyDepParser.__init__(self, m, f) 193 | self.transition_funcs[ArcEagerDepParser.SHIFT] = ArcEagerDepParser.shift 194 | self.transition_funcs[ArcEagerDepParser.RIGHT] = ArcEagerDepParser.arc_right 195 | self.transition_funcs[ArcEagerDepParser.LEFT] = ArcEagerDepParser.arc_left 196 | self.transition_funcs[ArcEagerDepParser.REDUCE] = ArcEagerDepParser.reduce 197 | 198 | def initial(self, sentence): 199 | return Configuration(list(range(len(sentence))) + [len(sentence)], sentence) 200 | 201 | def legal(self, conf): 202 | """ 203 | Legal transitions for arc-eager dependency parsing 204 | :param conf: The current state 205 | :return: any legal transitions 206 | """ 207 | transitions = [ 208 | GreedyDepParser.SHIFT, 209 | GreedyDepParser.RIGHT, 210 | GreedyDepParser.LEFT, 211 | GreedyDepParser.REDUCE 212 | ] 213 | shift_ok = True 214 | right_ok = True 215 | left_ok = True 216 | reduce_ok = True 217 | 218 | if len(conf.buffer) == 1: 219 | right_ok = shift_ok = False 220 | 221 | if len(conf.stack) == 0: 222 | left_ok = right_ok = reduce_ok = False 223 | else: 224 | s = conf.stack[-1] 225 | 226 | # if the s is already a dependent, we cannot left-arc 227 | if len(list(filter(lambda hd: s == hd[1], conf.arcs))) > 0: 228 | left_ok = False 229 | else: 230 | reduce_ok = False 231 | 232 | ok = [shift_ok, right_ok, left_ok, reduce_ok] 233 | 234 | legal_transitions = [] 235 | for it in range(len(transitions)): 236 | if ok[it] is True: 237 | legal_transitions.append(it) 238 | 239 | return legal_transitions 240 | 241 | def dyn_oracle(self, gold_conf, conf, legal_transitions): 242 | options = [] 243 | if GreedyDepParser.SHIFT in legal_transitions and ArcEagerDepParser.zero_cost_shift(conf, gold_conf): 244 | options.append(GreedyDepParser.SHIFT) 245 | if GreedyDepParser.RIGHT in legal_transitions and ArcEagerDepParser.zero_cost_right(conf, gold_conf): 246 | options.append(GreedyDepParser.RIGHT) 247 | if GreedyDepParser.LEFT in legal_transitions and ArcEagerDepParser.zero_cost_left(conf, gold_conf): 248 | options.append(GreedyDepParser.LEFT) 249 | if GreedyDepParser.REDUCE in legal_transitions and ArcEagerDepParser.zero_cost_reduce(conf, gold_conf): 250 | options.append(GreedyDepParser.REDUCE) 251 | 252 | return options 253 | 254 | @staticmethod 255 | def zero_cost_shift(conf, gold_conf): 256 | """ 257 | Is a shift zero cost? 258 | Moving b onto stack means that b will not be able to acquire any head or dependents in S. Cost 259 | is number of gold arcs of form (k, b) or (b, k) such that k in S 260 | 261 | :param conf: Working config 262 | :param gold_conf: Gold config 263 | :return: Is the cost zero 264 | """ 265 | if len(conf.buffer) <= 1: 266 | return False 267 | b = conf.buffer[0] 268 | 269 | for si in conf.stack: 270 | if gold_conf.heads[si] == b or (gold_conf.heads[b] == si): 271 | return False 272 | return True 273 | 274 | @staticmethod 275 | def zero_cost_right(conf, gold_conf): 276 | """ 277 | Adding the arc (s, b) and pushing b onto the stack means that b will not be able to acquire any head in 278 | S or B, nor any dependents in S. The cost is the number of gold arcs of form (k, b) such that k in S or B, 279 | (b, k) such that k in S and no arc (x, k) in working conf. Cost zero for (s, b) in gold arcs but also 280 | where s is not the gold head of b but the real head not in S or B and no gold dependents of b in S. 281 | We return a boolean to identify if right-arc will be zero cost 282 | 283 | :param conf: working configuration (A_c) 284 | :param gold_conf: gold configuration 285 | :return: True if zero-cost, false otherwise 286 | """ 287 | 288 | if len(conf.stack) is 0 or len(conf.buffer) is 0: 289 | return False 290 | 291 | # Stack top 292 | s = conf.stack[-1] 293 | # Buffer top 294 | b = conf.buffer[0] 295 | 296 | # (k, b) 297 | k = b in gold_conf.heads and gold_conf.heads[b] or -1 298 | 299 | # (s, b) in gold 300 | if k == s: 301 | return True 302 | 303 | # (k, b) and k in S or B 304 | k_b_costs = k in conf.stack or k in conf.buffer 305 | 306 | # (h, d) => k_heads[d] = h 307 | k_heads = dict((arc[1], arc[0]) for arc in conf.arcs) 308 | 309 | # (b, k) 310 | b_deps = gold_conf.deps[b] 311 | 312 | # (b, k) and k in S 313 | b_k_in_stack = list(filter(lambda dep: dep in conf.stack, b_deps)) 314 | b_k_final = list(filter(lambda dep: dep not in k_heads, b_k_in_stack)) 315 | 316 | # s is not gold head but real head (k) not in stack or buffer 317 | # and no gold deps of b in S -- (b, k) doesnt exist on stack 318 | if k not in conf.buffer and k not in conf.stack and len(b_k_in_stack) is 0: 319 | return True 320 | 321 | if k_b_costs: 322 | return False 323 | 324 | return len(b_k_final) == 0 325 | 326 | @staticmethod 327 | def zero_cost_left(conf, gold_conf): 328 | """ 329 | Is the cost of a left arc going to be zero? Adding the arc (b, s) and popping s from the stack 330 | means that s will not be able to acquire any head or dependents in B. The cost is the number of gold_arcs 331 | (k, s) or (s, k) where k in B. 332 | 333 | Cost of the arc found in the gold_arcs is 0, as well as the case where b is not the gold head, but the 334 | real head is not in B. 335 | 336 | :param conf: The working configuration 337 | :param gold_conf: The gold arcs 338 | :return: True if a left-arc would be zero-cost, False otherwise 339 | """ 340 | if len(conf.stack) is 0 or len(conf.buffer) is 0: 341 | return False 342 | 343 | s = conf.stack[-1] 344 | b = conf.buffer[0] 345 | 346 | for bi in range(b, len(conf.sentence) + 1): 347 | if bi in gold_conf.heads and gold_conf.heads[bi] == s: 348 | return False 349 | if b is not bi and gold_conf.heads[s] == bi: 350 | return False 351 | return True 352 | 353 | @staticmethod 354 | def zero_cost_reduce(conf, gold_conf): 355 | if len(conf.stack) is 0 or len(conf.buffer) is 0: 356 | return False 357 | 358 | s = conf.stack[-1] 359 | b = conf.buffer[0] 360 | for bi in range(b, len(conf.sentence) + 1): 361 | if bi in gold_conf.heads and gold_conf.heads[bi] == s: 362 | return False 363 | return True 364 | 365 | @staticmethod 366 | def shift(conf): 367 | b = conf.buffer[0] 368 | del conf.buffer[0] 369 | conf.stack.append(b) 370 | return conf 371 | 372 | @staticmethod 373 | def arc_right(conf): 374 | s = conf.stack[-1] 375 | b = conf.buffer[0] 376 | del conf.buffer[0] 377 | conf.stack.append(b) 378 | conf.arcs.append((s, b)) 379 | return conf 380 | 381 | @staticmethod 382 | def arc_left(conf): 383 | # pop the top off the stack, link the arc, from the buffer 384 | s = conf.stack.pop() 385 | b = conf.buffer[0] 386 | conf.arcs.append((b, s)) 387 | return conf 388 | 389 | @staticmethod 390 | def reduce(conf): 391 | conf.stack.pop() 392 | return conf 393 | 394 | 395 | class ArcHybridDepParser(GreedyDepParser): 396 | 397 | def __init__(self, m, f): 398 | GreedyDepParser.__init__(self, m, f) 399 | self.transition_funcs[ArcHybridDepParser.SHIFT] = ArcHybridDepParser.shift 400 | self.transition_funcs[ArcHybridDepParser.RIGHT] = ArcHybridDepParser.arc_right 401 | self.transition_funcs[ArcHybridDepParser.LEFT] = ArcHybridDepParser.arc_left 402 | self.root = None 403 | 404 | def initial(self, sentence): 405 | self.root = len(sentence) 406 | return Configuration(list(range(len(sentence))) + [len(sentence)], sentence) 407 | # return Configuration([self.root] + range(len(sentence)), sentence) 408 | 409 | def legal(self, conf): 410 | transitions = [] 411 | left_ok = right_ok = shift_ok = True 412 | 413 | if len(conf.stack) < 2: 414 | right_ok = False 415 | if len(conf.stack) == 0 or conf.stack[-1] == self.root: 416 | left_ok = False 417 | 418 | if shift_ok is True: 419 | transitions.append(GreedyDepParser.SHIFT) 420 | if right_ok is True: 421 | transitions.append(GreedyDepParser.RIGHT) 422 | if left_ok is True: 423 | transitions.append(GreedyDepParser.LEFT) 424 | return transitions 425 | 426 | @staticmethod 427 | def zero_cost_right(conf, gold_conf): 428 | """ 429 | Adding the arc (s1, s0) and popping s0 from the stack means that s0 will not be able 430 | to acquire heads or deps from B. The cost is the number of arcs in gold_conf of the form 431 | (s0, d) and (h, s0) where h, d in B. For non-zero cost moves, we are looking simply for 432 | (s0, b) or (b, s0) for all b in B 433 | :param conf: 434 | :param gold_conf: 435 | :return: 436 | """ 437 | s0 = conf.stack[-1] 438 | for b in conf.buffer: 439 | if (b in gold_conf.heads and gold_conf.heads[b] is s0) or gold_conf.heads[s0] is b: 440 | return False 441 | return True 442 | 443 | 444 | @staticmethod 445 | def zero_cost_left(conf, gold_conf): 446 | """ 447 | Adding the arc (b, s0) and popping s0 from the stack means that s0 will not be able to acquire 448 | heads from H = {s1} U B and will not be able to acquire dependents from B U b, therefore the cost is 449 | the number of arcs in T of form (s0, d) or (h, s0), h in H, d in D 450 | 451 | To have cost, then, only one instance must occur 452 | 453 | :param conf: 454 | :param gold_conf: 455 | :return: 456 | """ 457 | 458 | s0 = conf.stack[-1] 459 | s1 = len(conf.stack) > 2 and conf.stack[-2] or None 460 | 461 | if any(dep in conf.buffer for dep in gold_conf.deps[s0]): 462 | return False 463 | 464 | H = conf.buffer[1:] + [s1] 465 | if gold_conf.heads[s0] in H: 466 | return False 467 | return True 468 | 469 | @staticmethod 470 | def zero_cost_shift(conf, gold_conf): 471 | """ 472 | Pushing b onto the stack means that b will not be able to acquire 473 | heads from H = {s1} U S and will not be able to acquire deps from 474 | D = {s0, s1} U S 475 | :param conf: 476 | :param gold_conf: 477 | :return: 478 | """ 479 | if len(conf.buffer) < 1: 480 | return False 481 | if len(conf.stack) == 0: 482 | return True 483 | 484 | b = conf.buffer[0] 485 | # Cost is the number of arcs in T of the form (s0, d) and (h, s0) for h in H and d in D 486 | if b in gold_conf.heads and gold_conf.heads[b] in conf.stack[0:-1]: 487 | return False 488 | ll = len(list(filter(lambda dep: dep in conf.stack, gold_conf.deps[b]))) 489 | return ll == 0 490 | 491 | @staticmethod 492 | def shift(conf): 493 | b = conf.buffer[0] 494 | del conf.buffer[0] 495 | conf.stack.append(b) 496 | return conf 497 | 498 | @staticmethod 499 | def arc_right(conf): 500 | s0 = conf.stack.pop() 501 | s1 = conf.stack[-1] 502 | conf.arcs.append((s1, s0)) 503 | return conf 504 | 505 | @staticmethod 506 | def arc_left(conf): 507 | # pop the top off the stack, link the arc, from the buffer 508 | s0 = conf.stack.pop() 509 | b = conf.buffer[0] 510 | conf.arcs.append((b, s0)) 511 | return conf 512 | 513 | def dyn_oracle(self, gold_conf, conf, legal_transitions): 514 | options = [] 515 | if GreedyDepParser.SHIFT in legal_transitions and ArcHybridDepParser.zero_cost_shift(conf, gold_conf): 516 | options.append(GreedyDepParser.SHIFT) 517 | if GreedyDepParser.RIGHT in legal_transitions and ArcHybridDepParser.zero_cost_right(conf, gold_conf): 518 | options.append(GreedyDepParser.RIGHT) 519 | if GreedyDepParser.LEFT in legal_transitions and ArcHybridDepParser.zero_cost_left(conf, gold_conf): 520 | options.append(GreedyDepParser.LEFT) 521 | return options 522 | 523 | if __name__ == '__main__': 524 | 525 | import argparse 526 | import fileio 527 | import fx 528 | 529 | parser = argparse.ArgumentParser(description="Sample program showing training and testing dependency parsers") 530 | parser.add_argument('--parser', help='Parser type (eager|hybrid) (default: eager)', default='eager') 531 | parser.add_argument('--train', help='CONLL training file', required=True) 532 | parser.add_argument('--test', help='CONLL testing file', required=True) 533 | parser.add_argument('--fx', help='Feature extractor', default='ex') 534 | parser.add_argument('--n', help='Number of passes over training data', default=15, type=int) 535 | parser.add_argument('-v', action='store_true') 536 | opts = parser.parse_args() 537 | 538 | def filter_non_projective(gold): 539 | gold_proj = [] 540 | for s in gold: 541 | gold_conf = GreedyDepParser.get_gold_conf(s) 542 | if GreedyDepParser.non_projective(gold_conf) is False: 543 | gold_proj.append(s) 544 | elif opts.v is True: 545 | print('Skipping non-projective sentence', s) 546 | return gold_proj 547 | 548 | # Defaults 549 | feature_extractor = fx.ex 550 | Parser = ArcEagerDepParser 551 | 552 | if opts.fx == 'baseline': 553 | print('Selecting baseline feature extractor') 554 | feature_extractor = fx.ex 555 | 556 | 557 | if opts.parser == 'hybrid': 558 | print('Using arc-hybrid parser') 559 | Parser = ArcHybridDepParser 560 | 561 | gold = filter_non_projective(fileio.read_conll_deps(opts.train)) 562 | model = Classifier({}, [0, 1, 2, 3]) 563 | 564 | parser = Parser(model, feature_extractor) 565 | print('performing %d iterations' % opts.n) 566 | for i in range(0, opts.n): 567 | correct_iter = 0 568 | all_iter = 0 569 | random.shuffle(gold) 570 | for gold_sent in gold: 571 | correct_s, all_s = parser.train(gold_sent, i) 572 | correct_iter += correct_s 573 | all_iter += all_s 574 | 575 | print('fraction of correct transitions iteration %d: %d/%d = %f' % (i, correct_iter, all_iter, correct_iter/float(all_iter))) 576 | parser.avg_weights() 577 | test = filter_non_projective(fileio.read_conll_deps(opts.test)) 578 | 579 | all_arcs = 0 580 | correct_arcs = 0 581 | 582 | for gold_test_sent in test: 583 | 584 | gold_arcs = set([(gold_test_sent[i][2], i) for i in range(len(gold_test_sent))]) 585 | arcs = set(parser.run(gold_test_sent)) 586 | correct_arcs += len(gold_arcs & arcs) 587 | all_arcs += len(gold_arcs) 588 | 589 | print('accuracy %d/%d = %f' % (correct_arcs, all_arcs, float(correct_arcs)/float(all_arcs))) 590 | --------------------------------------------------------------------------------