├── LICENSE ├── README.md ├── data ├── 02-21.10way.clean ├── 22.auto.clean ├── 23.auto.clean ├── model └── vocabulary.json └── src ├── features.py ├── main.py ├── network.py ├── parser.py └── phrase_tree.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 jhcross 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Span-Based Constituency Parser 2 | 3 | This is an implementation of the span-based natural language constituency parser described in the paper [Span-Based Constituency Parsing with a Structure-Label System and Provably Optimal Dynamic Oracles](http://people.oregonstate.edu/~crossj/emnlp_2016.pdf), which will appear in *EMNLP* (2016). 4 | 5 | #### Required Dependencies 6 | 7 | * Python 2.7 8 | * NumPy 9 | * [DyNet](http://dynet.readthedocs.io/en/latest/python.html) (To ensure compatibility, check out and compile commit 71fc893eda8e3f3fccc77b9a4ae942dce77ba368) 10 | 11 | 12 | 13 | 14 | #### Vocabulary Files 15 | 16 | Vocabulary may be loaded every time from a training tree file, or it may be stored (separately from a trained model) in a JSON file, which is much faster and recommended. To learn the vocabulary from a file with training trees and write a JSON file, use a command such as the following: 17 | 18 | ``` 19 | python src/main.py --train data/02-21.10way.clean --write-vocab data/vocab.json 20 | ``` 21 | 22 | #### Training 23 | 24 | Training requires a file containing training trees (`--train`) and a file containg validation trees (`--dev`), which are parsed four times per training epoch to determine which model to keep. A file name must also be provided to store the saved model (`--model`). The following is an example of a command to train a model with all of the default settings: 25 | 26 | ``` 27 | python src/main.py --train data/02-21.10way.clean --dev data/22.auto.clean --vocab data/vocab.json --model data/my_model 28 | ``` 29 | 30 | The following table provides an overview of additional training options: 31 | 32 | Argument | Description | Default 33 | --- | --- | --- 34 | --dynet-mem | Memory (MB) to allocate for DyNet | 2000 35 | --dynet-l2 | L2 regularization factor | 0 36 | --dynet-seed | Seed for random parameter initialization | random 37 | --word-dims | Word embedding dimensions | 50 38 | --tag-dims | POS embedding dimensions | 20 39 | --lstm-units | LSTM units (per direction, for each of 2 layers) | 200 40 | --hidden-units | Units for ReLU FC layer (each of 2 action types) | 200 41 | --epochs | Number of training epochs | 10 42 | --batch-size | Number of sentences per training update | 10 43 | --droprate | Dropout probability | 0.5 44 | --unk-param | Parameter z for random UNKing | 0.8375 45 | --alpha | Softmax weight for exploration | 1.0 46 | --beta | Oracle action override probability | 0.0 47 | --np-seed | Seed for shuffling and softmax sampling | random 48 | 49 | 50 | #### Test Evaluation 51 | 52 | There is also a facility to directly evaluate a model agaist a reference corpus, by supplying the `--test` argument: 53 | 54 | ``` 55 | python src/main.py --test data/23.auto.clean --vocab data/vocab.json --model data/my_model 56 | ``` 57 | 58 | #### Citation 59 | 60 | If you use this software for research, we would appreciate a citation to our paper: 61 | 62 | ``` 63 | @inproceedings{cross2016span, 64 | title={Span-Based Constituency Parsing with a Structure-Label System and Provably Optimal Dynamic Oracles}, 65 | author={Cross, James and Huang, Liang}, 66 | journal={Empirical Methods in Natural Language Processing (EMNLP)}, 67 | year={2016} 68 | } 69 | ``` -------------------------------------------------------------------------------- /src/features.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import json 7 | from collections import defaultdict, OrderedDict 8 | 9 | import numpy as np 10 | 11 | from phrase_tree import PhraseTree 12 | from parser import Parser 13 | 14 | 15 | 16 | 17 | class FeatureMapper(object): 18 | """ 19 | Maps words, tags, and label actions to indices. 20 | """ 21 | 22 | UNK = '' 23 | START = '' 24 | STOP = '' 25 | 26 | 27 | @staticmethod 28 | def vocab_init(fname, verbose=True): 29 | """ 30 | Learn vocabulary from file of strings. 31 | """ 32 | word_freq = defaultdict(int) 33 | tag_freq = defaultdict(int) 34 | label_freq = defaultdict(int) 35 | 36 | trees = PhraseTree.load_treefile(fname) 37 | 38 | for i, tree in enumerate(trees): 39 | for (word, tag) in tree.sentence: 40 | word_freq[word] += 1 41 | tag_freq[tag] += 1 42 | 43 | for action in Parser.gold_actions(tree): 44 | if action.startswith('label-'): 45 | label = action[6:] 46 | label_freq[label] += 1 47 | 48 | if verbose: 49 | print('\rTree {}'.format(i), end='') 50 | sys.stdout.flush() 51 | 52 | if verbose: 53 | print('\r', end='') 54 | 55 | 56 | words = [ 57 | FeatureMapper.UNK, 58 | FeatureMapper.START, 59 | FeatureMapper.STOP, 60 | ] + sorted(word_freq) 61 | wdict = OrderedDict((w,i) for (i,w) in enumerate(words)) 62 | 63 | tags = [ 64 | FeatureMapper.UNK, 65 | FeatureMapper.START, 66 | FeatureMapper.STOP, 67 | ] + sorted(tag_freq) 68 | tdict = OrderedDict((t,i) for (i,t) in enumerate(tags)) 69 | 70 | labels = sorted(label_freq) 71 | ldict = OrderedDict((l,i) for (i,l) in enumerate(labels)) 72 | 73 | if verbose: 74 | print('Loading features from {}'.format(fname)) 75 | print('({} words, {} tags, {} nonterminal-chains)'.format( 76 | len(wdict), 77 | len(tdict), 78 | len(ldict), 79 | )) 80 | 81 | return { 82 | 'wdict': wdict, 83 | 'word_freq': word_freq, 84 | 'tdict': tdict, 85 | 'ldict': ldict, 86 | } 87 | 88 | 89 | def __init__(self, vocabfile, verbose=True): 90 | 91 | if vocabfile is not None: 92 | data = FeatureMapper.vocab_init( 93 | fname=vocabfile, 94 | verbose=verbose, 95 | ) 96 | self.wdict = data['wdict'] 97 | self.word_freq = data['word_freq'] 98 | self.tdict = data['tdict'] 99 | self.ldict = data['ldict'] 100 | 101 | self.word_freq_list = [] 102 | for word in self.wdict.keys(): 103 | if word in self.word_freq: 104 | self.word_freq_list.append(self.word_freq[word]) 105 | else: 106 | self.word_freq_list.append(0) 107 | 108 | 109 | @staticmethod 110 | def from_dict(data): 111 | new = FeatureMapper(None) 112 | new.wdict = data['wdict'] 113 | new.word_freq = data['word_freq'] 114 | new.tdict = data['tdict'] 115 | new.ldict = data['ldict'] 116 | new.word_freq_list = data['word_freq_list'] 117 | return new 118 | 119 | 120 | def as_dict(self): 121 | return { 122 | 'wdict': self.wdict, 123 | 'word_freq': self.word_freq, 124 | 'tdict': self.tdict, 125 | 'ldict': self.ldict, 126 | 'word_freq_list': self.word_freq_list 127 | } 128 | 129 | 130 | def save_json(self, filename): 131 | with open(filename, 'w') as fh: 132 | json.dump(self.as_dict(), fh) 133 | 134 | 135 | @staticmethod 136 | def load_json(filename): 137 | with open(filename) as fh: 138 | data = json.load(fh, object_pairs_hook=OrderedDict) 139 | return FeatureMapper.from_dict(data) 140 | 141 | 142 | def total_words(self): 143 | return len(self.wdict) 144 | 145 | 146 | def total_tags(self): 147 | return len(self.tdict) 148 | 149 | 150 | def total_label_actions(self): 151 | return 1 + len(self.ldict) 152 | 153 | 154 | def s_action_index(self, action): 155 | if action == 'sh': 156 | return 0 157 | elif action == 'comb': 158 | return 1 159 | else: 160 | raise ValueError('Not s-action: {}'.format(action)) 161 | 162 | 163 | def l_action_index(self, action): 164 | if action == 'none': 165 | return 0 166 | elif action.startswith('label-'): 167 | label = action[6:] 168 | label_index = self.ldict.get(label, None) 169 | if label_index is not None: 170 | return 1 + label_index 171 | else: 172 | return 0 173 | else: 174 | raise ValueError('Not l-action: {}'.format(action)) 175 | 176 | 177 | def s_action(self, index): 178 | return ('sh', 'comb')[index] 179 | 180 | 181 | def l_action(self, index): 182 | if index == 0: 183 | return 'none' 184 | else: 185 | return 'label-' + self.ldict.keys()[index - 1] 186 | 187 | 188 | def sentence_sequences(self, sentence): 189 | """ 190 | Array of indices for words and tags. 191 | """ 192 | sentence = ( 193 | [(FeatureMapper.START, FeatureMapper.START)] + 194 | sentence + 195 | [(FeatureMapper.STOP, FeatureMapper.STOP)] 196 | ) 197 | 198 | words = [ 199 | self.wdict[w] 200 | if w in self.wdict else self.wdict[FeatureMapper.UNK] 201 | for (w, t) in sentence 202 | ] 203 | tags = [ 204 | self.tdict[t] 205 | if t in self.tdict else self.tdict[FeatureMapper.UNK] 206 | for (w, t) in sentence 207 | ] 208 | 209 | w = np.array(words).astype('int32') 210 | t = np.array(tags).astype('int32') 211 | 212 | return w, t 213 | 214 | 215 | def gold_data(self, reftree): 216 | """ 217 | Static oracle for tree. 218 | """ 219 | 220 | w, t = self.sentence_sequences(reftree.sentence) 221 | 222 | (s_features, l_features) = Parser.training_data(reftree) 223 | 224 | struct_data = {} 225 | for (features, action) in s_features: 226 | struct_data[features] = self.s_action_index(action) 227 | 228 | label_data = {} 229 | for (features, action) in l_features: 230 | label_data[features] = self.l_action_index(action) 231 | 232 | return { 233 | 'tree': reftree, 234 | 'w': w, 235 | 't': t, 236 | 'struct_data': struct_data, 237 | 'label_data': label_data, 238 | } 239 | 240 | 241 | def gold_data_from_file(self, fname): 242 | """ 243 | Static oracle for file. 244 | """ 245 | trees = PhraseTree.load_treefile(fname) 246 | result = [] 247 | for tree in trees: 248 | sentence_data = self.gold_data(tree) 249 | result.append(sentence_data) 250 | return result 251 | 252 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Command-line interface for Span-Based Constituency Parser. 3 | """ 4 | 5 | 6 | from __future__ import print_function 7 | from __future__ import division 8 | 9 | import sys 10 | import argparse 11 | 12 | 13 | if __name__ == '__main__': 14 | 15 | parser = argparse.ArgumentParser(prog='Span-Based Constituency Parser') 16 | parser.add_argument( 17 | '--dynet-mem', 18 | dest='dynet_mem', 19 | help='Memory allocation for Dynet. (DEFAULT=2000)', 20 | default=2000, 21 | ) 22 | parser.add_argument( 23 | '--dynet-l2', 24 | dest='dynet_l2', 25 | help='L2 regularization parameter. (DEFAULT=0)', 26 | default=0, 27 | ) 28 | parser.add_argument( 29 | '--dynet-seed', 30 | dest='dynet_seed', 31 | help='Seed for PNG. (DEFAULT=0 : generate)', 32 | default=0, 33 | ) 34 | parser.add_argument( 35 | '--model', 36 | dest='model', 37 | help='File to save or load model.', 38 | ) 39 | parser.add_argument( 40 | '--train', 41 | dest='train', 42 | help='Training trees. PTB (parenthetical) format.', 43 | ) 44 | parser.add_argument( 45 | '--test', 46 | dest='test', 47 | help=( 48 | 'Evaluation trees. PTB (parenthetical) format.' 49 | ' Omit for training.' 50 | ), 51 | ) 52 | parser.add_argument( 53 | '--dev', 54 | dest='dev', 55 | help=( 56 | 'Validation trees. PTB (parenthetical) format.' 57 | ' Required for training' 58 | ), 59 | ) 60 | parser.add_argument( 61 | '--vocab', 62 | dest='vocab', 63 | help='JSON file from which to load vocabulary.', 64 | ) 65 | parser.add_argument( 66 | '--write-vocab', 67 | dest='vocab_output', 68 | help='Destination to save vocabulary from training data.', 69 | ) 70 | parser.add_argument( 71 | '--word-dims', 72 | dest='word_dims', 73 | type=int, 74 | default=50, 75 | help='Embedding dimesions for word forms. (DEFAULT=50)', 76 | ) 77 | parser.add_argument( 78 | '--tag-dims', 79 | dest='tag_dims', 80 | type=int, 81 | default=20, 82 | help='Embedding dimesions for POS tags. (DEFAULT=20)', 83 | ) 84 | parser.add_argument( 85 | '--lstm-units', 86 | dest='lstm_units', 87 | type=int, 88 | default=200, 89 | help='Number of LSTM units in each layer/direction. (DEFAULT=200)', 90 | ) 91 | parser.add_argument( 92 | '--hidden-units', 93 | dest='hidden_units', 94 | type=int, 95 | default=200, 96 | help='Number of hidden units for each FC ReLU layer. (DEFAULT=200)', 97 | ) 98 | parser.add_argument( 99 | '--epochs', 100 | dest='epochs', 101 | type=int, 102 | default=10, 103 | help='Number of training epochs. (DEFAULT=10)', 104 | ) 105 | parser.add_argument( 106 | '--batch-size', 107 | dest='batch_size', 108 | type=int, 109 | default=10, 110 | help='Number of sentences per training update. (DEFAULT=10)', 111 | ) 112 | parser.add_argument( 113 | '--droprate', 114 | dest='droprate', 115 | type=float, 116 | default=0.5, 117 | help='Dropout probability. (DEFAULT=0.5)', 118 | ) 119 | parser.add_argument( 120 | '--unk-param', 121 | dest='unk_param', 122 | type=float, 123 | default=0.8375, 124 | help='Parameter z for random UNKing. (DEFAULT=0.8375)', 125 | ) 126 | parser.add_argument( 127 | '--alpha', 128 | dest='alpha', 129 | type=float, 130 | default=1.0, 131 | help='Softmax distribution weighting for exploration. (DEFAULT=1.0)', 132 | ) 133 | parser.add_argument( 134 | '--beta', 135 | dest='beta', 136 | type=float, 137 | default=0, 138 | help='Probability of using oracle action in exploration. (DEFAULT=0)', 139 | ) 140 | parser.add_argument('--np-seed', type=int, dest='np_seed') 141 | 142 | args = parser.parse_args() 143 | 144 | # Overriding DyNet defaults 145 | sys.argv.insert(1, str(args.dynet_mem)) 146 | sys.argv.insert(1, '--dynet-mem') 147 | sys.argv.insert(1, str(args.dynet_l2)) 148 | sys.argv.insert(1, '--dynet-l2') 149 | sys.argv.insert(1, str(args.dynet_seed)) 150 | sys.argv.insert(1, '--dynet-seed') 151 | 152 | if args.vocab is not None: 153 | from features import FeatureMapper 154 | fm = FeatureMapper.load_json(args.vocab) 155 | elif args.train is not None: 156 | from features import FeatureMapper 157 | fm = FeatureMapper(args.train) 158 | if args.vocab_output is not None: 159 | fm.save_json(args.vocab_output) 160 | print('Wrote vocabulary file {}'.format(args.vocab_output)) 161 | sys.exit() 162 | else: 163 | print('Must specify either --vocab-file or --train-data.') 164 | print(' (Use -h or --help flag for full option list.)') 165 | sys.exit() 166 | 167 | if args.model is None: 168 | print('Must specify --model or (or --write-vocab) parameter.') 169 | print(' (Use -h or --help flag for full option list.)') 170 | sys.exit() 171 | 172 | if args.test is not None: 173 | from phrase_tree import PhraseTree 174 | from network import Network 175 | from parser import Parser 176 | 177 | test_trees = PhraseTree.load_treefile(args.test) 178 | print('Loaded test trees from {}'.format(args.test)) 179 | network = Network.load(args.model) 180 | print('Loaded model from: {}'.format(args.model)) 181 | accuracy = Parser.evaluate_corpus(test_trees, fm, network) 182 | print('Accuracy: {}'.format(accuracy)) 183 | elif args.train is not None: 184 | from network import Network 185 | 186 | if args.np_seed is not None: 187 | import numpy as np 188 | np.random.seed(args.np_seed) 189 | 190 | print('L2 regularization: {}'.format(args.dynet_l2)) 191 | 192 | Network.train( 193 | feature_mapper=fm, 194 | word_dims=args.word_dims, 195 | tag_dims=args.tag_dims, 196 | lstm_units=args.lstm_units, 197 | hidden_units=args.hidden_units, 198 | epochs=args.epochs, 199 | batch_size=args.batch_size, 200 | train_data_file=args.train, 201 | dev_data_file=args.dev, 202 | model_save_file=args.model, 203 | droprate=args.droprate, 204 | unk_param=args.unk_param, 205 | alpha=args.alpha, 206 | beta=args.beta, 207 | ) 208 | 209 | 210 | -------------------------------------------------------------------------------- /src/network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bi-LSTM network for span-based constituency parsing. 3 | """ 4 | 5 | from __future__ import print_function 6 | from __future__ import division 7 | 8 | import time 9 | import random 10 | import sys 11 | 12 | import dynet 13 | import numpy as np 14 | 15 | from phrase_tree import PhraseTree, FScore 16 | from features import FeatureMapper 17 | from parser import Parser 18 | 19 | class LSTM(object): 20 | """ 21 | LSTM class with initial state as parameter, and all parameters 22 | initialized in [-0.01, 0.01]. 23 | """ 24 | 25 | number = 0 26 | 27 | def __init__(self, input_dims, output_dims, model): 28 | self.input_dims = input_dims 29 | self.output_dims = output_dims 30 | self.model = model 31 | 32 | self.W_i = model.add_parameters( 33 | (output_dims, input_dims + output_dims), 34 | init=dynet.UniformInitializer(0.01), 35 | ) 36 | self.b_i = model.add_parameters( 37 | (output_dims,), 38 | init=dynet.ConstInitializer(0), 39 | ) 40 | self.W_f = model.add_parameters( 41 | (output_dims, input_dims + output_dims), 42 | init=dynet.UniformInitializer(0.01), 43 | ) 44 | self.b_f = model.add_parameters( 45 | (output_dims,), 46 | init=dynet.ConstInitializer(0), 47 | ) 48 | self.W_c = model.add_parameters( 49 | (output_dims, input_dims + output_dims), 50 | init=dynet.UniformInitializer(0.01), 51 | ) 52 | self.b_c = model.add_parameters( 53 | (output_dims,), 54 | init=dynet.ConstInitializer(0), 55 | ) 56 | self.W_o = model.add_parameters( 57 | (output_dims, input_dims + output_dims), 58 | init=dynet.UniformInitializer(0.01), 59 | ) 60 | self.b_o = model.add_parameters( 61 | (output_dims,), 62 | init=dynet.ConstInitializer(0), 63 | ) 64 | self.c0 = model.add_parameters( 65 | (output_dims,), 66 | init=dynet.ConstInitializer(0), 67 | ) 68 | 69 | self.W_params = [self.W_i, self.W_f, self.W_c, self.W_o] 70 | self.b_params = [self.b_i, self.b_f, self.b_c, self.b_o] 71 | self.params = self.W_params + self.b_params + [self.c0] 72 | 73 | class State(object): 74 | 75 | def __init__(self, lstm): 76 | self.lstm = lstm 77 | 78 | self.outputs = [] 79 | 80 | self.c = dynet.parameter(self.lstm.c0) 81 | self.h = dynet.tanh(self.c) 82 | 83 | self.W_i = dynet.parameter(self.lstm.W_i) 84 | self.b_i = dynet.parameter(self.lstm.b_i) 85 | 86 | self.W_f = dynet.parameter(self.lstm.W_f) 87 | self.b_f = dynet.parameter(self.lstm.b_f) 88 | 89 | self.W_c = dynet.parameter(self.lstm.W_c) 90 | self.b_c = dynet.parameter(self.lstm.b_c) 91 | 92 | self.W_o = dynet.parameter(self.lstm.W_o) 93 | self.b_o = dynet.parameter(self.lstm.b_o) 94 | 95 | 96 | def add_input(self, input_vec): 97 | """ 98 | Note that this function updates the existing State object! 99 | """ 100 | x = dynet.concatenate([input_vec, self.h]) 101 | 102 | i = dynet.logistic(self.W_i * x + self.b_i) 103 | f = dynet.logistic(self.W_f * x + self.b_f) 104 | g = dynet.tanh(self.W_c * x + self.b_c) 105 | o = dynet.logistic(self.W_o * x + self.b_o) 106 | 107 | c = dynet.cmult(f, self.c) + dynet.cmult(i, g) 108 | h = dynet.cmult(o, dynet.tanh(c)) 109 | 110 | self.c = c 111 | self.h = h 112 | self.outputs.append(h) 113 | 114 | return self 115 | 116 | 117 | def output(self): 118 | return self.outputs[-1] 119 | 120 | 121 | def initial_state(self): 122 | return LSTM.State(self) 123 | 124 | 125 | 126 | 127 | class Network(object): 128 | 129 | def __init__( 130 | self, 131 | word_count, 132 | tag_count, 133 | word_dims, 134 | tag_dims, 135 | lstm_units, 136 | hidden_units, 137 | struct_out, 138 | label_out, 139 | droprate=0, 140 | struct_spans=4, 141 | label_spans=3, 142 | ): 143 | 144 | self.word_count = word_count 145 | self.tag_count = tag_count 146 | self.word_dims = word_dims 147 | self.tag_dims = tag_dims 148 | self.lstm_units = lstm_units 149 | self.hidden_units = hidden_units 150 | self.struct_out = struct_out 151 | self.label_out = label_out 152 | 153 | self.droprate = droprate 154 | 155 | self.model = dynet.Model() 156 | 157 | self.trainer = dynet.AdadeltaTrainer(self.model, eps=1e-7, rho=0.99) 158 | random.seed(1) 159 | 160 | self.activation = dynet.rectify 161 | 162 | self.word_embed = self.model.add_lookup_parameters( 163 | (word_count, word_dims), 164 | ) 165 | self.tag_embed = self.model.add_lookup_parameters( 166 | (tag_count, tag_dims), 167 | ) 168 | 169 | self.fwd_lstm1 = LSTM(word_dims + tag_dims, lstm_units, self.model) 170 | self.back_lstm1 = LSTM(word_dims + tag_dims, lstm_units, self.model) 171 | 172 | self.fwd_lstm2 = LSTM(2 * lstm_units, lstm_units, self.model) 173 | self.back_lstm2 = LSTM(2 * lstm_units, lstm_units, self.model) 174 | 175 | 176 | self.struct_hidden_W = self.model.add_parameters( 177 | (hidden_units, 4 * struct_spans * lstm_units), 178 | dynet.UniformInitializer(0.01), 179 | ) 180 | self.struct_hidden_b = self.model.add_parameters( 181 | (hidden_units,), 182 | dynet.ConstInitializer(0), 183 | ) 184 | self.struct_output_W = self.model.add_parameters( 185 | (struct_out, hidden_units), 186 | dynet.ConstInitializer(0), 187 | ) 188 | self.struct_output_b = self.model.add_parameters( 189 | (struct_out,), 190 | dynet.ConstInitializer(0), 191 | ) 192 | 193 | self.label_hidden_W = self.model.add_parameters( 194 | (hidden_units, 4 * label_spans * lstm_units), 195 | dynet.UniformInitializer(0.01), 196 | ) 197 | self.label_hidden_b = self.model.add_parameters( 198 | (hidden_units,), 199 | dynet.ConstInitializer(0), 200 | ) 201 | self.label_output_W = self.model.add_parameters( 202 | (label_out, hidden_units), 203 | dynet.ConstInitializer(0), 204 | ) 205 | self.label_output_b = self.model.add_parameters( 206 | (label_out,), 207 | dynet.ConstInitializer(0), 208 | ) 209 | 210 | 211 | def init_params(self): 212 | 213 | self.word_embed.init_from_array( 214 | np.random.uniform(-0.01, 0.01, self.word_embed.shape()), 215 | ) 216 | self.tag_embed.init_from_array( 217 | np.random.uniform(-0.01, 0.01, self.tag_embed.shape()), 218 | ) 219 | 220 | 221 | def prep_params(self): 222 | 223 | self.W1_struct = dynet.parameter(self.struct_hidden_W) 224 | self.b1_struct = dynet.parameter(self.struct_hidden_b) 225 | 226 | self.W2_struct = dynet.parameter(self.struct_output_W) 227 | self.b2_struct = dynet.parameter(self.struct_output_b) 228 | 229 | self.W1_label = dynet.parameter(self.label_hidden_W) 230 | self.b1_label = dynet.parameter(self.label_hidden_b) 231 | 232 | self.W2_label = dynet.parameter(self.label_output_W) 233 | self.b2_label = dynet.parameter(self.label_output_b) 234 | 235 | 236 | def evaluate_recurrent(self, word_inds, tag_inds, test=False): 237 | 238 | fwd1 = self.fwd_lstm1.initial_state() 239 | back1 = self.back_lstm1.initial_state() 240 | 241 | fwd2 = self.fwd_lstm2.initial_state() 242 | back2 = self.back_lstm2.initial_state() 243 | 244 | sentence = [] 245 | 246 | for (w, t) in zip(word_inds, tag_inds): 247 | wordvec = dynet.lookup(self.word_embed, w) 248 | tagvec = dynet.lookup(self.tag_embed, t) 249 | vec = dynet.concatenate([wordvec, tagvec]) 250 | sentence.append(vec) 251 | 252 | fwd1_out = [] 253 | for vec in sentence: 254 | fwd1 = fwd1.add_input(vec) 255 | fwd_vec = fwd1.output() 256 | fwd1_out.append(fwd_vec) 257 | 258 | back1_out = [] 259 | for vec in reversed(sentence): 260 | back1 = back1.add_input(vec) 261 | back_vec = back1.output() 262 | back1_out.append(back_vec) 263 | 264 | lstm2_input = [] 265 | for (f, b) in zip(fwd1_out, reversed(back1_out)): 266 | lstm2_input.append(dynet.concatenate([f, b])) 267 | 268 | fwd2_out = [] 269 | for vec in lstm2_input: 270 | if self.droprate > 0 and not test: 271 | vec = dynet.dropout(vec, self.droprate) 272 | fwd2 = fwd2.add_input(vec) 273 | fwd_vec = fwd2.output() 274 | fwd2_out.append(fwd_vec) 275 | 276 | back2_out = [] 277 | for vec in reversed(lstm2_input): 278 | if self.droprate > 0 and not test: 279 | vec = dynet.dropout(vec, self.droprate) 280 | back2 = back2.add_input(vec) 281 | back_vec = back2.output() 282 | back2_out.append(back_vec) 283 | 284 | fwd_out = [dynet.concatenate([f1, f2]) for (f1, f2) in zip(fwd1_out, fwd2_out)] 285 | back_out = [dynet.concatenate([b1, b2]) for (b1, b2) in zip(back1_out, back2_out)] 286 | 287 | return fwd_out, back_out[::-1] 288 | 289 | 290 | def evaluate_struct(self, fwd_out, back_out, lefts, rights, test=False): 291 | 292 | fwd_span_out = [] 293 | for left_index, right_index in zip(lefts, rights): 294 | fwd_span_out.append(fwd_out[right_index] - fwd_out[left_index - 1]) 295 | fwd_span_vec = dynet.concatenate(fwd_span_out) 296 | 297 | back_span_out = [] 298 | for left_index, right_index in zip(lefts, rights): 299 | back_span_out.append(back_out[left_index] - back_out[right_index + 1]) 300 | back_span_vec = dynet.concatenate(back_span_out) 301 | 302 | hidden_input = dynet.concatenate([fwd_span_vec, back_span_vec]) 303 | 304 | if self.droprate > 0 and not test: 305 | hidden_input = dynet.dropout(hidden_input, self.droprate) 306 | 307 | hidden_output = self.activation(self.W1_struct * hidden_input + self.b1_struct) 308 | 309 | scores = (self.W2_struct * hidden_output + self.b2_struct) 310 | 311 | return scores 312 | 313 | 314 | 315 | def evaluate_label(self, fwd_out, back_out, lefts, rights, test=False): 316 | 317 | fwd_span_out = [] 318 | for left_index, right_index in zip(lefts, rights): 319 | fwd_span_out.append(fwd_out[right_index] - fwd_out[left_index - 1]) 320 | fwd_span_vec = dynet.concatenate(fwd_span_out) 321 | 322 | back_span_out = [] 323 | for left_index, right_index in zip(lefts, rights): 324 | back_span_out.append(back_out[left_index] - back_out[right_index + 1]) 325 | back_span_vec = dynet.concatenate(back_span_out) 326 | 327 | hidden_input = dynet.concatenate([fwd_span_vec, back_span_vec]) 328 | 329 | if self.droprate > 0 and not test: 330 | hidden_input = dynet.dropout(hidden_input, self.droprate) 331 | 332 | hidden_output = self.activation(self.W1_label * hidden_input + self.b1_label) 333 | 334 | scores = (self.W2_label * hidden_output + self.b2_label) 335 | 336 | return scores 337 | 338 | 339 | def save(self, filename): 340 | """ 341 | Appends architecture hyperparameters to end of dynet model file. 342 | """ 343 | self.model.save(filename) 344 | 345 | with open(filename, 'a') as f: 346 | f.write('\n') 347 | f.write('word_count = {}\n'.format(self.word_count)) 348 | f.write('tag_count = {}\n'.format(self.tag_count)) 349 | f.write('word_dims = {}\n'.format(self.word_dims)) 350 | f.write('tag_dims = {}\n'.format(self.tag_dims)) 351 | f.write('lstm_units = {}\n'.format(self.lstm_units)) 352 | f.write('hidden_units = {}\n'.format(self.hidden_units)) 353 | f.write('struct_out = {}\n'.format(self.struct_out)) 354 | f.write('label_out = {}\n'.format(self.label_out)) 355 | 356 | 357 | @staticmethod 358 | def load(filename): 359 | """ 360 | Loads file created by save() method. 361 | """ 362 | with open(filename) as f: 363 | f.readline() 364 | f.readline() 365 | word_count = int(f.readline().split()[-1]) 366 | tag_count = int(f.readline().split()[-1]) 367 | word_dims = int(f.readline().split()[-1]) 368 | tag_dims = int(f.readline().split()[-1]) 369 | lstm_units = int(f.readline().split()[-1]) 370 | hidden_units = int(f.readline().split()[-1]) 371 | struct_out = int(f.readline().split()[-1]) 372 | label_out = int(f.readline().split()[-1]) 373 | 374 | network = Network( 375 | word_count=word_count, 376 | tag_count=tag_count, 377 | word_dims=word_dims, 378 | tag_dims=tag_dims, 379 | lstm_units=lstm_units, 380 | hidden_units=hidden_units, 381 | struct_out=struct_out, 382 | label_out=label_out, 383 | ) 384 | network.model.load(filename) 385 | 386 | return network 387 | 388 | 389 | @staticmethod 390 | def train( 391 | feature_mapper, 392 | word_dims, 393 | tag_dims, 394 | lstm_units, 395 | hidden_units, 396 | epochs, 397 | batch_size, 398 | train_data_file, 399 | dev_data_file, 400 | model_save_file, 401 | droprate, 402 | unk_param, 403 | alpha=1.0, 404 | beta=0.0, 405 | ): 406 | 407 | start_time = time.time() 408 | 409 | fm = feature_mapper 410 | word_count = fm.total_words() 411 | tag_count = fm.total_tags() 412 | 413 | network = Network( 414 | word_count=word_count, 415 | tag_count=tag_count, 416 | word_dims=word_dims, 417 | tag_dims=tag_dims, 418 | lstm_units=lstm_units, 419 | hidden_units=hidden_units, 420 | struct_out=2, 421 | label_out=fm.total_label_actions(), 422 | droprate=droprate, 423 | ) 424 | network.init_params() 425 | 426 | print('Hidden units: {}, per-LSTM units: {}'.format( 427 | hidden_units, 428 | lstm_units, 429 | )) 430 | print('Embeddings: word={} tag={}'.format( 431 | (word_count, word_dims), 432 | (tag_count, tag_dims), 433 | )) 434 | print('Dropout rate: {}'.format(droprate)) 435 | print('Parameters initialized in [-0.01, 0.01]') 436 | print('Random UNKing parameter z = {}'.format(unk_param)) 437 | print('Exploration: alpha={} beta={}'.format(alpha, beta)) 438 | 439 | training_data = fm.gold_data_from_file(train_data_file) 440 | num_batches = -(-len(training_data) // batch_size) 441 | print('Loaded {} training sentences ({} batches of size {})!'.format( 442 | len(training_data), 443 | num_batches, 444 | batch_size, 445 | )) 446 | parse_every = -(-num_batches // 4) 447 | 448 | dev_trees = PhraseTree.load_treefile(dev_data_file) 449 | print('Loaded {} validation trees!'.format(len(dev_trees))) 450 | 451 | best_acc = FScore() 452 | 453 | for epoch in xrange(1, epochs + 1): 454 | print('........... epoch {} ...........'.format(epoch)) 455 | 456 | total_cost = 0.0 457 | total_states = 0 458 | training_acc = FScore() 459 | 460 | np.random.shuffle(training_data) 461 | 462 | for b in xrange(num_batches): 463 | batch = training_data[(b * batch_size) : ((b + 1) * batch_size)] 464 | 465 | explore = [ 466 | Parser.exploration( 467 | example, 468 | fm, 469 | network, 470 | alpha=alpha, 471 | beta=beta, 472 | ) for example in batch 473 | ] 474 | for (_, acc) in explore: 475 | training_acc += acc 476 | 477 | batch = [example for (example, _) in explore] 478 | 479 | dynet.renew_cg() 480 | network.prep_params() 481 | 482 | errors = [] 483 | 484 | for example in batch: 485 | 486 | ## random UNKing ## 487 | for (i, w) in enumerate(example['w']): 488 | if w <= 2: 489 | continue 490 | 491 | freq = fm.word_freq_list[w] 492 | drop_prob = unk_param / (unk_param + freq) 493 | r = np.random.random() 494 | if r < drop_prob: 495 | example['w'][i] = 0 496 | 497 | fwd, back = network.evaluate_recurrent( 498 | example['w'], 499 | example['t'], 500 | ) 501 | 502 | for (left, right), correct in example['struct_data'].items(): 503 | scores = network.evaluate_struct(fwd, back, left, right) 504 | 505 | probs = dynet.softmax(scores) 506 | loss = -dynet.log(dynet.pick(probs, correct)) 507 | errors.append(loss) 508 | total_states += len(example['struct_data']) 509 | 510 | for (left, right), correct in example['label_data'].items(): 511 | scores = network.evaluate_label(fwd, back, left, right) 512 | 513 | probs = dynet.softmax(scores) 514 | loss = -dynet.log(dynet.pick(probs, correct)) 515 | errors.append(loss) 516 | total_states += len(example['label_data']) 517 | 518 | batch_error = dynet.esum(errors) 519 | total_cost += batch_error.scalar_value() 520 | batch_error.backward() 521 | network.trainer.update() 522 | 523 | mean_cost = total_cost / total_states 524 | 525 | print( 526 | '\rBatch {} Mean Cost {:.4f} [Train: {}]'.format( 527 | b, 528 | mean_cost, 529 | training_acc, 530 | ), 531 | end='', 532 | ) 533 | sys.stdout.flush() 534 | 535 | if ((b + 1) % parse_every) == 0 or b == (num_batches - 1): 536 | dev_acc = Parser.evaluate_corpus( 537 | dev_trees, 538 | fm, 539 | network, 540 | ) 541 | print(' [Val: {}]'.format(dev_acc)) 542 | 543 | if dev_acc > best_acc: 544 | best_acc = dev_acc 545 | network.save(model_save_file) 546 | print(' [saved model: {}]'.format(model_save_file)) 547 | 548 | current_time = time.time() 549 | runmins = (current_time - start_time)/60. 550 | print(' Elapsed time: {:.2f}m'.format(runmins)) 551 | 552 | -------------------------------------------------------------------------------- /src/parser.py: -------------------------------------------------------------------------------- 1 | """ 2 | Shift-Combine-Label parser. 3 | """ 4 | 5 | from __future__ import print_function 6 | from __future__ import division 7 | 8 | import numpy as np 9 | import dynet 10 | from collections import defaultdict 11 | 12 | from phrase_tree import PhraseTree, FScore 13 | 14 | 15 | class Parser(object): 16 | 17 | def __init__(self, n): 18 | """ 19 | Initial state for parsing an n-word sentence. 20 | """ 21 | self.n = n 22 | self.i = 0 23 | self.stack = [] 24 | 25 | 26 | def can_shift(self): 27 | return (self.i < self.n) 28 | 29 | 30 | def can_combine(self): 31 | return (len(self.stack) > 1) 32 | 33 | 34 | def shift(self): 35 | j = self.i # (index of shifted word) 36 | treelet = PhraseTree(leaf=j) 37 | self.stack.append((j, j, [treelet])) 38 | self.i += 1 39 | 40 | 41 | def combine(self): 42 | (_, right, treelist0) = self.stack.pop() 43 | (left, _, treelist1) = self.stack.pop() 44 | self.stack.append((left, right, treelist1 + treelist0)) 45 | 46 | 47 | def label(self, nonterminals=[]): 48 | 49 | for nt in nonterminals: 50 | (left, right, trees) = self.stack.pop() 51 | tree = PhraseTree(symbol=nt, children=trees) 52 | self.stack.append((left, right, [tree])) 53 | 54 | 55 | def take_action(self, action): 56 | if action == 'sh': 57 | self.shift() 58 | elif action == 'comb': 59 | self.combine() 60 | elif action == 'none': 61 | return 62 | elif action.startswith('label-'): 63 | self.label(action[6:].split('-')) 64 | else: 65 | raise RuntimeError('Invalid Action: {}'.format(action)) 66 | 67 | 68 | def finished(self): 69 | return ( 70 | (self.i == self.n) and 71 | (len(self.stack) == 1) and 72 | (len(self.stack[0][2]) == 1) 73 | ) 74 | 75 | 76 | def tree(self): 77 | if not self.finished(): 78 | raise RuntimeError('Not finished.') 79 | return self.stack[0][2][0] 80 | 81 | 82 | def s_features(self): 83 | """ 84 | Features for predicting structural action (shift, combine): 85 | (pre-s1-span, s1-span, s0-span, post-s0-span) 86 | Note features use 1-based indexing: 87 | ... a span of (1, 1) means the first word of sentence 88 | ... (x, x-1) means no span 89 | """ 90 | lefts = [] 91 | rights = [] 92 | 93 | # pre-s1-span 94 | lefts.append(1) 95 | if len(self.stack) < 2: 96 | rights.append(0) 97 | else: 98 | s1_left = self.stack[-2][0] + 1 99 | rights.append(s1_left - 1) 100 | 101 | # s1-span 102 | if len(self.stack) < 2: 103 | lefts.append(1) 104 | rights.append(0) 105 | else: 106 | s1_left = self.stack[-2][0] + 1 107 | lefts.append(s1_left) 108 | s1_right = self.stack[-2][1] + 1 109 | rights.append(s1_right) 110 | 111 | # s0-span 112 | if len(self.stack) < 1: 113 | lefts.append(1) 114 | rights.append(0) 115 | else: 116 | s0_left = self.stack[-1][0] + 1 117 | lefts.append(s0_left) 118 | s0_right = self.stack[-1][1] + 1 119 | rights.append(s0_right) 120 | 121 | # post-s0-span 122 | lefts.append(self.i + 1) 123 | rights.append(self.n) 124 | 125 | return tuple(lefts), tuple(rights) 126 | 127 | 128 | 129 | def l_features(self): 130 | """ 131 | Features for predicting label action: 132 | (pre-s0-span, s0-span, post-s0-span) 133 | """ 134 | lefts = [] 135 | rights = [] 136 | 137 | # pre-s0-span 138 | lefts.append(1) 139 | if len(self.stack) < 1: 140 | rights.append(0) 141 | else: 142 | s0_left = self.stack[-1][0] + 1 143 | rights.append(s0_left - 1) 144 | 145 | 146 | # s0-span 147 | if len(self.stack) < 1: 148 | lefts.append(1) 149 | rights.append(0) 150 | else: 151 | s0_left = self.stack[-1][0] + 1 152 | lefts.append(s0_left) 153 | s0_right = self.stack[-1][1] + 1 154 | rights.append(s0_right) 155 | 156 | # post-s0-span 157 | lefts.append(self.i + 1) 158 | rights.append(self.n) 159 | 160 | return tuple(lefts), tuple(rights) 161 | 162 | 163 | def s_oracle(self, tree): 164 | """ 165 | Returns correct structural action in current (arbitrary) state, 166 | given gold tree. 167 | Deterministic (prefer combine). 168 | """ 169 | if not self.can_shift(): 170 | return 'comb' 171 | elif not self.can_combine(): 172 | return 'sh' 173 | else: 174 | (left0, right0, _) = self.stack[-1] 175 | a, _ = tree.enclosing(left0, right0) 176 | if a == left0: 177 | return 'sh' 178 | else: 179 | return 'comb' 180 | 181 | 182 | def l_oracle(self, tree): 183 | (left0, right0, _) = self.stack[-1] 184 | labels = tree.span_labels(left0, right0)[::-1] 185 | if len(labels) == 0: 186 | return 'none' 187 | else: 188 | return 'label-' + '-'.join(labels) 189 | 190 | 191 | @staticmethod 192 | def gold_actions(tree): 193 | n = len(tree.sentence) 194 | state = Parser(n) 195 | result = [] 196 | 197 | for step in range(2 * n - 1): 198 | 199 | if state.can_combine(): 200 | (left0, right0, _) = state.stack[-1] 201 | (left1, _, _) = state.stack[-2] 202 | a, b = tree.enclosing(left0, right0) 203 | if left1 >= a: 204 | result.append('comb') 205 | state.combine() 206 | else: 207 | result.append('sh') 208 | state.shift() 209 | else: 210 | result.append('sh') 211 | state.shift() 212 | 213 | (left0, right0, _) = state.stack[-1] 214 | labels = tree.span_labels(left0, right0)[::-1] 215 | if len(labels) == 0: 216 | result.append('none') 217 | else: 218 | result.append('label-' + '-'.join(labels)) 219 | state.label(labels) 220 | 221 | return result 222 | 223 | 224 | @staticmethod 225 | def training_data(tree): 226 | """ 227 | Using oracle (for gold sequence), omitting mandatory S-actions 228 | """ 229 | s_features = [] 230 | l_features = [] 231 | 232 | n = len(tree.sentence) 233 | state = Parser(n) 234 | result = [] 235 | 236 | for step in range(2 * n - 1): 237 | 238 | if not state.can_combine(): 239 | action = 'sh' 240 | elif not state.can_shift(): 241 | action = 'comb' 242 | else: 243 | action = state.s_oracle(tree) 244 | features = state.s_features() 245 | s_features.append((features, action)) 246 | state.take_action(action) 247 | 248 | 249 | action = state.l_oracle(tree) 250 | features = state.l_features() 251 | l_features.append((features, action)) 252 | state.take_action(action) 253 | 254 | return (s_features, l_features) 255 | 256 | 257 | 258 | @staticmethod 259 | def exploration(data, fm, network, alpha=1.0, beta=0): 260 | """ 261 | Only data from this parse, including mandatory S-actions. 262 | Follow softmax distribution for structural data. 263 | """ 264 | 265 | dynet.renew_cg() 266 | network.prep_params() 267 | 268 | struct_data = {} 269 | label_data = {} 270 | 271 | tree = data['tree'] 272 | sentence = tree.sentence 273 | 274 | n = len(sentence) 275 | state = Parser(n) 276 | 277 | w = data['w'] 278 | t = data['t'] 279 | fwd, back = network.evaluate_recurrent(w, t, test=True) 280 | 281 | for step in xrange(2 * n - 1): 282 | 283 | features = state.s_features() 284 | if not state.can_combine(): 285 | action = 'sh' 286 | correct_action = 'sh' 287 | elif not state.can_shift(): 288 | action = 'comb' 289 | correct_action = 'comb' 290 | else: 291 | 292 | correct_action = state.s_oracle(tree) 293 | 294 | r = np.random.random() 295 | if r < beta: 296 | action = correct_action 297 | else: 298 | left, right = features 299 | scores = network.evaluate_struct( 300 | fwd, 301 | back, 302 | left, 303 | right, 304 | test=True, 305 | ).npvalue() 306 | 307 | # sample from distribution 308 | exp = np.exp(scores * alpha) 309 | softmax = exp / (exp.sum()) 310 | r = np.random.random() 311 | 312 | if r <= softmax[0]: 313 | action = 'sh' 314 | else: 315 | action = 'comb' 316 | 317 | struct_data[features] = fm.s_action_index(correct_action) 318 | state.take_action(action) 319 | 320 | features = state.l_features() 321 | correct_action = state.l_oracle(tree) 322 | label_data[features] = fm.l_action_index(correct_action) 323 | 324 | r = np.random.random() 325 | if r < beta: 326 | action = correct_action 327 | else: 328 | left, right = features 329 | scores = network.evaluate_label( 330 | fwd, 331 | back, 332 | left, 333 | right, 334 | test=True, 335 | ).npvalue() 336 | if step < (2 * n - 2): 337 | action_index = np.argmax(scores) 338 | else: 339 | action_index = 1 + np.argmax(scores[1:]) 340 | action = fm.l_action(action_index) 341 | state.take_action(action) 342 | 343 | predicted = state.stack[0][2][0] 344 | predicted.propagate_sentence(sentence) 345 | accuracy = predicted.compare(tree) 346 | 347 | example = { 348 | 'w': w, 349 | 't': t, 350 | 'struct_data': struct_data, 351 | 'label_data': label_data, 352 | } 353 | 354 | return example, accuracy 355 | 356 | 357 | @staticmethod 358 | def parse(sentence, fm, network): 359 | 360 | dynet.renew_cg() 361 | network.prep_params() 362 | 363 | n = len(sentence) 364 | state = Parser(n) 365 | 366 | w, t = fm.sentence_sequences(sentence) 367 | 368 | fwd, back = network.evaluate_recurrent(w, t, test=True) 369 | 370 | for step in xrange(2 * n - 1): 371 | 372 | if not state.can_combine(): 373 | action = 'sh' 374 | elif not state.can_shift(): 375 | action = 'comb' 376 | else: 377 | left, right = state.s_features() 378 | scores = network.evaluate_struct( 379 | fwd, 380 | back, 381 | left, 382 | right, 383 | test=True, 384 | ).npvalue() 385 | action_index = np.argmax(scores) 386 | action = fm.s_action(action_index) 387 | state.take_action(action) 388 | 389 | 390 | left, right = state.l_features() 391 | scores = network.evaluate_label( 392 | fwd, 393 | back, 394 | left, 395 | right, 396 | test=True, 397 | ).npvalue() 398 | if step < (2 * n - 2): 399 | action_index = np.argmax(scores) 400 | else: 401 | action_index = 1 + np.argmax(scores[1:]) 402 | action = fm.l_action(action_index) 403 | state.take_action(action) 404 | 405 | if not state.finished(): 406 | raise RuntimeError('Bad ending state!') 407 | 408 | 409 | tree = state.stack[0][2][0] 410 | tree.propagate_sentence(sentence) 411 | return tree 412 | 413 | 414 | @staticmethod 415 | def evaluate_corpus(trees, fm, network): 416 | accuracy = FScore() 417 | for tree in trees: 418 | predicted = Parser.parse(tree.sentence, fm, network) 419 | local_accuracy = predicted.compare(tree) 420 | accuracy += local_accuracy 421 | return accuracy 422 | 423 | 424 | @staticmethod 425 | def write_predicted(fname, test_trees, fm, network): 426 | """ 427 | Input trees being used only to carry sentences. 428 | """ 429 | f = open(fname, 'w') 430 | for tree in test_trees: 431 | predicted = Parser.parse(tree.sentence, fm, network) 432 | topped = PhraseTree( 433 | symbol='TOP', 434 | children=[predicted], 435 | sentence=predicted.sentence, 436 | ) 437 | f.write(str(topped)) 438 | f.write('\n') 439 | f.close() 440 | 441 | 442 | -------------------------------------------------------------------------------- /src/phrase_tree.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Recursive representation of a phrase-structure parse tree 5 | for natural language sentences. 6 | """ 7 | 8 | from __future__ import print_function 9 | from __future__ import division 10 | 11 | from collections import defaultdict 12 | 13 | class PhraseTree(object): 14 | 15 | puncs = [",", ".", ":", "``", "''", "PU"] ## (COLLINS.prm) 16 | 17 | 18 | def __init__( 19 | self, 20 | symbol=None, 21 | children=[], 22 | sentence=[], 23 | leaf=None, 24 | ): 25 | self.symbol = symbol # label at top node 26 | self.children = children # list of PhraseTree objects 27 | self.sentence = sentence 28 | self.leaf = leaf # word at bottom level else None 29 | 30 | self._str = None 31 | 32 | 33 | def __str__(self): 34 | if self._str is None: 35 | if len(self.children) != 0: 36 | childstr = ' '.join(str(c) for c in self.children) 37 | self._str = '({} {})'.format(self.symbol, childstr) 38 | else: 39 | self._str = '({} {})'.format( 40 | self.sentence[self.leaf][1], 41 | self.sentence[self.leaf][0], 42 | ) 43 | return self._str 44 | 45 | 46 | def propagate_sentence(self, sentence): 47 | """ 48 | Recursively assigns sentence (list of (word, POS) pairs) 49 | to all nodes of a tree. 50 | """ 51 | self.sentence = sentence 52 | for child in self.children: 53 | child.propagate_sentence(sentence) 54 | 55 | 56 | def pretty(self, level=0, marker=' '): 57 | pad = marker * level 58 | 59 | if self.leaf is not None: 60 | leaf_string = '({} {})'.format( 61 | self.symbol, 62 | self.sentence[self.leaf][0], 63 | ) 64 | return pad + leaf_string 65 | 66 | else: 67 | result = pad + '(' + self.symbol 68 | for child in self.children: 69 | result += '\n' + child.pretty(level + 1) 70 | result += ')' 71 | return result 72 | 73 | 74 | @staticmethod 75 | def parse(line): 76 | """ 77 | Loads a tree from a tree in PTB parenthetical format. 78 | """ 79 | line += " " 80 | sentence = [] 81 | _, t = PhraseTree._parse(line, 0, sentence) 82 | 83 | if t.symbol == 'TOP' and len(t.children) == 1: 84 | t = t.children[0] 85 | 86 | return t 87 | 88 | 89 | @staticmethod 90 | def _parse(line, index, sentence): 91 | "((...) (...) w/t (...)). returns pos and tree, and carries sent out." 92 | 93 | assert line[index] == '(', "Invalid tree string {} at {}".format(line, index) 94 | index += 1 95 | symbol = None 96 | children = [] 97 | leaf = None 98 | while line[index] != ')': 99 | if line[index] == '(': 100 | index, t = PhraseTree._parse(line, index, sentence) 101 | children.append(t) 102 | 103 | else: 104 | if symbol is None: 105 | # symbol is here! 106 | rpos = min(line.find(' ', index), line.find(')', index)) 107 | # see above N.B. (find could return -1) 108 | 109 | symbol = line[index:rpos] # (word, tag) string pair 110 | 111 | index = rpos 112 | else: 113 | rpos = line.find(')', index) 114 | word = line[index:rpos] 115 | sentence.append((word, symbol)) 116 | leaf = len(sentence) - 1 117 | index = rpos 118 | 119 | if line[index] == " ": 120 | index += 1 121 | 122 | assert line[index] == ')', "Invalid tree string %s at %d" % (line, index) 123 | 124 | t = PhraseTree( 125 | symbol=symbol, 126 | children=children, 127 | sentence=sentence, 128 | leaf=leaf, 129 | ) 130 | 131 | return (index + 1), t 132 | 133 | 134 | def left_span(self): 135 | try: 136 | return self._left_span 137 | except AttributeError: 138 | if self.leaf is not None: 139 | self._left_span = self.leaf 140 | else: 141 | self._left_span = self.children[0].left_span() 142 | return self._left_span 143 | 144 | 145 | def right_span(self): 146 | try: 147 | return self._right_span 148 | except AttributeError: 149 | if self.leaf is not None: 150 | self._right_span = self.leaf 151 | else: 152 | self._right_span = self.children[-1].right_span() 153 | return self._right_span 154 | 155 | 156 | def brackets(self, advp_prt=True, counts=None): 157 | 158 | if counts is None: 159 | counts = defaultdict(int) 160 | 161 | if self.leaf is not None: 162 | return {} 163 | 164 | nonterm = self.symbol 165 | if advp_prt and nonterm=='PRT': 166 | nonterm = 'ADVP' 167 | 168 | left = self.left_span() 169 | right = self.right_span() 170 | 171 | # ignore punctuation 172 | while ( 173 | left < len(self.sentence) and 174 | self.sentence[left][1] in PhraseTree.puncs 175 | ): 176 | left += 1 177 | while ( 178 | right > 0 and self.sentence[right][1] in PhraseTree.puncs 179 | ): 180 | right -= 1 181 | 182 | if left <= right and nonterm != 'TOP': 183 | counts[(nonterm, left, right)] += 1 184 | 185 | for child in self.children: 186 | child.brackets(advp_prt=advp_prt, counts=counts) 187 | 188 | return counts 189 | 190 | 191 | def phrase(self): 192 | if self.leaf is not None: 193 | return [(self.leaf, self.symbol)] 194 | else: 195 | result = [] 196 | for child in self.children: 197 | result.extend(child.phrase()) 198 | return result 199 | 200 | 201 | @staticmethod 202 | def load_treefile(fname): 203 | trees = [] 204 | for line in open(fname): 205 | t = PhraseTree.parse(line) 206 | trees.append(t) 207 | return trees 208 | 209 | 210 | def compare(self, gold, advp_prt=True): 211 | """ 212 | returns (Precision, Recall, F-measure) 213 | """ 214 | predbracks = self.brackets(advp_prt) 215 | goldbracks = gold.brackets(advp_prt) 216 | 217 | correct = 0 218 | for gb in goldbracks: 219 | if gb in predbracks: 220 | correct += min(goldbracks[gb], predbracks[gb]) 221 | 222 | pred_total = sum(predbracks.values()) 223 | gold_total = sum(goldbracks.values()) 224 | 225 | return FScore(correct, pred_total, gold_total) 226 | 227 | 228 | def enclosing(self, i, j): 229 | """ 230 | Returns the left and right indices of the labeled span in the tree 231 | which is next-larger than (i, j) 232 | (whether or not (i, j) is itself a labeled span) 233 | """ 234 | for child in self.children: 235 | left = child.left_span() 236 | right = child.right_span() 237 | if (left <= i) and (right >= j): 238 | if (left == i) and (right == j): 239 | break 240 | return child.enclosing(i, j) 241 | 242 | return (self.left_span(), self.right_span()) 243 | 244 | 245 | def span_labels(self, i, j): 246 | """ 247 | Returns a list of span labels (if any) for (i, j) 248 | """ 249 | if self.leaf is not None: 250 | return [] 251 | 252 | if (self.left_span() == i) and (self.right_span() == j): 253 | result = [self.symbol] 254 | else: 255 | result = [] 256 | 257 | for child in self.children: 258 | left = child.left_span() 259 | right = child.right_span() 260 | if (left <= i) and (right >= j): 261 | result.extend(child.span_labels(i, j)) 262 | break 263 | 264 | return result 265 | 266 | 267 | 268 | 269 | class FScore(object): 270 | 271 | def __init__(self, correct=0, predcount=0, goldcount=0): 272 | self.correct = correct # correct brackets 273 | self.predcount = predcount # total predicted brackets 274 | self.goldcount = goldcount # total gold brackets 275 | 276 | 277 | def precision(self): 278 | if self.predcount > 0: 279 | return (100.0 * self.correct) / self.predcount 280 | else: 281 | return 0.0 282 | 283 | 284 | def recall(self): 285 | if self.goldcount > 0: 286 | return (100.0 * self.correct) / self.goldcount 287 | else: 288 | return 0.0 289 | 290 | 291 | def fscore(self): 292 | precision = self.precision() 293 | recall = self.recall() 294 | if (precision + recall) > 0: 295 | return (2 * precision * recall) / (precision + recall) 296 | else: 297 | return 0.0 298 | 299 | 300 | def __str__(self): 301 | precision = self.precision() 302 | recall = self.recall() 303 | fscore = self.fscore() 304 | return '(P= {:0.2f}, R= {:0.2f}, F= {:0.2f})'.format( 305 | precision, 306 | recall, 307 | fscore, 308 | ) 309 | 310 | 311 | def __iadd__(self, other): 312 | self.correct += other.correct 313 | self.predcount += other.predcount 314 | self.goldcount += other.goldcount 315 | return self 316 | 317 | 318 | def __add__(self, other): 319 | return Fmeasure(self.correct + other.correct, 320 | self.predcount + other.predcount, 321 | self.goldcount + other.goldcount) 322 | 323 | 324 | def __cmp__(self, other): 325 | return cmp(self.fscore(), other.fscore()) 326 | 327 | 328 | @staticmethod 329 | def parseval(gold_file, test_file): 330 | gold_trees = PhraseTree.load_treefile(gold_file) 331 | test_trees = PhraseTree.load_treefile(test_file) 332 | cumulative = FScore() 333 | 334 | for gold, test in zip(gold_trees, test_trees): 335 | acc = test.compare(gold, advp_prt=True) 336 | cumulative += acc 337 | 338 | return cumulative 339 | 340 | 341 | 342 | 343 | --------------------------------------------------------------------------------