├── src ├── __init__.py ├── embed_regularize.py ├── utility.py ├── demo.py ├── weight_drop.py ├── loss.py ├── ctb.py ├── model.py ├── dataloader.py ├── datacreate_ctb.py ├── datacreate_ptb.py ├── helpers.py └── dp.py ├── EVALB ├── evalb ├── Makefile ├── tgrep_proc.prl ├── LICENSE ├── sample │ ├── sample.tst │ ├── sample.gld │ ├── sample.prm │ └── sample.rsl ├── COLLINS.prm ├── new.prm ├── README └── evalb.c ├── .gitignore └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /EVALB/evalb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hantek/distance-parser/HEAD/EVALB/evalb -------------------------------------------------------------------------------- /EVALB/Makefile: -------------------------------------------------------------------------------- 1 | all: evalb 2 | 3 | evalb: evalb.c 4 | gcc -Wall -g -o evalb evalb.c 5 | -------------------------------------------------------------------------------- /EVALB/tgrep_proc.prl: -------------------------------------------------------------------------------- 1 | #!/usr/local/bin/perl 2 | 3 | while(<>) 4 | { 5 | if(m/TOP/) #skip lines which are blank 6 | { 7 | print; 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /src/embed_regularize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | 7 | def embedded_dropout(embed, words, dropout=0.1, scale=None): 8 | if dropout: 9 | mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout) 10 | mask = Variable(mask) 11 | masked_embed_weight = mask * embed.weight 12 | else: 13 | masked_embed_weight = embed.weight 14 | if scale: 15 | masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight 16 | 17 | padding_idx = embed.padding_idx 18 | if padding_idx is None: 19 | padding_idx = -1 20 | return F.embedding(words, masked_embed_weight, 21 | padding_idx, embed.max_norm, embed.norm_type, 22 | embed.scale_grad_by_freq, embed.sparse 23 | ) 24 | 25 | 26 | -------------------------------------------------------------------------------- /src/utility.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Filename: utility.py 3 | # Author:hankcs 4 | # Date: 2017-11-03 22:05 5 | import errno 6 | from os import makedirs 7 | 8 | import sys 9 | 10 | 11 | def make_sure_path_exists(path): 12 | try: 13 | makedirs(path) 14 | except OSError as exception: 15 | if exception.errno != errno.EEXIST: 16 | raise 17 | 18 | 19 | def eprint(*args, **kwargs): 20 | print(args) 21 | # print(*args, file=sys.stderr, **kwargs) 22 | 23 | 24 | def combine_files(fids, out, tb): 25 | print('%d files...' % len(fids)) 26 | total_sentence = 0 27 | for n, file in enumerate(fids): 28 | if n % 10 == 0 or n == len(fids) - 1: 29 | print("%c%.2f%%\r" % (13, (n + 1) / float(len(fids)) * 100), end='') 30 | sents = tb.parsed_sents(file) 31 | for s in sents: 32 | out.write(s.pformat(margin=sys.maxsize)) 33 | out.write(u'\n') 34 | total_sentence += 1 35 | print() 36 | print('%d sentences.' % total_sentence) 37 | print() 38 | 39 | -------------------------------------------------------------------------------- /EVALB/LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /EVALB/sample/sample.tst: -------------------------------------------------------------------------------- 1 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 2 | (S (A (P this)) (B (Q is) (C (R a) (T test)))) 3 | (S (A (P this)) (B (Q is) (A (R a) (U test)))) 4 | (S (C (P this)) (B (Q is) (A (R a) (U test)))) 5 | (S (A (P this)) (B (Q is) (R a) (A (T test)))) 6 | (S (A (P this) (Q is)) (A (R a) (T test))) 7 | (S (P this) (Q is) (R a) (T test)) 8 | (P this) (Q is) (R a) (T test) 9 | (S (A (P this)) (B (Q is) (A (A (R a) (T test))))) 10 | (S (A (P this)) (B (Q is) (A (A (A (A (A (R a) (T test)))))))) 11 | 12 | (S (A (P this)) (B (Q was) (A (A (R a) (T test))))) 13 | (S (A (P this)) (B (Q is) (U not) (A (A (R a) (T test))))) 14 | 15 | (TOP (S (A (P this)) (B (Q is) (A (R a) (T test))))) 16 | (S (A (P this)) (NONE *) (B (Q is) (A (R a) (T test)))) 17 | (S (A (P this)) (S (NONE abc) (A (NONE *))) (B (Q is) (A (R a) (T test)))) 18 | (S (A (P this)) (B (Q is) (A (R a) (TT test)))) 19 | (S (A (P This)) (B (Q is) (A (R a) (T test)))) 20 | (S (A (P That)) (B (Q is) (A (R a) (T test)))) 21 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 22 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test)))) 23 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (-NONE- *)) 24 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (: *)) 25 | -------------------------------------------------------------------------------- /EVALB/sample/sample.gld: -------------------------------------------------------------------------------- 1 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 2 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 3 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 4 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 5 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 6 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 7 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 8 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 9 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 10 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 11 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 12 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 13 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 14 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 15 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 16 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 17 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 18 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 19 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 20 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 21 | (S (A-SBJ-1 (P this)) (B-WHATEVER (Q is) (A (R a) (T test)))) 22 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test)))) 23 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (-NONE- *)) 24 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (: *)) 25 | -------------------------------------------------------------------------------- /src/demo.py: -------------------------------------------------------------------------------- 1 | from dp import * 2 | 3 | if __name__ == '__main__': 4 | print("building model...") 5 | model = distance_parser(vocab_size=args.vocab_size, 6 | embed_size=args.embedsz, 7 | hid_size=args.hidsz, 8 | arc_size=len(ptb_parsed.arc_dictionary), 9 | stag_size=len(ptb_parsed.stag_dictionary), 10 | window_size=args.window_size, 11 | dropout=args.dpout, 12 | dropoute=args.dpoute, 13 | dropoutr=args.dpoutr) 14 | if args.cuda: 15 | model.cuda() 16 | 17 | if os.path.isfile(parameter_filepath): 18 | print("Resuming from file: {}".format(parameter_filepath)) 19 | checkpoint = torch.load(parameter_filepath) 20 | start_epoch = checkpoint['epoch'] 21 | valid_precision = checkpoint['valid_precision'] 22 | valid_recall = checkpoint['valid_recall'] 23 | best_valid_f1 = checkpoint['valid_f1'] 24 | model.load_state_dict(checkpoint['model_state_dict']) 25 | print("loaded model: epoch {}, valid_loss {}, " 26 | "valid_precision {}, valid_recall {}, valid_f1 {}".format( 27 | start_epoch, checkpoint['valid_loss'], valid_precision, \ 28 | valid_recall, best_valid_f1)) 29 | 30 | print("Evaluating valid... ") 31 | valid_loss, valid_arc_prec, valid_tag_prec, \ 32 | valid_precision, valid_recall, valid_f1 = evaluate(model, ptb_parsed, 'valid') 33 | print("Evaluating test... ") 34 | test_loss, test_arc_prec, test_tag_prec, \ 35 | test_precision, test_recall, test_f1= evaluate(model, ptb_parsed, 'test') 36 | print(valid_log_template.format( 37 | start_epoch, 38 | ' ', valid_loss, valid_arc_prec, valid_tag_prec, 39 | valid_precision, valid_recall, valid_f1, 40 | ' ', test_loss, test_arc_prec, test_tag_prec, 41 | test_precision, test_recall, test_f1)) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # vim and gedit cache: 2 | *.swp 3 | *.swo 4 | *~ 5 | 6 | # cluster logs 7 | SMART_DISPATCH_LOGS/* 8 | 9 | # model params 10 | model/* 11 | params/* 12 | tmptrees/* 13 | 14 | # logs 15 | tblogs/* 16 | logs/* 17 | 18 | # data 19 | data/* 20 | 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | *.pyc 26 | 27 | # C extensions 28 | *.so 29 | 30 | # Distribution / packaging 31 | .Python 32 | env/ 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *.cover 68 | .hypothesis/ 69 | 70 | # Translations 71 | *.mo 72 | *.pot 73 | 74 | # Django stuff: 75 | *.log 76 | local_settings.py 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # celery beat schedule file 98 | celerybeat-schedule 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # dotenv 104 | .env 105 | 106 | # virtualenv 107 | .venv 108 | venv/ 109 | ENV/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | 124 | #pycharm 125 | .idea/ 126 | 127 | #pytorch 128 | *.pt 129 | -------------------------------------------------------------------------------- /src/weight_drop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | 4 | class WeightDrop(torch.nn.Module): 5 | def __init__(self, module, weights, dropout=0, variational=False): 6 | super(WeightDrop, self).__init__() 7 | self.module = module 8 | self.weights = weights 9 | self.dropout = dropout 10 | self.variational = variational 11 | self._setup() 12 | 13 | def widget_demagnetizer_y2k_edition(*args, **kwargs): 14 | # We need to replace flatten_parameters with a nothing function 15 | # It must be a function rather than a lambda as otherwise pickling explodes 16 | # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION! 17 | return 18 | 19 | def _setup(self): 20 | # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN 21 | if issubclass(type(self.module), torch.nn.RNNBase): 22 | self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition 23 | 24 | for name_w in self.weights: 25 | print('Applying weight drop of {} to {}'.format(self.dropout, name_w)) 26 | w = getattr(self.module, name_w) 27 | del self.module._parameters[name_w] 28 | self.module.register_parameter(name_w + '_raw', Parameter(w.data)) 29 | 30 | def _setweights(self): 31 | for name_w in self.weights: 32 | raw_w = getattr(self.module, name_w + '_raw') 33 | w = None 34 | if self.variational: 35 | mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1)) 36 | if raw_w.is_cuda: mask = mask.cuda() 37 | mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True) 38 | w = mask.expand_as(raw_w) * raw_w 39 | else: 40 | w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training) 41 | setattr(self.module, name_w, w) 42 | 43 | def forward(self, *args): 44 | self._setweights() 45 | return self.module.forward(*args) -------------------------------------------------------------------------------- /EVALB/sample/sample.prm: -------------------------------------------------------------------------------- 1 | ##------------------------------------------## 2 | ## Debug mode ## 3 | ## print out data for individual sentence ## 4 | ##------------------------------------------## 5 | DEBUG 0 6 | 7 | ##------------------------------------------## 8 | ## MAX error ## 9 | ## Number of error to stop the process. ## 10 | ## This is useful if there could be ## 11 | ## tokanization error. ## 12 | ## The process will stop when this number## 13 | ## of errors are accumulated. ## 14 | ##------------------------------------------## 15 | MAX_ERROR 10 16 | 17 | ##------------------------------------------## 18 | ## Cut-off length for statistics ## 19 | ## At the end of evaluation, the ## 20 | ## statistics for the senetnces of length## 21 | ## less than or equal to this number will## 22 | ## be shown, on top of the statistics ## 23 | ## for all the sentences ## 24 | ##------------------------------------------## 25 | CUTOFF_LEN 40 26 | 27 | ##------------------------------------------## 28 | ## unlabeled or labeled bracketing ## 29 | ## 0: unlabeled bracketing ## 30 | ## 1: labeled bracketing ## 31 | ##------------------------------------------## 32 | LABELED 1 33 | 34 | ##------------------------------------------## 35 | ## Delete labels ## 36 | ## list of labels to be ignored. ## 37 | ## If it is a pre-terminal label, delete ## 38 | ## the word along with the brackets. ## 39 | ## If it is a non-terminal label, just ## 40 | ## delete the brackets (don't delete ## 41 | ## deildrens). ## 42 | ##------------------------------------------## 43 | DELETE_LABEL TOP 44 | DELETE_LABEL -NONE- 45 | DELETE_LABEL , 46 | DELETE_LABEL : 47 | DELETE_LABEL `` 48 | DELETE_LABEL '' 49 | 50 | ##------------------------------------------## 51 | ## Delete labels for length calculation ## 52 | ## list of labels to be ignored for ## 53 | ## length calculation purpose ## 54 | ##------------------------------------------## 55 | DELETE_LABEL_FOR_LENGTH -NONE- 56 | 57 | 58 | ##------------------------------------------## 59 | ## Equivalent labels, words ## 60 | ## the pairs are considered equivalent ## 61 | ## This is non-directional. ## 62 | ##------------------------------------------## 63 | EQ_LABEL T TT 64 | 65 | EQ_WORD This this 66 | -------------------------------------------------------------------------------- /EVALB/COLLINS.prm: -------------------------------------------------------------------------------- 1 | ##------------------------------------------## 2 | ## Debug mode ## 3 | ## 0: No debugging ## 4 | ## 1: print data for individual sentence ## 5 | ##------------------------------------------## 6 | DEBUG 0 7 | 8 | ##------------------------------------------## 9 | ## MAX error ## 10 | ## Number of error to stop the process. ## 11 | ## This is useful if there could be ## 12 | ## tokanization error. ## 13 | ## The process will stop when this number## 14 | ## of errors are accumulated. ## 15 | ##------------------------------------------## 16 | MAX_ERROR 10 17 | 18 | ##------------------------------------------## 19 | ## Cut-off length for statistics ## 20 | ## At the end of evaluation, the ## 21 | ## statistics for the senetnces of length## 22 | ## less than or equal to this number will## 23 | ## be shown, on top of the statistics ## 24 | ## for all the sentences ## 25 | ##------------------------------------------## 26 | CUTOFF_LEN 40 27 | 28 | ##------------------------------------------## 29 | ## unlabeled or labeled bracketing ## 30 | ## 0: unlabeled bracketing ## 31 | ## 1: labeled bracketing ## 32 | ##------------------------------------------## 33 | LABELED 1 34 | 35 | ##------------------------------------------## 36 | ## Delete labels ## 37 | ## list of labels to be ignored. ## 38 | ## If it is a pre-terminal label, delete ## 39 | ## the word along with the brackets. ## 40 | ## If it is a non-terminal label, just ## 41 | ## delete the brackets (don't delete ## 42 | ## deildrens). ## 43 | ##------------------------------------------## 44 | DELETE_LABEL TOP 45 | DELETE_LABEL -NONE- 46 | DELETE_LABEL , 47 | DELETE_LABEL : 48 | DELETE_LABEL `` 49 | DELETE_LABEL '' 50 | DELETE_LABEL . 51 | 52 | ##------------------------------------------## 53 | ## Delete labels for length calculation ## 54 | ## list of labels to be ignored for ## 55 | ## length calculation purpose ## 56 | ##------------------------------------------## 57 | DELETE_LABEL_FOR_LENGTH -NONE- 58 | 59 | ##------------------------------------------## 60 | ## Equivalent labels, words ## 61 | ## the pairs are considered equivalent ## 62 | ## This is non-directional. ## 63 | ##------------------------------------------## 64 | EQ_LABEL ADVP PRT 65 | 66 | # EQ_WORD Example example 67 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def _assert_no_grad(variable): 6 | assert not variable.requires_grad, \ 7 | "nn criterions don't compute the gradient w.r.t. targets - please " \ 8 | "mark these variables as not requiring gradients" 9 | 10 | 11 | def ScaledRankLoss(input, target, mask, epsilon): 12 | """ 13 | scaled, single-sided L1 loss 14 | epsilon: parameter for scaling 15 | 16 | """ 17 | _assert_no_grad(target) 18 | 19 | diff = input[:, :, None] - input[:, None, :] 20 | target_diff_positive = ((target[:, :, None] - target[:, None, :]) > 0).float() 21 | target_diff_negative = - ((target[:, :, None] - target[:, None, :]) < 0).float() 22 | 23 | target_diff = target_diff_positive + target_diff_negative 24 | target_diff_zero = 1 - (target_diff_positive + (- target_diff_negative)) 25 | 26 | mask = mask[:, :, None] * mask[:, None, :] 27 | 28 | eepsilon = torch.exp(epsilon) 29 | loss = F.relu(eepsilon - target_diff * diff) + \ 30 | target_diff_zero * diff * diff / eepsilon ** 2 + \ 31 | 1 / eepsilon 32 | loss = (loss * mask).sum() / (mask.sum() + 1e-9) 33 | return loss 34 | 35 | 36 | # def rankloss(input, target, mask): 37 | # diff = (input[:, :, None] - input[:, None, :]) 38 | # ### eqloss: we modify the loss in the paper to account for "ties" 39 | # ### i.e. we don't train the ties. 40 | # target_sign = torch.sign(target[:, :, None] - target[:, None, :]).float() 41 | # mask = mask[:, :, None] * mask[:, None, :] 42 | # loss = F.relu(1. - target_sign * diff) 43 | # loss = (0.5 * loss * mask).sum() / (mask.sum() + 1e-9) 44 | # return loss 45 | 46 | 47 | def rankloss(input, target, mask, exp=False): 48 | diff = input[:, :, None] - input[:, None, :] 49 | target_diff = ((target[:, :, None] - target[:, None, :]) > 0).float() 50 | mask = mask[:, :, None] * mask[:, None, :] * target_diff 51 | 52 | if exp: 53 | loss = torch.exp(F.relu(target_diff - diff)) - 1 54 | else: 55 | loss = F.relu(target_diff - diff) 56 | loss = (loss * mask).sum() / (mask.sum() + 1e-9) 57 | 58 | return loss 59 | 60 | 61 | mse = torch.nn.MSELoss(reduce=False) 62 | 63 | 64 | def mseloss(input, target, mask): 65 | loss = mse(input, target) 66 | return (loss * mask).sum() / (mask.sum() + 1e-9) 67 | 68 | 69 | arcloss = torch.nn.CrossEntropyLoss(ignore_index=0) 70 | tagloss = torch.nn.CrossEntropyLoss(ignore_index=0) 71 | bce = torch.nn.BCELoss(size_average=False) 72 | 73 | 74 | def labelloss(input, target, mask): 75 | loss = bce(input * mask, target * mask) 76 | return loss / (mask.sum() + 1e-9) 77 | -------------------------------------------------------------------------------- /EVALB/sample/sample.rsl: -------------------------------------------------------------------------------- 1 | Sent. Matched Bracket Cross Correct Tag 2 | ID Len. Stat. Recal Prec. Bracket gold test Bracket Words Tags Accracy 3 | ============================================================================ 4 | 1 4 0 100.00 100.00 4 4 4 0 4 4 100.00 5 | 2 4 0 75.00 75.00 3 4 4 0 4 4 100.00 6 | 3 4 0 100.00 100.00 4 4 4 0 4 3 75.00 7 | 4 4 0 75.00 75.00 3 4 4 0 4 3 75.00 8 | 5 4 0 75.00 75.00 3 4 4 0 4 4 100.00 9 | 6 4 0 50.00 66.67 2 4 3 1 4 4 100.00 10 | 7 4 0 25.00 100.00 1 4 1 0 4 4 100.00 11 | 8 4 0 0.00 0.00 0 4 0 0 4 4 100.00 12 | 9 4 0 100.00 80.00 4 4 5 0 4 4 100.00 13 | 10 4 0 100.00 50.00 4 4 8 0 4 4 100.00 14 | 11 4 2 0.00 0.00 0 0 0 0 4 0 0.00 15 | 12 4 1 0.00 0.00 0 0 0 0 4 0 0.00 16 | 13 4 1 0.00 0.00 0 0 0 0 4 0 0.00 17 | 14 4 2 0.00 0.00 0 0 0 0 4 0 0.00 18 | 15 4 0 100.00 100.00 4 4 4 0 4 4 100.00 19 | 16 4 1 0.00 0.00 0 0 0 0 4 0 0.00 20 | 17 4 1 0.00 0.00 0 0 0 0 4 0 0.00 21 | 18 4 0 100.00 100.00 4 4 4 0 4 4 100.00 22 | 19 4 0 100.00 100.00 4 4 4 0 4 4 100.00 23 | 20 4 1 0.00 0.00 0 0 0 0 4 0 0.00 24 | 21 4 0 100.00 100.00 4 4 4 0 4 4 100.00 25 | 22 44 0 100.00 100.00 34 34 34 0 44 44 100.00 26 | 23 4 0 100.00 100.00 4 4 4 0 4 4 100.00 27 | 24 5 0 100.00 100.00 4 4 4 0 4 4 100.00 28 | ============================================================================ 29 | 87.76 90.53 86 98 95 16 108 106 98.15 30 | === Summary === 31 | 32 | -- All -- 33 | Number of sentence = 24 34 | Number of Error sentence = 5 35 | Number of Skip sentence = 2 36 | Number of Valid sentence = 17 37 | Bracketing Recall = 87.76 38 | Bracketing Precision = 90.53 39 | Complete match = 52.94 40 | Average crossing = 0.06 41 | No crossing = 94.12 42 | 2 or less crossing = 100.00 43 | Tagging accuracy = 98.15 44 | 45 | -- len<=40 -- 46 | Number of sentence = 23 47 | Number of Error sentence = 5 48 | Number of Skip sentence = 2 49 | Number of Valid sentence = 16 50 | Bracketing Recall = 81.25 51 | Bracketing Precision = 85.25 52 | Complete match = 50.00 53 | Average crossing = 0.06 54 | No crossing = 93.75 55 | 2 or less crossing = 100.00 56 | Tagging accuracy = 96.88 57 | -------------------------------------------------------------------------------- /EVALB/new.prm: -------------------------------------------------------------------------------- 1 | ##------------------------------------------## 2 | ## Debug mode ## 3 | ## 0: No debugging ## 4 | ## 1: print data for individual sentence ## 5 | ## 2: print detailed bracketing info ## 6 | ##------------------------------------------## 7 | DEBUG 0 8 | 9 | ##------------------------------------------## 10 | ## MAX error ## 11 | ## Number of error to stop the process. ## 12 | ## This is useful if there could be ## 13 | ## tokanization error. ## 14 | ## The process will stop when this number## 15 | ## of errors are accumulated. ## 16 | ##------------------------------------------## 17 | MAX_ERROR 10 18 | 19 | ##------------------------------------------## 20 | ## Cut-off length for statistics ## 21 | ## At the end of evaluation, the ## 22 | ## statistics for the senetnces of length## 23 | ## less than or equal to this number will## 24 | ## be shown, on top of the statistics ## 25 | ## for all the sentences ## 26 | ##------------------------------------------## 27 | CUTOFF_LEN 40 28 | 29 | ##------------------------------------------## 30 | ## unlabeled or labeled bracketing ## 31 | ## 0: unlabeled bracketing ## 32 | ## 1: labeled bracketing ## 33 | ##------------------------------------------## 34 | LABELED 1 35 | 36 | ##------------------------------------------## 37 | ## Delete labels ## 38 | ## list of labels to be ignored. ## 39 | ## If it is a pre-terminal label, delete ## 40 | ## the word along with the brackets. ## 41 | ## If it is a non-terminal label, just ## 42 | ## delete the brackets (don't delete ## 43 | ## deildrens). ## 44 | ##------------------------------------------## 45 | DELETE_LABEL TOP 46 | DELETE_LABEL S1 47 | DELETE_LABEL -NONE- 48 | DELETE_LABEL , 49 | DELETE_LABEL : 50 | DELETE_LABEL `` 51 | DELETE_LABEL '' 52 | DELETE_LABEL . 53 | DELETE_LABEL ? 54 | DELETE_LABEL ! 55 | 56 | ##------------------------------------------## 57 | ## Delete labels for length calculation ## 58 | ## list of labels to be ignored for ## 59 | ## length calculation purpose ## 60 | ##------------------------------------------## 61 | DELETE_LABEL_FOR_LENGTH -NONE- 62 | 63 | ##------------------------------------------## 64 | ## Labels to be considered for misquote ## 65 | ## (could be possesive or quote) ## 66 | ##------------------------------------------## 67 | QUOTE_LABEL `` 68 | QUOTE_LABEL '' 69 | QUOTE_LABEL POS 70 | 71 | ##------------------------------------------## 72 | ## These ones are less common, but ## 73 | ## are on occasion output by parsers: ## 74 | ##------------------------------------------## 75 | QUOTE_LABEL NN 76 | QUOTE_LABEL CD 77 | QUOTE_LABEL VBZ 78 | QUOTE_LABEL : 79 | 80 | ##------------------------------------------## 81 | ## Equivalent labels, words ## 82 | ## the pairs are considered equivalent ## 83 | ## This is non-directional. ## 84 | ##------------------------------------------## 85 | EQ_LABEL ADVP PRT 86 | 87 | # EQ_WORD Example example 88 | -------------------------------------------------------------------------------- /src/ctb.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Filename: ctb.py 3 | # Author: hantek, hankcs 4 | 5 | import os 6 | import argparse 7 | from os import listdir 8 | from os.path import isfile, join, isdir 9 | import nltk 10 | 11 | from utility import make_sure_path_exists, eprint, combine_files 12 | 13 | 14 | def convert(ctb_root, out_root): 15 | ctb_root = join(ctb_root, 'bracketed') 16 | fids = [f for f in listdir(ctb_root) if isfile(join(ctb_root, f)) and \ 17 | f.endswith('.nw') or \ 18 | f.endswith('.mz') or \ 19 | f.endswith('.wb')] 20 | make_sure_path_exists(out_root) 21 | 22 | for f in fids: 23 | with open(join(ctb_root, f), 'r') as src, \ 24 | open(join(out_root, f.split('.')[0] + '.fid'), 'w') as out: 25 | # encoding='GB2312' 26 | in_s_tag = False 27 | try: 28 | for line in src: 29 | if line.startswith('') or line.startswith(''): 32 | in_s_tag = False 33 | elif line.startswith('<'): 34 | continue 35 | elif in_s_tag and len(line) > 1: 36 | out.write(line) 37 | except: 38 | pass 39 | 40 | 41 | def combine_fids(fids, out_path): 42 | print('Generating ' + out_path) 43 | files = [] 44 | for fid in fids: 45 | f = 'chtb_%04d.fid' % fid 46 | if isfile(join(ctb_in_nltk, f)): 47 | files.append(f) 48 | with open(out_path, 'w') as out: 49 | combine_files(files, out, ctb) 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser( 54 | description='Combine Chinese Treebank 5.1 fid files into train/dev/test set') 55 | parser.add_argument("--ctb", required=True, 56 | help='The root path to Chinese Treebank 5.1') 57 | parser.add_argument("--output", required=True, 58 | help='The folder where to store the output train.txt/dev.txt/test.txt') 59 | 60 | args = parser.parse_args() 61 | 62 | ctb_in_nltk = None 63 | for root in nltk.data.path: 64 | if isdir(root): 65 | ctb_in_nltk = root 66 | 67 | if ctb_in_nltk is None: 68 | eprint('You should run nltk.download(\'ptb\') to fetch some data first!') 69 | exit(1) 70 | 71 | ctb_in_nltk = join(ctb_in_nltk, 'corpora') 72 | ctb_in_nltk = join(ctb_in_nltk, 'ctb') 73 | 74 | print('Converting CTB: removing xml tags...') 75 | convert(args.ctb, ctb_in_nltk) 76 | print('Importing to nltk...\n') 77 | from nltk.corpus import BracketParseCorpusReader, LazyCorpusLoader 78 | 79 | ctb = LazyCorpusLoader( 80 | 'ctb', BracketParseCorpusReader, r'chtb_.*\.*', 81 | tagset='unknown') 82 | 83 | training = list(range(1, 270 + 1)) + list(range(440, 1151 + 1)) 84 | development = list(range(301, 325 + 1)) 85 | test = list(range(271, 300 + 1)) 86 | 87 | root_path = args.output 88 | if not os.path.isdir(root_path): 89 | os.mkdir(root_path) 90 | combine_fids(training, join(root_path, 'train.txt')) 91 | combine_fids(development, join(root_path, 'dev.txt')) 92 | combine_fids(test, join(root_path, 'test.txt')) 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distance Parser 2 | Distance parser is a supervised constituency parser based on syntactic distance. 3 | This repo is a working sample of distance parser which reproduces the results reported in the paper 4 | [Straight to the Tree: Constituency Parsing with Neural Syntactic Distance](https://arxiv.org/abs/1806.04168), 5 | which is published in ACL 2018. We provide models with proper configurations for PTB and CTB datasets, as well as their preprocessing scripts. 6 | 7 | ## Requirements 8 | [PyTorch](https://pytorch.org/) We use PyTorch 0.4.0 with python 3.6. 9 | [Stanford POS tagger](https://nlp.stanford.edu/software/stanford-postagger-full-2018-02-27.zip). We use the full Stanford Tagger, version 3.9.1, build 2018-02-27. 10 | [NLTK](http://www.nltk.org/) We use NLTK 3.2.5. 11 | [EVALB](https://nlp.cs.nyu.edu/evalb/) We have integrated a compiled EVALB inside our repo. This compiled version is forked from the current latest verison of EVALB, which can be accessed through [this link](https://nlp.cs.nyu.edu/evalb/EVALB.tgz). 12 | 13 | ## Datasets and Preprocessing 14 | 15 | ### Preprocessing PTB 16 | We use the same preprocessed PTB files from the [self attentive parser](https://github.com/nikitakit/self-attentive-parser) repo. [GloVe embeddings](https://nlp.stanford.edu/projects/glove/) are optional if you don't want to run the ablation experiments. 17 | 18 | To preprocess PTB, please follow the steps below: 19 | 20 | 1. Download the 3 PTB data files from https://github.com/nikitakit/self-attentive-parser/tree/master/data, and put them in the `data/ptb` folder. 21 | 22 | 2. Run the following command to prepare the PTB data: 23 | ``` 24 | python datacreate_ptb.py ../data/ptb /path/to/glove.840B.300d.txt 25 | ``` 26 | 27 | ### Preprocessing CTB 28 | We use the standard train/valid/test split specified in [Liu and Zhang, 2017](https://arxiv.org/pdf/1707.05000.pdf) for our CTB experiments. 29 | 30 | To preprocess the CTB, please follow the steps below: 31 | 32 | 1. Download and unzip the Chinese Treebank dataset from https://wakespace.lib.wfu.edu/handle/10339/39379 33 | 34 | 2. If you don't have any corpus data in NLTK before, download some to initialize your `nltk_data` folder, such as: 35 | ``` 36 | python -c "import nltk; nltk.download('ptb')" 37 | ``` 38 | 39 | 3. Run the following command to link the dataset to NLTK, and generate the train/valid/test split in the repo: 40 | ``` 41 | python ctb.py --ctb /path/to/your/ctb8.0/data --output data/ctb_liusplit 42 | ``` 43 | 44 | 4. Integrate the Stanford Tagger for data preprocessing. Download the Stanford tagger from https://nlp.stanford.edu/software/stanford-postagger-full-2018-02-27.zip and unzip it. 45 | 46 | 5. Run the following command to generate the preprocessed files: 47 | ``` 48 | python datacreate_ctb.py ../data/ctb_liusplit /pth/to/stanford/tagger/ 49 | ``` 50 | 51 | ## Experiments 52 | For reproducing the PTB results in table 1, run 53 | ``` 54 | cd src 55 | python dp.py --cuda --datapath ../data/ptb --savepath ../ptbresults --epc 200 --lr 0.001 --bthsz 20 --hidsz 1200 --embedsz 400 --window_size 2 --dpout 0.3 --dpoute 0.1 --dpoutr 0.2 --weight_decay 1e-6 56 | ``` 57 | 58 | For reproducing the CTB results in table 2, run 59 | ``` 60 | cd src 61 | python dp.py --cuda --datapath ../data/ctb_liusplit --savepath ../ctbresults --epc 200 --lr 0.001 --bthsz 20 --hidsz 1200 --embedsz 400 --window_size 2 --dpout 0.4 --dpoute 0.1 --dpoutr 0.1 --weight_decay 1e-6 62 | ``` 63 | 64 | ## Pre-trained models 65 | We provide pre-trained models for the convenience of users. The following steps download the two pre-trained models to your repo: 66 | ``` 67 | mkdir results/ 68 | cd results/ 69 | wget http://lisaweb.iro.umontreal.ca/transfert/lisa/users/linzhou/distance_parser_pretrained_model/ctb.th 70 | wget http://lisaweb.iro.umontreal.ca/transfert/lisa/users/linzhou/distance_parser_pretrained_model/ptb.th 71 | ``` 72 | To re-evaluate the pre-trained models, run: 73 | ``` 74 | cd src/ 75 | python demo.py --cuda --datapath ../data/ptb/ --filename ptb # this command reproduces the 92.0 F1 score for PTB 76 | python demo.py --cuda --datapath ../data/ctb_liusplit/ --filename ctb # this command reproduces the 86.5 F1 score for CTB 77 | ``` 78 | Note that the file has to be in the `results` folder inorder for the `demo.py` script to load it automatically. 79 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 5 | 6 | from embed_regularize import embedded_dropout 7 | from weight_drop import WeightDrop 8 | 9 | 10 | class Shuffle(nn.Module): 11 | def __init__(self, permutation, contiguous=True): 12 | super(Shuffle, self).__init__() 13 | self.permutation = permutation 14 | self.contiguous = contiguous 15 | 16 | def forward(self, input): 17 | shuffled = input.permute(*self.permutation) 18 | if self.contiguous: 19 | return shuffled.contiguous() 20 | else: 21 | return shuffled 22 | 23 | 24 | class LayerNormalization(nn.Module): 25 | ''' Layer normalization module ''' 26 | 27 | def __init__(self, d_hid, eps=1e-3): 28 | super(LayerNormalization, self).__init__() 29 | 30 | self.eps = eps 31 | self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True) 32 | self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True) 33 | 34 | def forward(self, z): 35 | if z.size(1) == 1: 36 | return z 37 | 38 | mu = torch.mean(z, keepdim=True, dim=-1) 39 | sigma = torch.std(z, keepdim=True, dim=-1) 40 | ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps) 41 | ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out) 42 | 43 | return ln_out 44 | 45 | 46 | class distance_parser(nn.Module): 47 | def __init__(self, 48 | vocab_size, embed_size, hid_size, 49 | arc_size, stag_size, window_size, 50 | wordembed=None, dropout=0.2, dropoute=0.1, dropoutr=0.1): 51 | super(distance_parser, self).__init__() 52 | self.vocab_size = vocab_size 53 | self.embed_size = embed_size 54 | self.hid_size = hid_size 55 | self.arc_size = arc_size 56 | self.stag_size = stag_size 57 | self.window_size = window_size 58 | self.drop = nn.Dropout(dropout) 59 | self.dropoute = dropoute 60 | self.dropoutr = dropoutr 61 | self.encoder = nn.Embedding(vocab_size, embed_size) 62 | if wordembed is not None: 63 | self.encoder.weight.data = torch.FloatTensor(wordembed) 64 | 65 | self.tag_encoder = nn.Embedding(stag_size, embed_size) 66 | 67 | self.word_rnn = nn.LSTM(2 * embed_size, hid_size, num_layers=2, batch_first=True, dropout=dropout, 68 | bidirectional=True) 69 | self.word_rnn = WeightDrop(self.word_rnn, ['weight_hh_l0', 'weight_hh_l1'], dropout=dropoutr) 70 | 71 | self.conv1 = nn.Sequential(nn.Dropout(dropout), 72 | nn.Conv1d(hid_size * 2, 73 | hid_size, 74 | window_size), 75 | nn.ReLU()) 76 | 77 | self.arc_rnn = nn.LSTM(hid_size, hid_size, num_layers=2, batch_first=True, dropout=dropout, 78 | bidirectional=True) 79 | self.arc_rnn = WeightDrop(self.arc_rnn, ['weight_hh_l0', 'weight_hh_l1'], dropout=dropoutr) 80 | 81 | self.distance = nn.Sequential( 82 | nn.Dropout(dropout), 83 | nn.Linear(hid_size * 2, hid_size), 84 | nn.ReLU(), 85 | nn.Dropout(dropout), 86 | nn.Linear(hid_size, 1), 87 | ) 88 | 89 | self.terminal = nn.Sequential( 90 | nn.Dropout(dropout), 91 | nn.Linear(hid_size * 2, hid_size), 92 | nn.ReLU(), 93 | ) 94 | 95 | self.non_terminal = nn.Sequential( 96 | nn.Dropout(dropout), 97 | nn.Linear(hid_size * 2, hid_size), 98 | nn.ReLU(), 99 | ) 100 | 101 | self.arc = nn.Sequential( 102 | nn.Dropout(dropout), 103 | nn.Linear(hid_size, arc_size), 104 | ) 105 | 106 | def forward(self, words, stag, mask): 107 | """ 108 | tokens: Variable of LongTensor, shape (bsize, ntoken,) 109 | mock_emb: mock embedding for convolution overhead 110 | """ 111 | 112 | bsz, ntoken = words.size() 113 | emb_words = embedded_dropout(self.encoder, words, dropout=self.dropoute if self.training else 0) 114 | emb_words = self.drop(emb_words) 115 | 116 | emb_stags = embedded_dropout(self.tag_encoder, stag, dropout=self.dropoute if self.training else 0) 117 | emb_stags = self.drop(emb_stags) 118 | 119 | 120 | def run_rnn(input, rnn, lengths): 121 | sorted_idx = numpy.argsort(lengths)[::-1].tolist() 122 | rnn_input = pack_padded_sequence(input[sorted_idx], lengths[sorted_idx], batch_first=True) 123 | rnn_out, _ = rnn(rnn_input) # (bsize, ntoken, hidsize*2) 124 | rnn_out, _ = pad_packed_sequence(rnn_out, batch_first=True) 125 | rnn_out = rnn_out[numpy.argsort(sorted_idx).tolist()] 126 | 127 | return rnn_out 128 | 129 | sent_lengths = (mask.sum(dim=1)).data.cpu().numpy().astype('int') 130 | dst_lengths = sent_lengths - 1 131 | emb_plus_tag = torch.cat([emb_words, emb_stags], dim=-1) 132 | 133 | rnn1_out = run_rnn(emb_plus_tag, self.word_rnn, sent_lengths) 134 | 135 | terminal = self.terminal(rnn1_out.view(-1, self.hid_size*2)) 136 | tag = self.arc(terminal) # (bsize, ndst, tagsize) 137 | 138 | conv_out = self.conv1(rnn1_out.permute(0, 2, 1)).permute(0, 2, 1) # (bsize, ndst, hidsize) 139 | rnn2_out = run_rnn(conv_out, self.arc_rnn, dst_lengths) 140 | 141 | non_terminal = self.non_terminal(rnn2_out.view(-1, self.hid_size*2)) 142 | distance = self.distance(rnn2_out.view(-1, self.hid_size*2)).squeeze(dim=-1) # (bsize, ndst) 143 | arc = self.arc(non_terminal) # (bsize, ndst, arcsize) 144 | return distance.view(bsz, ntoken - 1), arc.contiguous().view(-1, self.arc_size), tag.view(-1, self.arc_size) 145 | -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import torch 5 | 6 | from helpers import * 7 | 8 | 9 | class Dictionary(object): 10 | def __init__(self): 11 | self.word2idx = {'': 0} 12 | self.idx2word = [''] 13 | self.word2frq = {} 14 | 15 | def add_word(self, word): 16 | if word not in self.word2idx: 17 | self.idx2word.append(word) 18 | self.word2idx[word] = len(self.idx2word) - 1 19 | if word not in self.word2frq: 20 | self.word2frq[word] = 1 21 | else: 22 | self.word2frq[word] += 1 23 | return self.word2idx[word] 24 | 25 | def __len__(self): 26 | return len(self.idx2word) 27 | 28 | def __getitem__(self, item): 29 | if self.word2idx.has_key(item): 30 | return self.word2idx[item] 31 | else: 32 | return self.word2idx[''] 33 | 34 | def rebuild_by_freq(self, thd=3): 35 | self.word2idx = {'': 0} 36 | self.idx2word = [''] 37 | 38 | for k, v in self.word2frq.iteritems(): 39 | if v >= thd and (not k in self.idx2word): 40 | self.idx2word.append(k) 41 | self.word2idx[k] = len(self.idx2word) - 1 42 | 43 | print('Number of words:', len(self.idx2word)) 44 | return len(self.idx2word) 45 | 46 | def class_weight(self): 47 | frq = [self.word2frq[self.idx2word[i]] for i in range(len(self.idx2word))] 48 | frq = numpy.array(frq).astype('float') 49 | weight = numpy.sqrt(frq.max() / frq) 50 | weight = numpy.clip(weight, a_min=0., a_max=5.) 51 | 52 | return weight 53 | 54 | 55 | class PTBLoader(object): 56 | '''Data path is assumed to be a directory with 57 | pkl files and a corpora subdirectory. 58 | ''' 59 | def __init__(self, data_path=None, use_glove=False): 60 | assert data_path is not None 61 | # make path available for nltk 62 | nltk.data.path.append(data_path) 63 | dict_filepath = os.path.join(data_path, 'dict.pkl') 64 | data_filepath = os.path.join(data_path, 'parsed.pkl') 65 | 66 | print("loading dictionary ...") 67 | self.dictionary = pickle.load(open(dict_filepath, "rb")) 68 | if use_glove: 69 | glove_filepath = os.path.join(data_path, 'ptb_glove.npy') 70 | print("loading preprocessed glove file ...") 71 | f_we = open(glove_filepath, 'rb') 72 | self.wordembed_matrix = numpy.load(f_we) 73 | f_we.close() 74 | else: 75 | self.wordembed_matrix = None 76 | 77 | # build tree and distance 78 | print("loading tree and distance ...") 79 | file_data = open(data_filepath, 'rb') 80 | self.train, self.arc_dictionary, self.stag_dictionary = pickle.load(file_data) 81 | self.valid = pickle.load(file_data) 82 | self.test = pickle.load(file_data) 83 | file_data.close() 84 | 85 | def batchify(self, dataname, batch_size, sort=False): 86 | sents, trees = None, None 87 | if dataname == 'train': 88 | idxs, tags, stags, arcs, distances, sents, trees = self.train 89 | elif dataname == 'valid': 90 | idxs, tags, stags, arcs, distances, _, _ = self.valid 91 | elif dataname == 'test': 92 | idxs, tags, stags, arcs, distances, _, _ = self.test 93 | else: 94 | raise 'need a correct dataname' 95 | 96 | assert len(idxs) == len(distances) 97 | assert len(idxs) == len(tags) 98 | 99 | bachified_idxs, bachified_tags, bachified_stags, \ 100 | bachified_arcs, bachified_dsts, \ 101 | = [], [], [], [], [] 102 | bachified_sents, bachified_trees = [], [] 103 | for i in range(0, len(idxs), batch_size): 104 | if i + batch_size >= len(idxs): continue 105 | 106 | if sents is not None: 107 | bachified_sents.append(sents[i: i + batch_size]) 108 | bachified_trees.append(trees[i: i + batch_size]) 109 | 110 | extracted_idxs = idxs[i: i + batch_size] 111 | extracted_tags = tags[i: i + batch_size] 112 | extracted_stags = stags[i: i + batch_size] 113 | 114 | extracted_arcs = arcs[i: i + batch_size] 115 | extracted_dsts = distances[i: i + batch_size] 116 | 117 | longest_idx = max([len(i) for i in extracted_idxs]) 118 | longest_arc = longest_idx - 1 119 | 120 | minibatch_idxs, minibatch_tags, minibatch_stags, \ 121 | minibatch_arcs, minibatch_dsts, \ 122 | = [], [], [], [], [] 123 | for idx, tag, stag, \ 124 | arc, dst \ 125 | in zip(extracted_idxs, extracted_tags, extracted_stags, 126 | extracted_arcs, extracted_dsts): 127 | padded_idx = idx + [-1] * (longest_idx - len(idx)) 128 | padded_tag = tag + [0] * (longest_idx - len(tag)) 129 | padded_stag = stag + [0] * (longest_idx - len(stag)) 130 | 131 | padded_arc = arc + [0] * (longest_arc - len(arc)) 132 | padded_dst = dst + [0] * (longest_arc - len(dst)) 133 | 134 | minibatch_idxs.append(padded_idx) 135 | minibatch_tags.append(padded_tag) 136 | minibatch_stags.append(padded_stag) 137 | 138 | minibatch_arcs.append(padded_arc) 139 | minibatch_dsts.append(padded_dst) 140 | 141 | minibatch_idxs = torch.LongTensor(minibatch_idxs) 142 | minibatch_tags = torch.LongTensor(minibatch_tags) 143 | minibatch_stags = torch.LongTensor(minibatch_stags) 144 | 145 | minibatch_arcs = torch.LongTensor(minibatch_arcs) 146 | minibatch_dsts = torch.FloatTensor(minibatch_dsts) 147 | 148 | bachified_idxs.append(minibatch_idxs) 149 | bachified_tags.append(minibatch_tags) 150 | bachified_stags.append(minibatch_stags) 151 | 152 | bachified_arcs.append(minibatch_arcs) 153 | bachified_dsts.append(minibatch_dsts) 154 | 155 | if sents is not None: 156 | return bachified_idxs, bachified_tags, bachified_stags, \ 157 | bachified_arcs, bachified_dsts, \ 158 | bachified_sents, bachified_trees 159 | return bachified_idxs, bachified_tags, bachified_stags, \ 160 | bachified_arcs, bachified_dsts 161 | 162 | -------------------------------------------------------------------------------- /src/datacreate_ctb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from nltk.tree import Tree 5 | from nltk.tag import StanfordPOSTagger 6 | 7 | from helpers import * 8 | 9 | 10 | 11 | def load_trees(path, strip_top=True, strip_spmrl_features=True): 12 | trees = [] 13 | with open(path) as infile: 14 | for line in infile: 15 | trees.append(Tree.fromstring(line)) 16 | 17 | if strip_top: 18 | for i, tree in enumerate(trees): 19 | if tree.label() in ("TOP", "ROOT"): 20 | assert len(tree) == 1 21 | trees[i] = tree[0] 22 | return trees 23 | 24 | 25 | class CTBCreator(object): 26 | '''Data path is assumed to be a directory with 27 | pkl files and a corpora subdirectory. 28 | ''' 29 | def __init__(self, 30 | wordembed_dim=300, 31 | embeddingstd=0.1, 32 | data_path=None, 33 | tagger_path=None): 34 | assert data_path is not None 35 | assert tagger_path is not None 36 | dict_filepath = os.path.join(data_path, 'dict.pkl') 37 | data_filepath = os.path.join(data_path, 'parsed.pkl') 38 | train_filepath = os.path.join(data_path, "train.txt") 39 | valid_filepath = os.path.join(data_path, "dev.txt") 40 | test_filepath = os.path.join(data_path, "test.txt") 41 | 42 | self.st = StanfordPOSTagger(os.path.join(tagger_path, 'models/chinese-distsim.tagger'), 43 | os.path.join(tagger_path, 'stanford-postagger.jar')) 44 | 45 | print("building dictionary ...") 46 | f_dict = open(dict_filepath, 'wb') 47 | self.dictionary = Dictionary() 48 | 49 | print("loading trees from {}".format(train_filepath)) 50 | train_trees = load_trees(train_filepath) 51 | print("loading trees from {}".format(valid_filepath)) 52 | valid_trees = load_trees(valid_filepath) 53 | print("loading trees from {}".format(test_filepath)) 54 | test_trees = load_trees(test_filepath) 55 | 56 | self.add_words(train_trees) 57 | self.dictionary.rebuild_by_freq() 58 | self.arc_dictionary = Dictionary() 59 | self.stag_dictionary = Dictionary() 60 | self.train = self.preprocess(train_trees, is_train=True) 61 | self.valid = self.preprocess(valid_trees, is_train=False) 62 | self.test = self.preprocess(test_trees, is_train=False) 63 | with open(dict_filepath, "wb") as file_dict: 64 | pickle.dump(self.dictionary, file_dict) 65 | with open(data_filepath, "wb") as file_data: 66 | pickle.dump((self.train, self.arc_dictionary, 67 | self.stag_dictionary), file_data) 68 | pickle.dump(self.valid, file_data) 69 | pickle.dump(self.test, file_data) 70 | 71 | print(len(self.arc_dictionary.idx2word)) 72 | print(self.arc_dictionary.idx2word) 73 | 74 | def add_words(self, trees): 75 | words, tags = [], [] 76 | for tree in trees: 77 | tree = process_NONE(tree) 78 | words, tags = zip(*tree.pos()) 79 | words = [''] + list(words) + [''] 80 | for w in words: 81 | self.dictionary.add_word(w) 82 | 83 | def preprocess(self, parse_trees, is_train=False): 84 | sens_idx = [] 85 | sens_tag = [] 86 | sens_stag = [] 87 | sens_arc = [] 88 | distances = [] 89 | sens = [] 90 | trees = [] 91 | 92 | print('\nConverting trees ...') 93 | for i, tree in enumerate(parse_trees): 94 | tree = process_NONE(tree) 95 | if i % 10 == 0: 96 | print("Done %d/%d\r" % (i, len(parse_trees)), end='') 97 | word_lexs, _ = zip(*tree.pos()) 98 | idx = [] 99 | for word in ([''] + list(word_lexs) + ['']): 100 | idx.append(self.dictionary[word]) 101 | 102 | listerized_tree, arcs, tags = tree2list(tree) 103 | tags = [''] + tags + [''] 104 | arcs = [''] + arcs + [''] 105 | 106 | if type(listerized_tree) is str: 107 | listerized_tree = [listerized_tree] 108 | distances_sent, _ = distance(listerized_tree) 109 | distances_sent = [0] + distances_sent + [0] 110 | 111 | idx_arcs = [] 112 | for arc in arcs: 113 | arc = precess_arc(arc) 114 | arc_id = self.arc_dictionary.add_word(arc) if is_train else self.arc_dictionary[arc] 115 | idx_arcs.append(arc_id) 116 | 117 | # the "tags" are the collapsed unary chains, i.e. FRAG+DT 118 | # at evaluation, we swap the word tag "DT" with the true tag in "stags" (see after) 119 | idx_tags = [] 120 | for tag in tags: 121 | tag = precess_arc(tag) 122 | tag_id = self.arc_dictionary.add_word(tag) if is_train else self.arc_dictionary[tag] 123 | idx_tags.append(tag_id) 124 | 125 | assert len(distances_sent) == len(idx) - 1 126 | assert len(arcs) == len(idx) - 1 127 | assert len(idx) == len(word_lexs) + 2 128 | 129 | sens.append(word_lexs) 130 | trees.append(tree) 131 | sens_idx.append(idx) 132 | sens_tag.append(idx_tags) 133 | sens_arc.append(idx_arcs) 134 | distances.append(distances_sent) 135 | 136 | print('\nLabelling POS tags ...') 137 | st_outputs = self.st.tag_sents(sens) 138 | for i, word_tags in enumerate(st_outputs): 139 | if i % 10 == 0: 140 | print("Done %d/%d\r" % (i, len(parse_trees)), end='') 141 | word_tags = [t[1].split('#')[1] for t in word_tags] 142 | stags = [''] + list(word_tags) + [''] 143 | 144 | # the "stags" are the original word tags included in the data files 145 | # we keep track of them so that, during evaluation, we can swap them with the original ones. 146 | idx_stags = [] 147 | for stag in stags: 148 | stag_id = self.stag_dictionary.add_word(stag) if is_train else self.stag_dictionary[stag] 149 | idx_stags.append(stag_id) 150 | 151 | sens_stag.append(idx_stags) 152 | 153 | return sens_idx, sens_tag, sens_stag, \ 154 | sens_arc, distances, sens, trees 155 | 156 | 157 | if __name__ == '__main__': 158 | import sys 159 | CTBCreator(data_path=sys.argv[1], tagger_path=sys.argv[2]) 160 | -------------------------------------------------------------------------------- /src/datacreate_ptb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from nltk.tree import Tree 5 | 6 | from helpers import * 7 | 8 | 9 | def load_trees(path, strip_top=True, strip_spmrl_features=True): 10 | trees = [] 11 | with open(path) as infile: 12 | for line in infile: 13 | trees.append(Tree.fromstring(line)) 14 | 15 | if strip_top: 16 | for i, tree in enumerate(trees): 17 | if tree.label() in ("TOP", "ROOT"): 18 | assert len(tree) == 1 19 | trees[i] = tree[0] 20 | return trees 21 | 22 | 23 | class PTBCreator(object): 24 | '''Data path is assumed to be a directory with 25 | pkl files and a corpora subdirectory. 26 | ''' 27 | def __init__(self, 28 | wordembed_dim=300, 29 | embeddingstd=0.1, 30 | data_path=None, 31 | glove_path=None): 32 | assert data_path is not None 33 | dict_filepath = os.path.join(data_path, 'dict.pkl') 34 | data_filepath = os.path.join(data_path, 'parsed.pkl') 35 | train_filepath = os.path.join(data_path, "02-21.10way.clean") 36 | valid_filepath = os.path.join(data_path, "22.auto.clean") 37 | test_filepath = os.path.join(data_path, "23.auto.clean") 38 | embed_filepath = os.path.join(data_path, "ptb_glove.npy") #'../data/ptb/ptb_glove.npy' 39 | 40 | print("building dictionary ...") 41 | f_dict = open(dict_filepath, 'wb') 42 | self.dictionary = Dictionary() 43 | 44 | print("loading trees from {}".format(train_filepath)) 45 | train_trees = load_trees(train_filepath) 46 | print("loading trees from {}".format(valid_filepath)) 47 | valid_trees = load_trees(valid_filepath) 48 | print("loading trees from {}".format(test_filepath)) 49 | test_trees = load_trees(test_filepath) 50 | 51 | self.add_words(train_trees) 52 | self.dictionary.rebuild_by_freq() 53 | self.arc_dictionary = Dictionary() 54 | self.stag_dictionary = Dictionary() 55 | self.train = self.preprocess(train_trees, is_train=True) 56 | self.valid = self.preprocess(valid_trees, is_train=False) 57 | self.test = self.preprocess(test_trees, is_train=False) 58 | with open(dict_filepath, "wb") as file_dict: 59 | pickle.dump(self.dictionary, file_dict) 60 | with open(data_filepath, "wb") as file_data: 61 | pickle.dump((self.train, self.arc_dictionary, 62 | self.stag_dictionary), file_data) 63 | pickle.dump(self.valid, file_data) 64 | pickle.dump(self.test, file_data) 65 | 66 | 67 | if glove_path is not None: 68 | maxvocabsize = len(self.dictionary) 69 | print("loading raw GloVe file ...") 70 | wv = {} 71 | vec = open(glove_path, 'r') 72 | for line in vec.readlines(): 73 | line = line.split(' ') 74 | wv[line[0]] = numpy.asarray( 75 | [float(x) for x in line[1:]]).astype('float32') 76 | vec.close() 77 | 78 | self.wordembed_matrix = embeddingstd * \ 79 | numpy.random.randn(maxvocabsize, wordembed_dim).astype('float32') 80 | key_error = 0 81 | for key, value in enumerate(self.dictionary.idx2word): 82 | try: 83 | self.wordembed_matrix[key] = wv[value] 84 | except KeyError: 85 | key_error += 1 86 | print("Total vocab size: %d, tokens not found in glove: %d" % ( 87 | maxvocabsize, key_error)) 88 | del wv 89 | 90 | print("dumping augmented word embedding matrix ...") 91 | f_we = open(embed_filepath, 'wb') 92 | numpy.save(f_we, self.wordembed_matrix) 93 | f_we.close() 94 | 95 | print(len(self.arc_dictionary.idx2word)) 96 | print(self.arc_dictionary.idx2word) 97 | 98 | def add_words(self, trees): 99 | words, tags = [], [] 100 | for tree in trees: 101 | words, tags = zip(*tree.pos()) 102 | words = [''] + list(words) + [''] 103 | for w in words: 104 | self.dictionary.add_word(w) 105 | 106 | def preprocess(self, parse_trees, is_train=False): 107 | sens_idx = [] 108 | sens_tag = [] 109 | sens_stag = [] 110 | sens_arc = [] 111 | distances = [] 112 | sens = [] 113 | trees = [] 114 | 115 | print('\nConverting trees ...') 116 | for i, tree in enumerate(parse_trees): 117 | if i % 10 == 0: 118 | print("Done %d/%d\r" % (i, len(parse_trees)), end='') 119 | word_lexs, word_tags = zip(*tree.pos()) 120 | idx = [] 121 | for word in ([''] + list(word_lexs) + ['']): 122 | idx.append(self.dictionary[word]) 123 | 124 | listerized_tree, arcs, tags = tree2list(tree) 125 | stags = [''] + list(word_tags) + [''] 126 | tags = [''] + tags + [''] 127 | arcs = [''] + arcs + [''] 128 | 129 | if type(listerized_tree) is str: 130 | listerized_tree = [listerized_tree] 131 | distances_sent, _ = distance(listerized_tree) 132 | distances_sent = [0] + distances_sent + [0] 133 | 134 | idx_arcs = [] 135 | for arc in arcs: 136 | arc = precess_arc(arc) 137 | arc_id = self.arc_dictionary.add_word(arc) if is_train else self.arc_dictionary[arc] 138 | idx_arcs.append(arc_id) 139 | 140 | # the "stags" are the original word tags included in the data files 141 | # we keep track of them so that, during evaluation, we can swap them with the original ones. 142 | idx_stags = [] 143 | for stag in stags: 144 | stag_id = self.stag_dictionary.add_word(stag) if is_train else self.stag_dictionary[stag] 145 | idx_stags.append(stag_id) 146 | 147 | # the "tags" are the collapsed unary chains, i.e. FRAG+DT 148 | # at evaluation, we swap the word tag "DT" with the true tag in "stags" (see after) 149 | idx_tags = [] 150 | for tag in tags: 151 | tag = precess_arc(tag) 152 | tag_id = self.arc_dictionary.add_word(tag) if is_train else self.arc_dictionary[tag] 153 | idx_tags.append(tag_id) 154 | 155 | assert len(distances_sent) == len(idx) - 1 156 | assert len(arcs) == len(idx) - 1 157 | assert len(idx) == len(word_lexs) + 2 158 | assert len(stags) == len(tags) 159 | 160 | sens.append(word_lexs) 161 | trees.append(tree) 162 | sens_idx.append(idx) 163 | sens_tag.append(idx_tags) 164 | sens_arc.append(idx_arcs) 165 | sens_stag.append(idx_stags) 166 | distances.append(distances_sent) 167 | 168 | return sens_idx, sens_tag, sens_stag, \ 169 | sens_arc, distances, sens, trees 170 | 171 | 172 | if __name__ == '__main__': 173 | import sys 174 | PTBCreator(data_path=sys.argv[1], glove_path=sys.argv[2] if len(sys.argv) > 2 else None) 175 | -------------------------------------------------------------------------------- /src/helpers.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import nltk 4 | import numpy 5 | 6 | 7 | word_tags = ['CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 8 | 'JJS', 'LS', 'MD', 'NN', 'NNS', 'NNP', 'NNPS', 'PDT', 9 | 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 10 | 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 11 | 'VBZ', 'WDT', 'WP', 'WP$', 'WRB'] 12 | currency_tags_words = ['#', '$', 'C$', 'A$'] 13 | ellipsis = ['*', '*?*', '0', '*T*', '*ICH*', '*U*', '*RNR*', '*EXP*', 14 | '*PPA*', '*NOT*'] 15 | punctuation_tags = ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', '``'] 16 | punctuation_words = ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', '``', 17 | '--', ';', '-', '?', '!', '...', '-LCB-', 18 | '-RCB-'] 19 | delated_tags = ['TOP', '-NONE-', ',', ':', '``', '\'\''] 20 | 21 | 22 | def precess_arc(label): 23 | labels = label.split('+') 24 | new_arc = [] 25 | for l in labels: 26 | if l == 'ADVP': 27 | l = 'PRT' 28 | # if len(new_arc) > 0 and l == new_arc[-1]: 29 | # continue 30 | new_arc.append(l) 31 | label = '+'.join(new_arc) 32 | return label 33 | 34 | 35 | def process_NONE(tree): 36 | if isinstance(tree, nltk.Tree): 37 | label = tree.label() 38 | if label == '-NONE-': 39 | return None 40 | else: 41 | tr = [] 42 | for node in tree: 43 | new_node = process_NONE(node) 44 | if new_node is not None: 45 | tr.append(new_node) 46 | if tr == []: 47 | return None 48 | else: 49 | return nltk.Tree(label, tr) 50 | else: 51 | return tree 52 | 53 | 54 | class Dictionary(object): 55 | def __init__(self): 56 | self.word2idx = {'': 0} 57 | self.idx2word = [''] 58 | self.word2frq = {} 59 | 60 | def add_word(self, word): 61 | if word not in self.word2idx: 62 | self.idx2word.append(word) 63 | self.word2idx[word] = len(self.idx2word) - 1 64 | if word not in self.word2frq: 65 | self.word2frq[word] = 1 66 | else: 67 | self.word2frq[word] += 1 68 | return self.word2idx[word] 69 | 70 | def __len__(self): 71 | return len(self.idx2word) 72 | 73 | def __getitem__(self, item): 74 | if item in self.word2idx: 75 | return self.word2idx[item] 76 | else: 77 | return self.word2idx[''] 78 | 79 | def rebuild_by_freq(self, thd=3): 80 | self.word2idx = {'': 0} 81 | self.idx2word = [''] 82 | 83 | for k, v in self.word2frq.items(): 84 | if v >= thd and (not k in self.idx2word): 85 | self.idx2word.append(k) 86 | self.word2idx[k] = len(self.idx2word) - 1 87 | 88 | print('Number of words:', len(self.idx2word)) 89 | return len(self.idx2word) 90 | 91 | def class_weight(self): 92 | frq = [self.word2frq[self.idx2word[i]] for i in range(len(self.idx2word))] 93 | frq = numpy.array(frq).astype('float') 94 | weight = numpy.sqrt(frq.max() / frq) 95 | weight = numpy.clip(weight, a_min=0., a_max=5.) 96 | return weight 97 | 98 | 99 | class FScore(object): 100 | def __init__(self, recall, precision, fscore): 101 | self.recall = recall 102 | self.precision = precision 103 | self.fscore = fscore 104 | 105 | def __str__(self): 106 | return "(Recall={:.2f}, Precision={:.2f}, FScore={:.2f})".format( 107 | self.recall, self.precision, self.fscore) 108 | 109 | 110 | def build_nltktree(depth, arc, tag, sen, arcdict, tagdict, stagdict, stags=None): 111 | """stags are the stanford predicted tags present in the train/valid/test files. 112 | """ 113 | assert len(sen) > 0 114 | assert len(depth) == len(sen) - 1, ("%s_%s" % (len(depth), len(sen))) 115 | if stags: 116 | assert len(stags) == len(tag) 117 | 118 | if len(sen) == 1: 119 | tag_list = str(tagdict[tag[0]]).split('+') 120 | tag_list.reverse() 121 | # if stags, put the real stanford pos TAG for the word and leave the 122 | # unary chain on top. 123 | if stags is not None: 124 | assert len(stags) > 0 125 | tag_list.insert(0, str(stagdict[stags[0]])) 126 | word = str(sen[0]) 127 | for t in tag_list: 128 | word = nltk.Tree(t, [word]) 129 | assert isinstance(word, nltk.Tree) 130 | return word 131 | else: 132 | idx = numpy.argmax(depth) 133 | node0 = build_nltktree( 134 | depth[:idx], arc[:idx], tag[:idx + 1], sen[:idx + 1], 135 | arcdict, tagdict, stagdict, stags[:idx + 1] if stags else None) 136 | node1 = build_nltktree( 137 | depth[idx + 1:], arc[idx + 1:], tag[idx + 1:], sen[idx + 1:], 138 | arcdict, tagdict, stagdict, stags[idx + 1:] if stags else None) 139 | 140 | if node0.label() != '' and node1.label() != '': 141 | tr = [node0, node1] 142 | elif node0.label() == '' and node1.label() != '': 143 | tr = [c for c in node0] + [node1] 144 | elif node0.label() != '' and node1.label() == '': 145 | tr = [node0] + [c for c in node1] 146 | elif node0.label() == '' and node1.label() == '': 147 | tr = [c for c in node0] + [c for c in node1] 148 | 149 | arc_list = str(arcdict[arc[idx]]).split('+') 150 | arc_list.reverse() 151 | for a in arc_list: 152 | if isinstance(tr, nltk.Tree): 153 | tr = [tr] 154 | tr = nltk.Tree(a, tr) 155 | 156 | return tr 157 | 158 | 159 | def MRG(tr): 160 | if isinstance(tr, str): 161 | return '( %s )' % tr 162 | # return tr + ' ' 163 | else: 164 | s = '(' 165 | for subtr in tr: 166 | s += MRG(subtr) + ' ' 167 | s += ')' 168 | return s 169 | 170 | 171 | def get_brackets(tree, start_idx=0, root=False): 172 | assert isinstance(tree, nltk.Tree) 173 | label = tree.label() 174 | label = label.replace('ADVP', 'PRT') 175 | 176 | brackets = set() 177 | if isinstance(tree[0], nltk.Tree): 178 | end_idx = start_idx 179 | for node in tree: 180 | node_brac, next_idx = get_brackets(node, end_idx) 181 | brackets.update(node_brac) 182 | end_idx = next_idx 183 | if not root: 184 | brackets.add((start_idx, end_idx, label)) 185 | else: 186 | end_idx = start_idx + 1 187 | 188 | return brackets, end_idx 189 | 190 | 191 | def normalize(x): 192 | return x / (sum(x) + 1e-8) 193 | 194 | 195 | def tree2list(tree, parent_arc=[]): 196 | if isinstance(tree, nltk.Tree): 197 | label = tree.label() 198 | if isinstance(tree[0], nltk.Tree): 199 | label = re.split('-|=', tree.label())[0] 200 | root_arc_list = parent_arc + [label] 201 | root_arc = '+'.join(root_arc_list) 202 | if len(tree) == 1: 203 | root, arc, tag = tree2list(tree[0], parent_arc=root_arc_list) 204 | elif len(tree) == 2: 205 | c0, arc0, tag0 = tree2list(tree[0]) 206 | c1, arc1, tag1 = tree2list(tree[1]) 207 | root = [c0, c1] 208 | arc = arc0 + [root_arc] + arc1 209 | tag = tag0 + tag1 210 | else: 211 | c0, arc0, tag0 = tree2list(tree[0]) 212 | c1, arc1, tag1 = tree2list(nltk.Tree('', tree[1:])) 213 | if bin == 0: 214 | root = [c0] + c1 215 | else: 216 | root = [c0, c1] 217 | arc = arc0 + [root_arc] + arc1 218 | tag = tag0 + tag1 219 | return root, arc, tag 220 | else: 221 | if len(parent_arc) == 1: 222 | parent_arc.insert(0, '') 223 | # parent_arc[-1] = '' 224 | del parent_arc[-1] 225 | return str(tree), [], ['+'.join(parent_arc)] 226 | 227 | 228 | def distance(root): 229 | if isinstance(root, list): 230 | dist_list = [] 231 | depth_list = [] 232 | for child in root: 233 | dist, depth = distance(child) 234 | dist_list.append(dist) 235 | depth_list.append(depth) 236 | 237 | max_depth = max(depth_list) 238 | 239 | out = dist_list[0] 240 | for dist in dist_list[1:]: 241 | out.append(max_depth) 242 | out.extend(dist) 243 | return out, max_depth + 1 244 | else: 245 | return [], 1 246 | -------------------------------------------------------------------------------- /src/dp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import random 5 | 6 | import torch.nn as nn 7 | from torch.optim.lr_scheduler import ReduceLROnPlateau 8 | 9 | from dataloader import PTBLoader 10 | from helpers import * 11 | from loss import * 12 | from model import distance_parser 13 | 14 | 15 | def get_args(): 16 | parser = argparse.ArgumentParser( 17 | description='Syntactic distance based neural parser') 18 | parser.add_argument('--epc', type=int, default=100) 19 | parser.add_argument('--lr', type=float, default=.001) 20 | parser.add_argument('--bthsz', type=int, default=20) 21 | parser.add_argument('--hidsz', type=int, default=1200) 22 | parser.add_argument('--embedsz', type=int, default=400) 23 | parser.add_argument('--window_size', type=int, default=2) 24 | parser.add_argument('--dpout', type=float, default=0.3) 25 | parser.add_argument('--dpoute', type=float, default=0.1) 26 | parser.add_argument('--dpoutr', type=float, default=0.2) 27 | parser.add_argument('--seed', type=int, default=1234) 28 | parser.add_argument('--weight_decay', type=float, default=1e-6) 29 | parser.add_argument('--use_glove', action='store_true') 30 | parser.add_argument('--logfre', type=int, default=200) 31 | parser.add_argument('--devfre', type=int, default=-1) 32 | parser.add_argument('--cuda', action='store_true', dest='cuda') 33 | parser.add_argument('--datapath', type=str, default='../data/ptb') 34 | parser.add_argument('--savepath', type=str, default='../results') 35 | parser.add_argument('--filename', type=str, default=None) 36 | 37 | args = parser.parse_args() 38 | # set seed and return args 39 | random.seed(args.seed) 40 | torch.random.manual_seed(args.seed) 41 | if args.cuda and torch.cuda.is_available(): 42 | torch.cuda.random.manual_seed(args.seed) 43 | return args 44 | 45 | 46 | def evaluate(model, data, mode='valid'): 47 | import tempfile 48 | model.eval() 49 | if mode == 'valid': 50 | idxs, tags, stags, arcs, dsts = data.batchify(mode, 1) 51 | _, _, _, _, _, sents, trees = data.valid 52 | elif mode == 'test': 53 | idxs, tags, stags, arcs, dsts = data.batchify(mode, 1) 54 | _, _, _, _, _, sents, trees = data.test 55 | 56 | temp_path = tempfile.TemporaryDirectory(prefix="evalb-") 57 | temp_file_path = os.path.join(temp_path.name, "pred_trees.txt") 58 | temp_targ_path = os.path.join(temp_path.name, "true_trees.txt") 59 | temp_eval_path = os.path.join(temp_path.name, "evals.txt") 60 | 61 | print("Temp: {}, {}".format(temp_file_path, temp_targ_path)) 62 | temp_tree_file = open(temp_file_path, "w") 63 | temp_targ_file = open(temp_targ_path, "w") 64 | 65 | set_loss = 0.0 66 | set_counter = 0 67 | set_arc_prec = 0.0 68 | arc_counter = 0 69 | set_tag_prec = 0.0 70 | tag_counter = 0 71 | for _, (idx, tag, stag, arc, dst, sent, tree) in enumerate( 72 | zip(idxs, tags, stags, arcs, dsts, sents, trees)): 73 | 74 | if args.cuda: 75 | idx = idx.cuda() 76 | tag = tag.cuda() 77 | stag = stag.cuda() 78 | arc = arc.cuda() 79 | dst = dst.cuda() 80 | 81 | mask = (idx >= 0).float() 82 | idx = idx * mask.long() 83 | dstmask = (dst > 0).float() 84 | pred_dst, pred_arc, pred_tag = model(idx, stag, mask) 85 | 86 | loss = rankloss(pred_dst, dst, dstmask) 87 | set_loss += loss.item() 88 | set_counter += 1 89 | 90 | _, pred_arc_idx = torch.max(pred_arc, dim=-1) 91 | arc_prec = ((arc == pred_arc_idx).float() * dstmask).sum() 92 | set_arc_prec += arc_prec.item() 93 | arc_counter += dstmask.sum().item() 94 | 95 | _, pred_tag_idx = torch.max(pred_tag, dim=-1) 96 | pred_tag_idx[0], pred_tag_idx[-1] = -1, -1 97 | tag_prec = (tag == pred_tag_idx).float().sum() 98 | set_tag_prec += tag_prec.item() 99 | tag_counter += (tag != 0).float().sum().item() 100 | 101 | pred_tree = build_nltktree( 102 | pred_dst.data.squeeze().cpu().numpy().tolist()[1:-1], 103 | pred_arc_idx.data.squeeze().cpu().numpy().tolist()[1:-1], 104 | pred_tag_idx.data.squeeze().cpu().numpy().tolist()[1:-1], 105 | sent, 106 | ptb_parsed.arc_dictionary.idx2word, 107 | ptb_parsed.arc_dictionary.idx2word, 108 | ptb_parsed.stag_dictionary.idx2word, 109 | stags=stag.data.squeeze().cpu().numpy().tolist()[1:-1] 110 | ) 111 | 112 | def process_str_tree(str_tree): 113 | return re.sub('[ |\n]+', ' ', str_tree) 114 | 115 | temp_tree_file.write(process_str_tree(str(pred_tree)) + '\n') 116 | temp_targ_file.write(process_str_tree(str(tree)) + '\n') 117 | 118 | # execute the evalb command: 119 | temp_tree_file.close() 120 | temp_targ_file.close() 121 | 122 | evalb_dir = os.path.join(os.getcwd(), "../EVALB") 123 | evalb_param_path = os.path.join(evalb_dir, "COLLINS.prm") 124 | evalb_program_path = os.path.join(evalb_dir, "evalb") 125 | command = "{} -p {} {} {} > {}".format( 126 | evalb_program_path, 127 | evalb_param_path, 128 | temp_targ_path, 129 | temp_file_path, 130 | temp_eval_path) 131 | 132 | import subprocess 133 | subprocess.run(command, shell=True) 134 | fscore = FScore(math.nan, math.nan, math.nan) 135 | 136 | with open(temp_eval_path) as infile: 137 | for line in infile: 138 | match = re.match(r"Bracketing Recall\s+=\s+(\d+\.\d+)", line) 139 | if match: 140 | fscore.recall = float(match.group(1)) 141 | match = re.match(r"Bracketing Precision\s+=\s+(\d+\.\d+)", line) 142 | if match: 143 | fscore.precision = float(match.group(1)) 144 | match = re.match(r"Bracketing FMeasure\s+=\s+(\d+\.\d+)", line) 145 | if match: 146 | fscore.fscore = float(match.group(1)) 147 | break 148 | 149 | temp_path.cleanup() 150 | 151 | set_loss /= set_counter 152 | set_arc_prec /= arc_counter 153 | set_tag_prec /= tag_counter 154 | 155 | model.train() 156 | 157 | return (set_loss, set_arc_prec, set_tag_prec, 158 | fscore.precision, fscore.recall, fscore.fscore) 159 | 160 | 161 | args = get_args() 162 | 163 | if args.filename is None: 164 | filename = sorted(str(args)[10:-1].split(', ')) 165 | filename = [i for i in filename if ('dir' not in i) and 166 | ('tblog' not in i) and 167 | ('fre' not in i) and 168 | ('cuda' not in i) and 169 | ('nlookback' not in i)] 170 | filename = __file__.split('.')[0].split('/')[-1] + '_' + \ 171 | '_'.join(filename).replace('=', '') \ 172 | .replace('/', '') \ 173 | .replace('\'', '') \ 174 | .replace('..', '') \ 175 | .replace('\"', '') 176 | else: 177 | filename = args.filename 178 | 179 | if not os.path.isdir(args.savepath): 180 | os.mkdir(args.savepath) 181 | parameter_filepath = os.path.join(args.savepath, filename + '.th') 182 | print('model parth:', parameter_filepath) 183 | 184 | print(args) 185 | print("loading data ...") 186 | ptb_parsed = PTBLoader(data_path=args.datapath, use_glove=args.use_glove) 187 | 188 | wordembed = ptb_parsed.wordembed_matrix 189 | args.vocab_size = len(ptb_parsed.dictionary) 190 | 191 | train_log_template = 'epoch {:<3d} batch {:<4d} loss {:<.6f} rank {:<.6f} arc {:<.6f} tag {:<.6f}' 192 | valid_log_template = \ 193 | '*** epoch {:<3d} \tloss \tarc prec \ttag prec \tprecision\trecall \tlf1 \n' \ 194 | '{:10}DEV\t{:<.6f}\t{:<.6f}\t{:<.6f}\t{:<.6f}\t{:<.6f}\t{:<.6f}\n' \ 195 | '{:10}TEST\t{:<.6f}\t{:<.6f}\t{:<.6f}\t{:<.6f}\t{:<.6f}\t{:<.6f}' 196 | 197 | if __name__ == '__main__': 198 | print("building model...") 199 | model = distance_parser(vocab_size=args.vocab_size, 200 | embed_size=args.embedsz, 201 | hid_size=args.hidsz, 202 | arc_size=len(ptb_parsed.arc_dictionary), 203 | stag_size=len(ptb_parsed.stag_dictionary), 204 | window_size=args.window_size, 205 | dropout=args.dpout, 206 | dropoute=args.dpoute, 207 | dropoutr=args.dpoutr, 208 | wordembed=wordembed) 209 | if args.cuda: 210 | model.cuda() 211 | 212 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0, 0.999), 213 | weight_decay=args.weight_decay) 214 | scheduler = ReduceLROnPlateau(optimizer, mode='max', patience=2, factor=0.5, min_lr=0.000001) 215 | 216 | print(" ") 217 | numparams = sum([numpy.prod(i.size()) for i in model.parameters()]) 218 | print("Number of params: {0}\n{1:35}{2:35}Size".format( 219 | numparams, 'Name', 'Shape')) # this includes tied parameters 220 | print("---------------------------------------------------------------------------") 221 | for item in model.state_dict().keys(): 222 | this_param = model.state_dict()[item] 223 | print("{:60}{!s:35}{}".format( 224 | item, this_param.size(), numpy.prod(this_param.size()))) 225 | print(" ") 226 | 227 | # setting up training initial variables; checking out to previous model (if exist) 228 | best_valid_f1 = 0.0 229 | start_epoch = 0 230 | 231 | print("training ...") 232 | 233 | train_idxs, train_tags, train_stags, \ 234 | train_arcs, train_distances, \ 235 | train_sents, train_trees = ptb_parsed.batchify('train', args.bthsz) 236 | if args.devfre == -1: 237 | args.devfre = len(train_idxs) 238 | 239 | for epoch in range(start_epoch, start_epoch + args.epc): 240 | inds = list(range(len(train_idxs))) 241 | random.shuffle(inds) 242 | epc_train_idxs = [train_idxs[i] for i in inds] 243 | epc_train_tags = [train_tags[i] for i in inds] 244 | epc_train_stags = [train_stags[i] for i in inds] 245 | epc_train_arcs = [train_arcs[i] for i in inds] 246 | epc_train_distances = [train_distances[i] for i in inds] 247 | 248 | for ibatch, (idx, tag, stag, arc, dst) in \ 249 | enumerate( 250 | zip( 251 | epc_train_idxs, 252 | epc_train_tags, 253 | epc_train_stags, 254 | epc_train_arcs, 255 | epc_train_distances, 256 | )): 257 | 258 | if args.cuda: 259 | idx = idx.cuda() 260 | tag = tag.cuda() 261 | stag = stag.cuda() 262 | arc = arc.cuda() 263 | dst = dst.cuda() 264 | 265 | mask = (idx >= 0).float() 266 | idx = idx * mask.long() 267 | dstmask = (dst > 0).float() 268 | 269 | optimizer.zero_grad() 270 | pred_dst, pred_arc, pred_tag = model(idx, stag, mask) 271 | loss_rank = rankloss(pred_dst, dst, dstmask) 272 | loss_arc = arcloss(pred_arc, arc.view(-1)) 273 | loss_tag = tagloss(pred_tag, tag.view(-1)) 274 | 275 | loss = loss_rank + loss_arc + loss_tag 276 | loss.backward() 277 | 278 | nn.utils.clip_grad_norm_(model.parameters(), 1.) 279 | optimizer.step() 280 | 281 | if (ibatch + 1) % args.logfre == 0: 282 | print(train_log_template.format(epoch, ibatch + 1, loss.item(), 283 | loss_rank.item(), loss_arc.item(), 284 | loss_tag.item())) 285 | 286 | ##### 287 | 288 | print("Evaluating valid... ") 289 | valid_loss, valid_arc_prec, valid_tag_prec, \ 290 | valid_precision, valid_recall, valid_f1 = evaluate(model, ptb_parsed, 'valid') 291 | print("Evaluating test... ") 292 | test_loss, test_arc_prec, test_tag_prec, \ 293 | test_precision, test_recall, test_f1 = evaluate(model, ptb_parsed, 'test') 294 | print(valid_log_template.format( 295 | epoch, 296 | ' ', valid_loss, valid_arc_prec, valid_tag_prec, 297 | valid_precision, valid_recall, valid_f1, 298 | ' ', test_loss, test_arc_prec, test_tag_prec, 299 | test_precision, test_recall, test_f1)) 300 | 301 | if valid_f1 > best_valid_f1: 302 | best_valid_f1 = valid_f1 303 | torch.save({ 304 | 'epoch': epoch, 305 | 'valid_loss': valid_loss, 306 | 'valid_precision': valid_precision, 307 | 'valid_recall': valid_recall, 308 | 'valid_f1': valid_f1, 309 | 'model_state_dict': model.state_dict(), 310 | 'optimizer': optimizer.state_dict(), 311 | }, parameter_filepath) 312 | 313 | scheduler.step(valid_f1) 314 | -------------------------------------------------------------------------------- /EVALB/README: -------------------------------------------------------------------------------- 1 | ################################################################# 2 | # # 3 | # Bug fix and additional functionality for evalb # 4 | # # 5 | # This updated version of evalb fixes a bug in which sentences # 6 | # were incorrectly categorized as "length mismatch" when the # 7 | # the parse output had certain mislabeled parts-of-speech. # 8 | # # 9 | # The bug was the result of evalb treating one of the tags (in # 10 | # gold or test) as a label to be deleted (see sections [6],[7] # 11 | # for details), but not the corresponding tag in the other. # 12 | # This most often occurs with punctuation. See the subdir # 13 | # "bug" for an example gld and tst file demonstating the bug, # 14 | # as well as output of evalb with and without the bug fix. # 15 | # # 16 | # For the present version in case of length mismatch, the nodes # 17 | # causing the imbalance are reinserted to resolve the miscount. # 18 | # If the lengths of gold and test truly differ, the error is # 19 | # still reported. The parameter file "new.prm" (derived from # 20 | # COLLINS.prm) shows how to add new potential mislabelings for # 21 | # quotes (",``,',`). # 22 | # # 23 | # I have preserved DJB's revision for modern compilers except # 24 | # for the delcaration of "exit" which is provided by stdlib. # 25 | # # 26 | # Other changes: # 27 | # # 28 | # * output of F-Measure in addition to precision and recall # 29 | # (I did not update the documention in section [4] for this) # 30 | # # 31 | # * more comprehensive DEBUG output that includes bracketing # 32 | # information as evalb is processing each sentence # 33 | # (useful in working through this, and peraps other bugs). # 34 | # Use either the "-D" run-time switch or set DEBUG to 2 in # 35 | # the parameter file. # 36 | # # 37 | # * added DELETE_LABEL lines in new.prm for S1 nodes produced # 38 | # by the Charniak parser and "?", "!" punctuation produced by # 39 | # the Bikel parser. # 40 | # # 41 | # # 42 | # David Ellis (Brown) # 43 | # # 44 | # January.2006 # 45 | ################################################################# 46 | 47 | ################################################################# 48 | # # 49 | # Update of evalb for modern compilers # 50 | # # 51 | # This is an updated version of evalb, for use with modern C # 52 | # compilers. There are a few updates, each marked in the code: # 53 | # # 54 | # /* DJB: explanation of comment */ # 55 | # # 56 | # The updates are purely to help compilation with recent # 57 | # versions of GCC (and other C compilers). There are *NO* other # 58 | # changes to the algorithm itself. # 59 | # # 60 | # I have made these changes following recommendations from # 61 | # users of the Corpora Mailing List, especially Peet Morris and # 62 | # Ramon Ziai. # 63 | # # 64 | # David Brooks (Birmingham) # 65 | # # 66 | # September.2005 # 67 | ################################################################# 68 | 69 | ################################################################# 70 | # # 71 | # README file for evalb # 72 | # # 73 | # Satoshi Sekine (NYU) # 74 | # Mike Collins (UPenn) # 75 | # # 76 | # October.1997 # 77 | ################################################################# 78 | 79 | Contents of this README: 80 | 81 | [0] COPYRIGHT 82 | [1] INTRODUCTION 83 | [2] INSTALLATION AND RUN 84 | [3] OPTIONS 85 | [4] OUTPUT FORMAT FROM THE SCORER 86 | [5] HOW TO CREATE A GOLDFILE FROM THE TREEBANK 87 | [6] THE PARAMETER FILE 88 | [7] MORE DETAILS ABOUT THE SCORING ALGORITHM 89 | 90 | 91 | [0] COPYRIGHT 92 | 93 | The authors abandon the copyright of this program. Everyone is 94 | permitted to copy and distribute the program or a portion of the program 95 | with no charge and no restrictions unless it is harmful to someone. 96 | 97 | However, the authors are delightful for the user's kindness of proper 98 | usage and letting the authors know bugs or problems. 99 | 100 | This software is provided "AS IS", and the authors make no warranties, 101 | express or implied. 102 | 103 | To legally enforce the abandonment of copyright, this package is released 104 | under the Unlicense (see LICENSE). 105 | 106 | [1] INTRODUCTION 107 | 108 | Evaluation of bracketing looks simple, but in fact, there are minor 109 | differences from system to system. This is a program to parametarize 110 | such minor differences and to give an informative result. 111 | 112 | "evalb" evaluates bracketing accuracy in a test-file against a gold-file. 113 | It returns recall, precision, tagging accuracy. It uses an identical 114 | algorithm to that used in (Collins ACL97). 115 | 116 | 117 | [2] Installation and Run 118 | 119 | To compile the scorer, type 120 | 121 | > make 122 | 123 | 124 | To run the scorer: 125 | 126 | > evalb -p Parameter_file Gold_file Test_file 127 | 128 | 129 | For example to use the sample files: 130 | 131 | > evalb -p sample.prm sample.gld sample.tst 132 | 133 | 134 | 135 | [3] OPTIONS 136 | 137 | You can specify system parameters in the command line options. 138 | Other options concerning to evaluation metrix should be specified 139 | in parameter file, described later. 140 | 141 | -p param_file parameter file 142 | -d debug mode 143 | -e n number of error to kill (default=10) 144 | -h help 145 | 146 | 147 | 148 | [4] OUTPUT FORMAT FROM THE SCORER 149 | 150 | The scorer gives individual scores for each sentence, for 151 | example: 152 | 153 | Sent. Matched Bracket Cross Correct Tag 154 | ID Len. Stat. Recal Prec. Bracket gold test Bracket Words Tags Accracy 155 | ============================================================================ 156 | 1 8 0 100.00 100.00 5 5 5 0 6 5 83.33 157 | 158 | At the end of the output the === Summary === section gives statistics 159 | for all sentences, and for sentences <=40 words in length. The summary 160 | contains the following information: 161 | 162 | i) Number of sentences -- total number of sentences. 163 | 164 | ii) Number of Error/Skip sentences -- should both be 0 if there is no 165 | problem with the parsed/gold files. 166 | 167 | iii) Number of valid sentences = Number of sentences - Number of Error/Skip 168 | sentences 169 | 170 | iv) Bracketing recall = (number of correct constituents) 171 | ---------------------------------------- 172 | (number of constituents in the goldfile) 173 | 174 | v) Bracketing precision = (number of correct constituents) 175 | ---------------------------------------- 176 | (number of constituents in the parsed file) 177 | 178 | vi) Complete match = percentaage of sentences where recall and precision are 179 | both 100%. 180 | 181 | vii) Average crossing = (number of constituents crossing a goldfile constituen 182 | ---------------------------------------------------- 183 | (number of sentences) 184 | 185 | viii) No crossing = percentage of sentences which have 0 crossing brackets. 186 | 187 | ix) 2 or less crossing = percentage of sentences which have <=2 crossing brackets. 188 | 189 | x) Tagging accuracy = percentage of correct POS tags (but see [5].3 for exact 190 | details of what is counted). 191 | 192 | 193 | 194 | [5] HOW TO CREATE A GOLDFILE FROM THE PENN TREEBANK 195 | 196 | 197 | The gold and parsed files are in a format similar to this: 198 | 199 | (TOP (S (INTJ (RB No)) (, ,) (NP (PRP it)) (VP (VBD was) (RB n't) (NP (NNP Black) (NNP Monday))) (. .))) 200 | 201 | To create a gold file from the treebank: 202 | 203 | tgrep -wn '/.*/' | tgrep_proc.prl 204 | 205 | will produce a goldfile in the required format. ("tgrep -wn '/.*/'" prints 206 | parse trees, "tgrep_process.prl" just skips blank lines). 207 | 208 | For example, to produce a goldfile for section 23 of the treebank: 209 | 210 | tgrep -wn '/.*/' | tail +90895 | tgrep_process.prl | sed 2416q > sec23.gold 211 | 212 | 213 | 214 | [6] THE PARAMETER (.prm) FILE 215 | 216 | 217 | The .prm file sets options regarding the scoring method. COLLINS.prm gives 218 | the same scoring behaviour as the scorer used in (Collins 97). The options 219 | chosen were: 220 | 221 | 1) LABELED 1 222 | 223 | to give labelled precision/recall figures, i.e. a constituent must have the 224 | same span *and* label as a constituent in the goldfile. 225 | 226 | 2) DELETE_LABEL TOP 227 | 228 | Don't count the "TOP" label (which is always given in the output of tgrep) 229 | when scoring. 230 | 231 | 3) DELETE_LABEL -NONE- 232 | 233 | Remove traces (and all constituents which dominate nothing but traces) when 234 | scoring. For example 235 | 236 | .... (VP (VBD reported) (SBAR (-NONE- 0) (S (-NONE- *T*-1)))) (. .))) 237 | 238 | would be processed to give 239 | 240 | .... (VP (VBD reported)) (. .))) 241 | 242 | 243 | 4) 244 | DELETE_LABEL , -- for the purposes of scoring remove punctuation 245 | DELETE_LABEL : 246 | DELETE_LABEL `` 247 | DELETE_LABEL '' 248 | DELETE_LABEL . 249 | 250 | 5) DELETE_LABEL_FOR_LENGTH -NONE- -- don't include traces when calculating 251 | the length of a sentence (important 252 | when classifying a sentence as <=40 253 | words or >40 words) 254 | 255 | 6) EQ_LABEL ADVP PRT 256 | 257 | Count ADVP and PRT as being the same label when scoring. 258 | 259 | 260 | 261 | 262 | [7] MORE DETAILS ABOUT THE SCORING ALGORITHM 263 | 264 | 265 | 1) The scorer initially processes the files to remove all nodes specified 266 | by DELETE_LABEL in the .prm file. It also recursively removes nodes which 267 | dominate nothing due to all their children being removed. For example, if 268 | -NONE- is specified as a label to be deleted, 269 | 270 | .... (VP (VBD reported) (SBAR (-NONE- 0) (S (-NONE- *T*-1)))) (. .))) 271 | 272 | would be processed to give 273 | 274 | .... (VP (VBD reported)) (. .))) 275 | 276 | 2) The scorer also removes all functional tags attached to non-terminals 277 | (functional tags are prefixed with "-" or "=" in the treebank). For example 278 | "NP-SBJ" is processed to give "NP", "NP=2" is changed to "NP". 279 | 280 | 281 | 3) Tagging accuracy counts tags for all words *except* any tags which are 282 | deleted by a DELETE_LABEL specification in the .prm file. (For example, for 283 | COLLINS.prm, punctuation tagged as "," ":" etc. would not be included). 284 | 285 | 4) When calculating the length of a sentence, all words with POS tags not 286 | included in the "DELETE_LABEL_FOR_LENGTH" list in the .prm file are 287 | counted. (For COLLINS.prm, only "-NONE-" is specified in this list, so 288 | traces are removed before calculating the length of the sentence). 289 | 290 | 5) There are some subtleties in scoring when either the goldfile or parsed 291 | file contains multiple constituents for the same span which have the same 292 | non-terminal label. e.g. (NP (NP the man)) If the goldfile contains n 293 | constituents for the same span, and the parsed file contains m constituents 294 | with that nonterminal, the scorer works as follows: 295 | 296 | i) If m>n, then the precision is n/m, recall is 100% 297 | 298 | ii) If n>m, then the precision is 100%, recall is m/n. 299 | 300 | iii) If n==m, recall and precision are both 100%. 301 | -------------------------------------------------------------------------------- /EVALB/evalb.c: -------------------------------------------------------------------------------- 1 | /*****************************************************************/ 2 | /* evalb [-p param_file] [-dh] [-e n] gold-file test-file */ 3 | /* */ 4 | /* Evaluate bracketing in test-file against gold-file. */ 5 | /* Return recall, precision, tagging accuracy. */ 6 | /* */ 7 | /*