├── src
├── __init__.py
├── embed_regularize.py
├── utility.py
├── demo.py
├── weight_drop.py
├── loss.py
├── ctb.py
├── model.py
├── dataloader.py
├── datacreate_ctb.py
├── datacreate_ptb.py
├── helpers.py
└── dp.py
├── EVALB
├── evalb
├── Makefile
├── tgrep_proc.prl
├── LICENSE
├── sample
│ ├── sample.tst
│ ├── sample.gld
│ ├── sample.prm
│ └── sample.rsl
├── COLLINS.prm
├── new.prm
├── README
└── evalb.c
├── .gitignore
└── README.md
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/EVALB/evalb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hantek/distance-parser/HEAD/EVALB/evalb
--------------------------------------------------------------------------------
/EVALB/Makefile:
--------------------------------------------------------------------------------
1 | all: evalb
2 |
3 | evalb: evalb.c
4 | gcc -Wall -g -o evalb evalb.c
5 |
--------------------------------------------------------------------------------
/EVALB/tgrep_proc.prl:
--------------------------------------------------------------------------------
1 | #!/usr/local/bin/perl
2 |
3 | while(<>)
4 | {
5 | if(m/TOP/) #skip lines which are blank
6 | {
7 | print;
8 | }
9 | }
10 |
--------------------------------------------------------------------------------
/src/embed_regularize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.autograd import Variable
4 | import torch.nn.functional as F
5 |
6 |
7 | def embedded_dropout(embed, words, dropout=0.1, scale=None):
8 | if dropout:
9 | mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout)
10 | mask = Variable(mask)
11 | masked_embed_weight = mask * embed.weight
12 | else:
13 | masked_embed_weight = embed.weight
14 | if scale:
15 | masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight
16 |
17 | padding_idx = embed.padding_idx
18 | if padding_idx is None:
19 | padding_idx = -1
20 | return F.embedding(words, masked_embed_weight,
21 | padding_idx, embed.max_norm, embed.norm_type,
22 | embed.scale_grad_by_freq, embed.sparse
23 | )
24 |
25 |
26 |
--------------------------------------------------------------------------------
/src/utility.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # Filename: utility.py
3 | # Author:hankcs
4 | # Date: 2017-11-03 22:05
5 | import errno
6 | from os import makedirs
7 |
8 | import sys
9 |
10 |
11 | def make_sure_path_exists(path):
12 | try:
13 | makedirs(path)
14 | except OSError as exception:
15 | if exception.errno != errno.EEXIST:
16 | raise
17 |
18 |
19 | def eprint(*args, **kwargs):
20 | print(args)
21 | # print(*args, file=sys.stderr, **kwargs)
22 |
23 |
24 | def combine_files(fids, out, tb):
25 | print('%d files...' % len(fids))
26 | total_sentence = 0
27 | for n, file in enumerate(fids):
28 | if n % 10 == 0 or n == len(fids) - 1:
29 | print("%c%.2f%%\r" % (13, (n + 1) / float(len(fids)) * 100), end='')
30 | sents = tb.parsed_sents(file)
31 | for s in sents:
32 | out.write(s.pformat(margin=sys.maxsize))
33 | out.write(u'\n')
34 | total_sentence += 1
35 | print()
36 | print('%d sentences.' % total_sentence)
37 | print()
38 |
39 |
--------------------------------------------------------------------------------
/EVALB/LICENSE:
--------------------------------------------------------------------------------
1 | This is free and unencumbered software released into the public domain.
2 |
3 | Anyone is free to copy, modify, publish, use, compile, sell, or
4 | distribute this software, either in source code form or as a compiled
5 | binary, for any purpose, commercial or non-commercial, and by any
6 | means.
7 |
8 | In jurisdictions that recognize copyright laws, the author or authors
9 | of this software dedicate any and all copyright interest in the
10 | software to the public domain. We make this dedication for the benefit
11 | of the public at large and to the detriment of our heirs and
12 | successors. We intend this dedication to be an overt act of
13 | relinquishment in perpetuity of all present and future rights to this
14 | software under copyright law.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22 | OTHER DEALINGS IN THE SOFTWARE.
23 |
24 | For more information, please refer to
25 |
--------------------------------------------------------------------------------
/EVALB/sample/sample.tst:
--------------------------------------------------------------------------------
1 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
2 | (S (A (P this)) (B (Q is) (C (R a) (T test))))
3 | (S (A (P this)) (B (Q is) (A (R a) (U test))))
4 | (S (C (P this)) (B (Q is) (A (R a) (U test))))
5 | (S (A (P this)) (B (Q is) (R a) (A (T test))))
6 | (S (A (P this) (Q is)) (A (R a) (T test)))
7 | (S (P this) (Q is) (R a) (T test))
8 | (P this) (Q is) (R a) (T test)
9 | (S (A (P this)) (B (Q is) (A (A (R a) (T test)))))
10 | (S (A (P this)) (B (Q is) (A (A (A (A (A (R a) (T test))))))))
11 |
12 | (S (A (P this)) (B (Q was) (A (A (R a) (T test)))))
13 | (S (A (P this)) (B (Q is) (U not) (A (A (R a) (T test)))))
14 |
15 | (TOP (S (A (P this)) (B (Q is) (A (R a) (T test)))))
16 | (S (A (P this)) (NONE *) (B (Q is) (A (R a) (T test))))
17 | (S (A (P this)) (S (NONE abc) (A (NONE *))) (B (Q is) (A (R a) (T test))))
18 | (S (A (P this)) (B (Q is) (A (R a) (TT test))))
19 | (S (A (P This)) (B (Q is) (A (R a) (T test))))
20 | (S (A (P That)) (B (Q is) (A (R a) (T test))))
21 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
22 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))))
23 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (-NONE- *))
24 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (: *))
25 |
--------------------------------------------------------------------------------
/EVALB/sample/sample.gld:
--------------------------------------------------------------------------------
1 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
2 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
3 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
4 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
5 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
6 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
7 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
8 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
9 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
10 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
11 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
12 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
13 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
14 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
15 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
16 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
17 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
18 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
19 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
20 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
21 | (S (A-SBJ-1 (P this)) (B-WHATEVER (Q is) (A (R a) (T test))))
22 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))))
23 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (-NONE- *))
24 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (: *))
25 |
--------------------------------------------------------------------------------
/src/demo.py:
--------------------------------------------------------------------------------
1 | from dp import *
2 |
3 | if __name__ == '__main__':
4 | print("building model...")
5 | model = distance_parser(vocab_size=args.vocab_size,
6 | embed_size=args.embedsz,
7 | hid_size=args.hidsz,
8 | arc_size=len(ptb_parsed.arc_dictionary),
9 | stag_size=len(ptb_parsed.stag_dictionary),
10 | window_size=args.window_size,
11 | dropout=args.dpout,
12 | dropoute=args.dpoute,
13 | dropoutr=args.dpoutr)
14 | if args.cuda:
15 | model.cuda()
16 |
17 | if os.path.isfile(parameter_filepath):
18 | print("Resuming from file: {}".format(parameter_filepath))
19 | checkpoint = torch.load(parameter_filepath)
20 | start_epoch = checkpoint['epoch']
21 | valid_precision = checkpoint['valid_precision']
22 | valid_recall = checkpoint['valid_recall']
23 | best_valid_f1 = checkpoint['valid_f1']
24 | model.load_state_dict(checkpoint['model_state_dict'])
25 | print("loaded model: epoch {}, valid_loss {}, "
26 | "valid_precision {}, valid_recall {}, valid_f1 {}".format(
27 | start_epoch, checkpoint['valid_loss'], valid_precision, \
28 | valid_recall, best_valid_f1))
29 |
30 | print("Evaluating valid... ")
31 | valid_loss, valid_arc_prec, valid_tag_prec, \
32 | valid_precision, valid_recall, valid_f1 = evaluate(model, ptb_parsed, 'valid')
33 | print("Evaluating test... ")
34 | test_loss, test_arc_prec, test_tag_prec, \
35 | test_precision, test_recall, test_f1= evaluate(model, ptb_parsed, 'test')
36 | print(valid_log_template.format(
37 | start_epoch,
38 | ' ', valid_loss, valid_arc_prec, valid_tag_prec,
39 | valid_precision, valid_recall, valid_f1,
40 | ' ', test_loss, test_arc_prec, test_tag_prec,
41 | test_precision, test_recall, test_f1))
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # vim and gedit cache:
2 | *.swp
3 | *.swo
4 | *~
5 |
6 | # cluster logs
7 | SMART_DISPATCH_LOGS/*
8 |
9 | # model params
10 | model/*
11 | params/*
12 | tmptrees/*
13 |
14 | # logs
15 | tblogs/*
16 | logs/*
17 |
18 | # data
19 | data/*
20 |
21 | # Byte-compiled / optimized / DLL files
22 | __pycache__/
23 | *.py[cod]
24 | *$py.class
25 | *.pyc
26 |
27 | # C extensions
28 | *.so
29 |
30 | # Distribution / packaging
31 | .Python
32 | env/
33 | build/
34 | develop-eggs/
35 | dist/
36 | downloads/
37 | eggs/
38 | .eggs/
39 | lib/
40 | lib64/
41 | parts/
42 | sdist/
43 | var/
44 | wheels/
45 | *.egg-info/
46 | .installed.cfg
47 | *.egg
48 |
49 | # PyInstaller
50 | # Usually these files are written by a python script from a template
51 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
52 | *.manifest
53 | *.spec
54 |
55 | # Installer logs
56 | pip-log.txt
57 | pip-delete-this-directory.txt
58 |
59 | # Unit test / coverage reports
60 | htmlcov/
61 | .tox/
62 | .coverage
63 | .coverage.*
64 | .cache
65 | nosetests.xml
66 | coverage.xml
67 | *.cover
68 | .hypothesis/
69 |
70 | # Translations
71 | *.mo
72 | *.pot
73 |
74 | # Django stuff:
75 | *.log
76 | local_settings.py
77 |
78 | # Flask stuff:
79 | instance/
80 | .webassets-cache
81 |
82 | # Scrapy stuff:
83 | .scrapy
84 |
85 | # Sphinx documentation
86 | docs/_build/
87 |
88 | # PyBuilder
89 | target/
90 |
91 | # Jupyter Notebook
92 | .ipynb_checkpoints
93 |
94 | # pyenv
95 | .python-version
96 |
97 | # celery beat schedule file
98 | celerybeat-schedule
99 |
100 | # SageMath parsed files
101 | *.sage.py
102 |
103 | # dotenv
104 | .env
105 |
106 | # virtualenv
107 | .venv
108 | venv/
109 | ENV/
110 |
111 | # Spyder project settings
112 | .spyderproject
113 | .spyproject
114 |
115 | # Rope project settings
116 | .ropeproject
117 |
118 | # mkdocs documentation
119 | /site
120 |
121 | # mypy
122 | .mypy_cache/
123 |
124 | #pycharm
125 | .idea/
126 |
127 | #pytorch
128 | *.pt
129 |
--------------------------------------------------------------------------------
/src/weight_drop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Parameter
3 |
4 | class WeightDrop(torch.nn.Module):
5 | def __init__(self, module, weights, dropout=0, variational=False):
6 | super(WeightDrop, self).__init__()
7 | self.module = module
8 | self.weights = weights
9 | self.dropout = dropout
10 | self.variational = variational
11 | self._setup()
12 |
13 | def widget_demagnetizer_y2k_edition(*args, **kwargs):
14 | # We need to replace flatten_parameters with a nothing function
15 | # It must be a function rather than a lambda as otherwise pickling explodes
16 | # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION!
17 | return
18 |
19 | def _setup(self):
20 | # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
21 | if issubclass(type(self.module), torch.nn.RNNBase):
22 | self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition
23 |
24 | for name_w in self.weights:
25 | print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
26 | w = getattr(self.module, name_w)
27 | del self.module._parameters[name_w]
28 | self.module.register_parameter(name_w + '_raw', Parameter(w.data))
29 |
30 | def _setweights(self):
31 | for name_w in self.weights:
32 | raw_w = getattr(self.module, name_w + '_raw')
33 | w = None
34 | if self.variational:
35 | mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
36 | if raw_w.is_cuda: mask = mask.cuda()
37 | mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
38 | w = mask.expand_as(raw_w) * raw_w
39 | else:
40 | w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
41 | setattr(self.module, name_w, w)
42 |
43 | def forward(self, *args):
44 | self._setweights()
45 | return self.module.forward(*args)
--------------------------------------------------------------------------------
/EVALB/sample/sample.prm:
--------------------------------------------------------------------------------
1 | ##------------------------------------------##
2 | ## Debug mode ##
3 | ## print out data for individual sentence ##
4 | ##------------------------------------------##
5 | DEBUG 0
6 |
7 | ##------------------------------------------##
8 | ## MAX error ##
9 | ## Number of error to stop the process. ##
10 | ## This is useful if there could be ##
11 | ## tokanization error. ##
12 | ## The process will stop when this number##
13 | ## of errors are accumulated. ##
14 | ##------------------------------------------##
15 | MAX_ERROR 10
16 |
17 | ##------------------------------------------##
18 | ## Cut-off length for statistics ##
19 | ## At the end of evaluation, the ##
20 | ## statistics for the senetnces of length##
21 | ## less than or equal to this number will##
22 | ## be shown, on top of the statistics ##
23 | ## for all the sentences ##
24 | ##------------------------------------------##
25 | CUTOFF_LEN 40
26 |
27 | ##------------------------------------------##
28 | ## unlabeled or labeled bracketing ##
29 | ## 0: unlabeled bracketing ##
30 | ## 1: labeled bracketing ##
31 | ##------------------------------------------##
32 | LABELED 1
33 |
34 | ##------------------------------------------##
35 | ## Delete labels ##
36 | ## list of labels to be ignored. ##
37 | ## If it is a pre-terminal label, delete ##
38 | ## the word along with the brackets. ##
39 | ## If it is a non-terminal label, just ##
40 | ## delete the brackets (don't delete ##
41 | ## deildrens). ##
42 | ##------------------------------------------##
43 | DELETE_LABEL TOP
44 | DELETE_LABEL -NONE-
45 | DELETE_LABEL ,
46 | DELETE_LABEL :
47 | DELETE_LABEL ``
48 | DELETE_LABEL ''
49 |
50 | ##------------------------------------------##
51 | ## Delete labels for length calculation ##
52 | ## list of labels to be ignored for ##
53 | ## length calculation purpose ##
54 | ##------------------------------------------##
55 | DELETE_LABEL_FOR_LENGTH -NONE-
56 |
57 |
58 | ##------------------------------------------##
59 | ## Equivalent labels, words ##
60 | ## the pairs are considered equivalent ##
61 | ## This is non-directional. ##
62 | ##------------------------------------------##
63 | EQ_LABEL T TT
64 |
65 | EQ_WORD This this
66 |
--------------------------------------------------------------------------------
/EVALB/COLLINS.prm:
--------------------------------------------------------------------------------
1 | ##------------------------------------------##
2 | ## Debug mode ##
3 | ## 0: No debugging ##
4 | ## 1: print data for individual sentence ##
5 | ##------------------------------------------##
6 | DEBUG 0
7 |
8 | ##------------------------------------------##
9 | ## MAX error ##
10 | ## Number of error to stop the process. ##
11 | ## This is useful if there could be ##
12 | ## tokanization error. ##
13 | ## The process will stop when this number##
14 | ## of errors are accumulated. ##
15 | ##------------------------------------------##
16 | MAX_ERROR 10
17 |
18 | ##------------------------------------------##
19 | ## Cut-off length for statistics ##
20 | ## At the end of evaluation, the ##
21 | ## statistics for the senetnces of length##
22 | ## less than or equal to this number will##
23 | ## be shown, on top of the statistics ##
24 | ## for all the sentences ##
25 | ##------------------------------------------##
26 | CUTOFF_LEN 40
27 |
28 | ##------------------------------------------##
29 | ## unlabeled or labeled bracketing ##
30 | ## 0: unlabeled bracketing ##
31 | ## 1: labeled bracketing ##
32 | ##------------------------------------------##
33 | LABELED 1
34 |
35 | ##------------------------------------------##
36 | ## Delete labels ##
37 | ## list of labels to be ignored. ##
38 | ## If it is a pre-terminal label, delete ##
39 | ## the word along with the brackets. ##
40 | ## If it is a non-terminal label, just ##
41 | ## delete the brackets (don't delete ##
42 | ## deildrens). ##
43 | ##------------------------------------------##
44 | DELETE_LABEL TOP
45 | DELETE_LABEL -NONE-
46 | DELETE_LABEL ,
47 | DELETE_LABEL :
48 | DELETE_LABEL ``
49 | DELETE_LABEL ''
50 | DELETE_LABEL .
51 |
52 | ##------------------------------------------##
53 | ## Delete labels for length calculation ##
54 | ## list of labels to be ignored for ##
55 | ## length calculation purpose ##
56 | ##------------------------------------------##
57 | DELETE_LABEL_FOR_LENGTH -NONE-
58 |
59 | ##------------------------------------------##
60 | ## Equivalent labels, words ##
61 | ## the pairs are considered equivalent ##
62 | ## This is non-directional. ##
63 | ##------------------------------------------##
64 | EQ_LABEL ADVP PRT
65 |
66 | # EQ_WORD Example example
67 |
--------------------------------------------------------------------------------
/src/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def _assert_no_grad(variable):
6 | assert not variable.requires_grad, \
7 | "nn criterions don't compute the gradient w.r.t. targets - please " \
8 | "mark these variables as not requiring gradients"
9 |
10 |
11 | def ScaledRankLoss(input, target, mask, epsilon):
12 | """
13 | scaled, single-sided L1 loss
14 | epsilon: parameter for scaling
15 |
16 | """
17 | _assert_no_grad(target)
18 |
19 | diff = input[:, :, None] - input[:, None, :]
20 | target_diff_positive = ((target[:, :, None] - target[:, None, :]) > 0).float()
21 | target_diff_negative = - ((target[:, :, None] - target[:, None, :]) < 0).float()
22 |
23 | target_diff = target_diff_positive + target_diff_negative
24 | target_diff_zero = 1 - (target_diff_positive + (- target_diff_negative))
25 |
26 | mask = mask[:, :, None] * mask[:, None, :]
27 |
28 | eepsilon = torch.exp(epsilon)
29 | loss = F.relu(eepsilon - target_diff * diff) + \
30 | target_diff_zero * diff * diff / eepsilon ** 2 + \
31 | 1 / eepsilon
32 | loss = (loss * mask).sum() / (mask.sum() + 1e-9)
33 | return loss
34 |
35 |
36 | # def rankloss(input, target, mask):
37 | # diff = (input[:, :, None] - input[:, None, :])
38 | # ### eqloss: we modify the loss in the paper to account for "ties"
39 | # ### i.e. we don't train the ties.
40 | # target_sign = torch.sign(target[:, :, None] - target[:, None, :]).float()
41 | # mask = mask[:, :, None] * mask[:, None, :]
42 | # loss = F.relu(1. - target_sign * diff)
43 | # loss = (0.5 * loss * mask).sum() / (mask.sum() + 1e-9)
44 | # return loss
45 |
46 |
47 | def rankloss(input, target, mask, exp=False):
48 | diff = input[:, :, None] - input[:, None, :]
49 | target_diff = ((target[:, :, None] - target[:, None, :]) > 0).float()
50 | mask = mask[:, :, None] * mask[:, None, :] * target_diff
51 |
52 | if exp:
53 | loss = torch.exp(F.relu(target_diff - diff)) - 1
54 | else:
55 | loss = F.relu(target_diff - diff)
56 | loss = (loss * mask).sum() / (mask.sum() + 1e-9)
57 |
58 | return loss
59 |
60 |
61 | mse = torch.nn.MSELoss(reduce=False)
62 |
63 |
64 | def mseloss(input, target, mask):
65 | loss = mse(input, target)
66 | return (loss * mask).sum() / (mask.sum() + 1e-9)
67 |
68 |
69 | arcloss = torch.nn.CrossEntropyLoss(ignore_index=0)
70 | tagloss = torch.nn.CrossEntropyLoss(ignore_index=0)
71 | bce = torch.nn.BCELoss(size_average=False)
72 |
73 |
74 | def labelloss(input, target, mask):
75 | loss = bce(input * mask, target * mask)
76 | return loss / (mask.sum() + 1e-9)
77 |
--------------------------------------------------------------------------------
/EVALB/sample/sample.rsl:
--------------------------------------------------------------------------------
1 | Sent. Matched Bracket Cross Correct Tag
2 | ID Len. Stat. Recal Prec. Bracket gold test Bracket Words Tags Accracy
3 | ============================================================================
4 | 1 4 0 100.00 100.00 4 4 4 0 4 4 100.00
5 | 2 4 0 75.00 75.00 3 4 4 0 4 4 100.00
6 | 3 4 0 100.00 100.00 4 4 4 0 4 3 75.00
7 | 4 4 0 75.00 75.00 3 4 4 0 4 3 75.00
8 | 5 4 0 75.00 75.00 3 4 4 0 4 4 100.00
9 | 6 4 0 50.00 66.67 2 4 3 1 4 4 100.00
10 | 7 4 0 25.00 100.00 1 4 1 0 4 4 100.00
11 | 8 4 0 0.00 0.00 0 4 0 0 4 4 100.00
12 | 9 4 0 100.00 80.00 4 4 5 0 4 4 100.00
13 | 10 4 0 100.00 50.00 4 4 8 0 4 4 100.00
14 | 11 4 2 0.00 0.00 0 0 0 0 4 0 0.00
15 | 12 4 1 0.00 0.00 0 0 0 0 4 0 0.00
16 | 13 4 1 0.00 0.00 0 0 0 0 4 0 0.00
17 | 14 4 2 0.00 0.00 0 0 0 0 4 0 0.00
18 | 15 4 0 100.00 100.00 4 4 4 0 4 4 100.00
19 | 16 4 1 0.00 0.00 0 0 0 0 4 0 0.00
20 | 17 4 1 0.00 0.00 0 0 0 0 4 0 0.00
21 | 18 4 0 100.00 100.00 4 4 4 0 4 4 100.00
22 | 19 4 0 100.00 100.00 4 4 4 0 4 4 100.00
23 | 20 4 1 0.00 0.00 0 0 0 0 4 0 0.00
24 | 21 4 0 100.00 100.00 4 4 4 0 4 4 100.00
25 | 22 44 0 100.00 100.00 34 34 34 0 44 44 100.00
26 | 23 4 0 100.00 100.00 4 4 4 0 4 4 100.00
27 | 24 5 0 100.00 100.00 4 4 4 0 4 4 100.00
28 | ============================================================================
29 | 87.76 90.53 86 98 95 16 108 106 98.15
30 | === Summary ===
31 |
32 | -- All --
33 | Number of sentence = 24
34 | Number of Error sentence = 5
35 | Number of Skip sentence = 2
36 | Number of Valid sentence = 17
37 | Bracketing Recall = 87.76
38 | Bracketing Precision = 90.53
39 | Complete match = 52.94
40 | Average crossing = 0.06
41 | No crossing = 94.12
42 | 2 or less crossing = 100.00
43 | Tagging accuracy = 98.15
44 |
45 | -- len<=40 --
46 | Number of sentence = 23
47 | Number of Error sentence = 5
48 | Number of Skip sentence = 2
49 | Number of Valid sentence = 16
50 | Bracketing Recall = 81.25
51 | Bracketing Precision = 85.25
52 | Complete match = 50.00
53 | Average crossing = 0.06
54 | No crossing = 93.75
55 | 2 or less crossing = 100.00
56 | Tagging accuracy = 96.88
57 |
--------------------------------------------------------------------------------
/EVALB/new.prm:
--------------------------------------------------------------------------------
1 | ##------------------------------------------##
2 | ## Debug mode ##
3 | ## 0: No debugging ##
4 | ## 1: print data for individual sentence ##
5 | ## 2: print detailed bracketing info ##
6 | ##------------------------------------------##
7 | DEBUG 0
8 |
9 | ##------------------------------------------##
10 | ## MAX error ##
11 | ## Number of error to stop the process. ##
12 | ## This is useful if there could be ##
13 | ## tokanization error. ##
14 | ## The process will stop when this number##
15 | ## of errors are accumulated. ##
16 | ##------------------------------------------##
17 | MAX_ERROR 10
18 |
19 | ##------------------------------------------##
20 | ## Cut-off length for statistics ##
21 | ## At the end of evaluation, the ##
22 | ## statistics for the senetnces of length##
23 | ## less than or equal to this number will##
24 | ## be shown, on top of the statistics ##
25 | ## for all the sentences ##
26 | ##------------------------------------------##
27 | CUTOFF_LEN 40
28 |
29 | ##------------------------------------------##
30 | ## unlabeled or labeled bracketing ##
31 | ## 0: unlabeled bracketing ##
32 | ## 1: labeled bracketing ##
33 | ##------------------------------------------##
34 | LABELED 1
35 |
36 | ##------------------------------------------##
37 | ## Delete labels ##
38 | ## list of labels to be ignored. ##
39 | ## If it is a pre-terminal label, delete ##
40 | ## the word along with the brackets. ##
41 | ## If it is a non-terminal label, just ##
42 | ## delete the brackets (don't delete ##
43 | ## deildrens). ##
44 | ##------------------------------------------##
45 | DELETE_LABEL TOP
46 | DELETE_LABEL S1
47 | DELETE_LABEL -NONE-
48 | DELETE_LABEL ,
49 | DELETE_LABEL :
50 | DELETE_LABEL ``
51 | DELETE_LABEL ''
52 | DELETE_LABEL .
53 | DELETE_LABEL ?
54 | DELETE_LABEL !
55 |
56 | ##------------------------------------------##
57 | ## Delete labels for length calculation ##
58 | ## list of labels to be ignored for ##
59 | ## length calculation purpose ##
60 | ##------------------------------------------##
61 | DELETE_LABEL_FOR_LENGTH -NONE-
62 |
63 | ##------------------------------------------##
64 | ## Labels to be considered for misquote ##
65 | ## (could be possesive or quote) ##
66 | ##------------------------------------------##
67 | QUOTE_LABEL ``
68 | QUOTE_LABEL ''
69 | QUOTE_LABEL POS
70 |
71 | ##------------------------------------------##
72 | ## These ones are less common, but ##
73 | ## are on occasion output by parsers: ##
74 | ##------------------------------------------##
75 | QUOTE_LABEL NN
76 | QUOTE_LABEL CD
77 | QUOTE_LABEL VBZ
78 | QUOTE_LABEL :
79 |
80 | ##------------------------------------------##
81 | ## Equivalent labels, words ##
82 | ## the pairs are considered equivalent ##
83 | ## This is non-directional. ##
84 | ##------------------------------------------##
85 | EQ_LABEL ADVP PRT
86 |
87 | # EQ_WORD Example example
88 |
--------------------------------------------------------------------------------
/src/ctb.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # Filename: ctb.py
3 | # Author: hantek, hankcs
4 |
5 | import os
6 | import argparse
7 | from os import listdir
8 | from os.path import isfile, join, isdir
9 | import nltk
10 |
11 | from utility import make_sure_path_exists, eprint, combine_files
12 |
13 |
14 | def convert(ctb_root, out_root):
15 | ctb_root = join(ctb_root, 'bracketed')
16 | fids = [f for f in listdir(ctb_root) if isfile(join(ctb_root, f)) and \
17 | f.endswith('.nw') or \
18 | f.endswith('.mz') or \
19 | f.endswith('.wb')]
20 | make_sure_path_exists(out_root)
21 |
22 | for f in fids:
23 | with open(join(ctb_root, f), 'r') as src, \
24 | open(join(out_root, f.split('.')[0] + '.fid'), 'w') as out:
25 | # encoding='GB2312'
26 | in_s_tag = False
27 | try:
28 | for line in src:
29 | if line.startswith('') or line.startswith(''):
32 | in_s_tag = False
33 | elif line.startswith('<'):
34 | continue
35 | elif in_s_tag and len(line) > 1:
36 | out.write(line)
37 | except:
38 | pass
39 |
40 |
41 | def combine_fids(fids, out_path):
42 | print('Generating ' + out_path)
43 | files = []
44 | for fid in fids:
45 | f = 'chtb_%04d.fid' % fid
46 | if isfile(join(ctb_in_nltk, f)):
47 | files.append(f)
48 | with open(out_path, 'w') as out:
49 | combine_files(files, out, ctb)
50 |
51 |
52 | if __name__ == '__main__':
53 | parser = argparse.ArgumentParser(
54 | description='Combine Chinese Treebank 5.1 fid files into train/dev/test set')
55 | parser.add_argument("--ctb", required=True,
56 | help='The root path to Chinese Treebank 5.1')
57 | parser.add_argument("--output", required=True,
58 | help='The folder where to store the output train.txt/dev.txt/test.txt')
59 |
60 | args = parser.parse_args()
61 |
62 | ctb_in_nltk = None
63 | for root in nltk.data.path:
64 | if isdir(root):
65 | ctb_in_nltk = root
66 |
67 | if ctb_in_nltk is None:
68 | eprint('You should run nltk.download(\'ptb\') to fetch some data first!')
69 | exit(1)
70 |
71 | ctb_in_nltk = join(ctb_in_nltk, 'corpora')
72 | ctb_in_nltk = join(ctb_in_nltk, 'ctb')
73 |
74 | print('Converting CTB: removing xml tags...')
75 | convert(args.ctb, ctb_in_nltk)
76 | print('Importing to nltk...\n')
77 | from nltk.corpus import BracketParseCorpusReader, LazyCorpusLoader
78 |
79 | ctb = LazyCorpusLoader(
80 | 'ctb', BracketParseCorpusReader, r'chtb_.*\.*',
81 | tagset='unknown')
82 |
83 | training = list(range(1, 270 + 1)) + list(range(440, 1151 + 1))
84 | development = list(range(301, 325 + 1))
85 | test = list(range(271, 300 + 1))
86 |
87 | root_path = args.output
88 | if not os.path.isdir(root_path):
89 | os.mkdir(root_path)
90 | combine_fids(training, join(root_path, 'train.txt'))
91 | combine_fids(development, join(root_path, 'dev.txt'))
92 | combine_fids(test, join(root_path, 'test.txt'))
93 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Distance Parser
2 | Distance parser is a supervised constituency parser based on syntactic distance.
3 | This repo is a working sample of distance parser which reproduces the results reported in the paper
4 | [Straight to the Tree: Constituency Parsing with Neural Syntactic Distance](https://arxiv.org/abs/1806.04168),
5 | which is published in ACL 2018. We provide models with proper configurations for PTB and CTB datasets, as well as their preprocessing scripts.
6 |
7 | ## Requirements
8 | [PyTorch](https://pytorch.org/) We use PyTorch 0.4.0 with python 3.6.
9 | [Stanford POS tagger](https://nlp.stanford.edu/software/stanford-postagger-full-2018-02-27.zip). We use the full Stanford Tagger, version 3.9.1, build 2018-02-27.
10 | [NLTK](http://www.nltk.org/) We use NLTK 3.2.5.
11 | [EVALB](https://nlp.cs.nyu.edu/evalb/) We have integrated a compiled EVALB inside our repo. This compiled version is forked from the current latest verison of EVALB, which can be accessed through [this link](https://nlp.cs.nyu.edu/evalb/EVALB.tgz).
12 |
13 | ## Datasets and Preprocessing
14 |
15 | ### Preprocessing PTB
16 | We use the same preprocessed PTB files from the [self attentive parser](https://github.com/nikitakit/self-attentive-parser) repo. [GloVe embeddings](https://nlp.stanford.edu/projects/glove/) are optional if you don't want to run the ablation experiments.
17 |
18 | To preprocess PTB, please follow the steps below:
19 |
20 | 1. Download the 3 PTB data files from https://github.com/nikitakit/self-attentive-parser/tree/master/data, and put them in the `data/ptb` folder.
21 |
22 | 2. Run the following command to prepare the PTB data:
23 | ```
24 | python datacreate_ptb.py ../data/ptb /path/to/glove.840B.300d.txt
25 | ```
26 |
27 | ### Preprocessing CTB
28 | We use the standard train/valid/test split specified in [Liu and Zhang, 2017](https://arxiv.org/pdf/1707.05000.pdf) for our CTB experiments.
29 |
30 | To preprocess the CTB, please follow the steps below:
31 |
32 | 1. Download and unzip the Chinese Treebank dataset from https://wakespace.lib.wfu.edu/handle/10339/39379
33 |
34 | 2. If you don't have any corpus data in NLTK before, download some to initialize your `nltk_data` folder, such as:
35 | ```
36 | python -c "import nltk; nltk.download('ptb')"
37 | ```
38 |
39 | 3. Run the following command to link the dataset to NLTK, and generate the train/valid/test split in the repo:
40 | ```
41 | python ctb.py --ctb /path/to/your/ctb8.0/data --output data/ctb_liusplit
42 | ```
43 |
44 | 4. Integrate the Stanford Tagger for data preprocessing. Download the Stanford tagger from https://nlp.stanford.edu/software/stanford-postagger-full-2018-02-27.zip and unzip it.
45 |
46 | 5. Run the following command to generate the preprocessed files:
47 | ```
48 | python datacreate_ctb.py ../data/ctb_liusplit /pth/to/stanford/tagger/
49 | ```
50 |
51 | ## Experiments
52 | For reproducing the PTB results in table 1, run
53 | ```
54 | cd src
55 | python dp.py --cuda --datapath ../data/ptb --savepath ../ptbresults --epc 200 --lr 0.001 --bthsz 20 --hidsz 1200 --embedsz 400 --window_size 2 --dpout 0.3 --dpoute 0.1 --dpoutr 0.2 --weight_decay 1e-6
56 | ```
57 |
58 | For reproducing the CTB results in table 2, run
59 | ```
60 | cd src
61 | python dp.py --cuda --datapath ../data/ctb_liusplit --savepath ../ctbresults --epc 200 --lr 0.001 --bthsz 20 --hidsz 1200 --embedsz 400 --window_size 2 --dpout 0.4 --dpoute 0.1 --dpoutr 0.1 --weight_decay 1e-6
62 | ```
63 |
64 | ## Pre-trained models
65 | We provide pre-trained models for the convenience of users. The following steps download the two pre-trained models to your repo:
66 | ```
67 | mkdir results/
68 | cd results/
69 | wget http://lisaweb.iro.umontreal.ca/transfert/lisa/users/linzhou/distance_parser_pretrained_model/ctb.th
70 | wget http://lisaweb.iro.umontreal.ca/transfert/lisa/users/linzhou/distance_parser_pretrained_model/ptb.th
71 | ```
72 | To re-evaluate the pre-trained models, run:
73 | ```
74 | cd src/
75 | python demo.py --cuda --datapath ../data/ptb/ --filename ptb # this command reproduces the 92.0 F1 score for PTB
76 | python demo.py --cuda --datapath ../data/ctb_liusplit/ --filename ctb # this command reproduces the 86.5 F1 score for CTB
77 | ```
78 | Note that the file has to be in the `results` folder inorder for the `demo.py` script to load it automatically.
79 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
5 |
6 | from embed_regularize import embedded_dropout
7 | from weight_drop import WeightDrop
8 |
9 |
10 | class Shuffle(nn.Module):
11 | def __init__(self, permutation, contiguous=True):
12 | super(Shuffle, self).__init__()
13 | self.permutation = permutation
14 | self.contiguous = contiguous
15 |
16 | def forward(self, input):
17 | shuffled = input.permute(*self.permutation)
18 | if self.contiguous:
19 | return shuffled.contiguous()
20 | else:
21 | return shuffled
22 |
23 |
24 | class LayerNormalization(nn.Module):
25 | ''' Layer normalization module '''
26 |
27 | def __init__(self, d_hid, eps=1e-3):
28 | super(LayerNormalization, self).__init__()
29 |
30 | self.eps = eps
31 | self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
32 | self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)
33 |
34 | def forward(self, z):
35 | if z.size(1) == 1:
36 | return z
37 |
38 | mu = torch.mean(z, keepdim=True, dim=-1)
39 | sigma = torch.std(z, keepdim=True, dim=-1)
40 | ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
41 | ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)
42 |
43 | return ln_out
44 |
45 |
46 | class distance_parser(nn.Module):
47 | def __init__(self,
48 | vocab_size, embed_size, hid_size,
49 | arc_size, stag_size, window_size,
50 | wordembed=None, dropout=0.2, dropoute=0.1, dropoutr=0.1):
51 | super(distance_parser, self).__init__()
52 | self.vocab_size = vocab_size
53 | self.embed_size = embed_size
54 | self.hid_size = hid_size
55 | self.arc_size = arc_size
56 | self.stag_size = stag_size
57 | self.window_size = window_size
58 | self.drop = nn.Dropout(dropout)
59 | self.dropoute = dropoute
60 | self.dropoutr = dropoutr
61 | self.encoder = nn.Embedding(vocab_size, embed_size)
62 | if wordembed is not None:
63 | self.encoder.weight.data = torch.FloatTensor(wordembed)
64 |
65 | self.tag_encoder = nn.Embedding(stag_size, embed_size)
66 |
67 | self.word_rnn = nn.LSTM(2 * embed_size, hid_size, num_layers=2, batch_first=True, dropout=dropout,
68 | bidirectional=True)
69 | self.word_rnn = WeightDrop(self.word_rnn, ['weight_hh_l0', 'weight_hh_l1'], dropout=dropoutr)
70 |
71 | self.conv1 = nn.Sequential(nn.Dropout(dropout),
72 | nn.Conv1d(hid_size * 2,
73 | hid_size,
74 | window_size),
75 | nn.ReLU())
76 |
77 | self.arc_rnn = nn.LSTM(hid_size, hid_size, num_layers=2, batch_first=True, dropout=dropout,
78 | bidirectional=True)
79 | self.arc_rnn = WeightDrop(self.arc_rnn, ['weight_hh_l0', 'weight_hh_l1'], dropout=dropoutr)
80 |
81 | self.distance = nn.Sequential(
82 | nn.Dropout(dropout),
83 | nn.Linear(hid_size * 2, hid_size),
84 | nn.ReLU(),
85 | nn.Dropout(dropout),
86 | nn.Linear(hid_size, 1),
87 | )
88 |
89 | self.terminal = nn.Sequential(
90 | nn.Dropout(dropout),
91 | nn.Linear(hid_size * 2, hid_size),
92 | nn.ReLU(),
93 | )
94 |
95 | self.non_terminal = nn.Sequential(
96 | nn.Dropout(dropout),
97 | nn.Linear(hid_size * 2, hid_size),
98 | nn.ReLU(),
99 | )
100 |
101 | self.arc = nn.Sequential(
102 | nn.Dropout(dropout),
103 | nn.Linear(hid_size, arc_size),
104 | )
105 |
106 | def forward(self, words, stag, mask):
107 | """
108 | tokens: Variable of LongTensor, shape (bsize, ntoken,)
109 | mock_emb: mock embedding for convolution overhead
110 | """
111 |
112 | bsz, ntoken = words.size()
113 | emb_words = embedded_dropout(self.encoder, words, dropout=self.dropoute if self.training else 0)
114 | emb_words = self.drop(emb_words)
115 |
116 | emb_stags = embedded_dropout(self.tag_encoder, stag, dropout=self.dropoute if self.training else 0)
117 | emb_stags = self.drop(emb_stags)
118 |
119 |
120 | def run_rnn(input, rnn, lengths):
121 | sorted_idx = numpy.argsort(lengths)[::-1].tolist()
122 | rnn_input = pack_padded_sequence(input[sorted_idx], lengths[sorted_idx], batch_first=True)
123 | rnn_out, _ = rnn(rnn_input) # (bsize, ntoken, hidsize*2)
124 | rnn_out, _ = pad_packed_sequence(rnn_out, batch_first=True)
125 | rnn_out = rnn_out[numpy.argsort(sorted_idx).tolist()]
126 |
127 | return rnn_out
128 |
129 | sent_lengths = (mask.sum(dim=1)).data.cpu().numpy().astype('int')
130 | dst_lengths = sent_lengths - 1
131 | emb_plus_tag = torch.cat([emb_words, emb_stags], dim=-1)
132 |
133 | rnn1_out = run_rnn(emb_plus_tag, self.word_rnn, sent_lengths)
134 |
135 | terminal = self.terminal(rnn1_out.view(-1, self.hid_size*2))
136 | tag = self.arc(terminal) # (bsize, ndst, tagsize)
137 |
138 | conv_out = self.conv1(rnn1_out.permute(0, 2, 1)).permute(0, 2, 1) # (bsize, ndst, hidsize)
139 | rnn2_out = run_rnn(conv_out, self.arc_rnn, dst_lengths)
140 |
141 | non_terminal = self.non_terminal(rnn2_out.view(-1, self.hid_size*2))
142 | distance = self.distance(rnn2_out.view(-1, self.hid_size*2)).squeeze(dim=-1) # (bsize, ndst)
143 | arc = self.arc(non_terminal) # (bsize, ndst, arcsize)
144 | return distance.view(bsz, ntoken - 1), arc.contiguous().view(-1, self.arc_size), tag.view(-1, self.arc_size)
145 |
--------------------------------------------------------------------------------
/src/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | import torch
5 |
6 | from helpers import *
7 |
8 |
9 | class Dictionary(object):
10 | def __init__(self):
11 | self.word2idx = {'': 0}
12 | self.idx2word = ['']
13 | self.word2frq = {}
14 |
15 | def add_word(self, word):
16 | if word not in self.word2idx:
17 | self.idx2word.append(word)
18 | self.word2idx[word] = len(self.idx2word) - 1
19 | if word not in self.word2frq:
20 | self.word2frq[word] = 1
21 | else:
22 | self.word2frq[word] += 1
23 | return self.word2idx[word]
24 |
25 | def __len__(self):
26 | return len(self.idx2word)
27 |
28 | def __getitem__(self, item):
29 | if self.word2idx.has_key(item):
30 | return self.word2idx[item]
31 | else:
32 | return self.word2idx['']
33 |
34 | def rebuild_by_freq(self, thd=3):
35 | self.word2idx = {'': 0}
36 | self.idx2word = ['']
37 |
38 | for k, v in self.word2frq.iteritems():
39 | if v >= thd and (not k in self.idx2word):
40 | self.idx2word.append(k)
41 | self.word2idx[k] = len(self.idx2word) - 1
42 |
43 | print('Number of words:', len(self.idx2word))
44 | return len(self.idx2word)
45 |
46 | def class_weight(self):
47 | frq = [self.word2frq[self.idx2word[i]] for i in range(len(self.idx2word))]
48 | frq = numpy.array(frq).astype('float')
49 | weight = numpy.sqrt(frq.max() / frq)
50 | weight = numpy.clip(weight, a_min=0., a_max=5.)
51 |
52 | return weight
53 |
54 |
55 | class PTBLoader(object):
56 | '''Data path is assumed to be a directory with
57 | pkl files and a corpora subdirectory.
58 | '''
59 | def __init__(self, data_path=None, use_glove=False):
60 | assert data_path is not None
61 | # make path available for nltk
62 | nltk.data.path.append(data_path)
63 | dict_filepath = os.path.join(data_path, 'dict.pkl')
64 | data_filepath = os.path.join(data_path, 'parsed.pkl')
65 |
66 | print("loading dictionary ...")
67 | self.dictionary = pickle.load(open(dict_filepath, "rb"))
68 | if use_glove:
69 | glove_filepath = os.path.join(data_path, 'ptb_glove.npy')
70 | print("loading preprocessed glove file ...")
71 | f_we = open(glove_filepath, 'rb')
72 | self.wordembed_matrix = numpy.load(f_we)
73 | f_we.close()
74 | else:
75 | self.wordembed_matrix = None
76 |
77 | # build tree and distance
78 | print("loading tree and distance ...")
79 | file_data = open(data_filepath, 'rb')
80 | self.train, self.arc_dictionary, self.stag_dictionary = pickle.load(file_data)
81 | self.valid = pickle.load(file_data)
82 | self.test = pickle.load(file_data)
83 | file_data.close()
84 |
85 | def batchify(self, dataname, batch_size, sort=False):
86 | sents, trees = None, None
87 | if dataname == 'train':
88 | idxs, tags, stags, arcs, distances, sents, trees = self.train
89 | elif dataname == 'valid':
90 | idxs, tags, stags, arcs, distances, _, _ = self.valid
91 | elif dataname == 'test':
92 | idxs, tags, stags, arcs, distances, _, _ = self.test
93 | else:
94 | raise 'need a correct dataname'
95 |
96 | assert len(idxs) == len(distances)
97 | assert len(idxs) == len(tags)
98 |
99 | bachified_idxs, bachified_tags, bachified_stags, \
100 | bachified_arcs, bachified_dsts, \
101 | = [], [], [], [], []
102 | bachified_sents, bachified_trees = [], []
103 | for i in range(0, len(idxs), batch_size):
104 | if i + batch_size >= len(idxs): continue
105 |
106 | if sents is not None:
107 | bachified_sents.append(sents[i: i + batch_size])
108 | bachified_trees.append(trees[i: i + batch_size])
109 |
110 | extracted_idxs = idxs[i: i + batch_size]
111 | extracted_tags = tags[i: i + batch_size]
112 | extracted_stags = stags[i: i + batch_size]
113 |
114 | extracted_arcs = arcs[i: i + batch_size]
115 | extracted_dsts = distances[i: i + batch_size]
116 |
117 | longest_idx = max([len(i) for i in extracted_idxs])
118 | longest_arc = longest_idx - 1
119 |
120 | minibatch_idxs, minibatch_tags, minibatch_stags, \
121 | minibatch_arcs, minibatch_dsts, \
122 | = [], [], [], [], []
123 | for idx, tag, stag, \
124 | arc, dst \
125 | in zip(extracted_idxs, extracted_tags, extracted_stags,
126 | extracted_arcs, extracted_dsts):
127 | padded_idx = idx + [-1] * (longest_idx - len(idx))
128 | padded_tag = tag + [0] * (longest_idx - len(tag))
129 | padded_stag = stag + [0] * (longest_idx - len(stag))
130 |
131 | padded_arc = arc + [0] * (longest_arc - len(arc))
132 | padded_dst = dst + [0] * (longest_arc - len(dst))
133 |
134 | minibatch_idxs.append(padded_idx)
135 | minibatch_tags.append(padded_tag)
136 | minibatch_stags.append(padded_stag)
137 |
138 | minibatch_arcs.append(padded_arc)
139 | minibatch_dsts.append(padded_dst)
140 |
141 | minibatch_idxs = torch.LongTensor(minibatch_idxs)
142 | minibatch_tags = torch.LongTensor(minibatch_tags)
143 | minibatch_stags = torch.LongTensor(minibatch_stags)
144 |
145 | minibatch_arcs = torch.LongTensor(minibatch_arcs)
146 | minibatch_dsts = torch.FloatTensor(minibatch_dsts)
147 |
148 | bachified_idxs.append(minibatch_idxs)
149 | bachified_tags.append(minibatch_tags)
150 | bachified_stags.append(minibatch_stags)
151 |
152 | bachified_arcs.append(minibatch_arcs)
153 | bachified_dsts.append(minibatch_dsts)
154 |
155 | if sents is not None:
156 | return bachified_idxs, bachified_tags, bachified_stags, \
157 | bachified_arcs, bachified_dsts, \
158 | bachified_sents, bachified_trees
159 | return bachified_idxs, bachified_tags, bachified_stags, \
160 | bachified_arcs, bachified_dsts
161 |
162 |
--------------------------------------------------------------------------------
/src/datacreate_ctb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | from nltk.tree import Tree
5 | from nltk.tag import StanfordPOSTagger
6 |
7 | from helpers import *
8 |
9 |
10 |
11 | def load_trees(path, strip_top=True, strip_spmrl_features=True):
12 | trees = []
13 | with open(path) as infile:
14 | for line in infile:
15 | trees.append(Tree.fromstring(line))
16 |
17 | if strip_top:
18 | for i, tree in enumerate(trees):
19 | if tree.label() in ("TOP", "ROOT"):
20 | assert len(tree) == 1
21 | trees[i] = tree[0]
22 | return trees
23 |
24 |
25 | class CTBCreator(object):
26 | '''Data path is assumed to be a directory with
27 | pkl files and a corpora subdirectory.
28 | '''
29 | def __init__(self,
30 | wordembed_dim=300,
31 | embeddingstd=0.1,
32 | data_path=None,
33 | tagger_path=None):
34 | assert data_path is not None
35 | assert tagger_path is not None
36 | dict_filepath = os.path.join(data_path, 'dict.pkl')
37 | data_filepath = os.path.join(data_path, 'parsed.pkl')
38 | train_filepath = os.path.join(data_path, "train.txt")
39 | valid_filepath = os.path.join(data_path, "dev.txt")
40 | test_filepath = os.path.join(data_path, "test.txt")
41 |
42 | self.st = StanfordPOSTagger(os.path.join(tagger_path, 'models/chinese-distsim.tagger'),
43 | os.path.join(tagger_path, 'stanford-postagger.jar'))
44 |
45 | print("building dictionary ...")
46 | f_dict = open(dict_filepath, 'wb')
47 | self.dictionary = Dictionary()
48 |
49 | print("loading trees from {}".format(train_filepath))
50 | train_trees = load_trees(train_filepath)
51 | print("loading trees from {}".format(valid_filepath))
52 | valid_trees = load_trees(valid_filepath)
53 | print("loading trees from {}".format(test_filepath))
54 | test_trees = load_trees(test_filepath)
55 |
56 | self.add_words(train_trees)
57 | self.dictionary.rebuild_by_freq()
58 | self.arc_dictionary = Dictionary()
59 | self.stag_dictionary = Dictionary()
60 | self.train = self.preprocess(train_trees, is_train=True)
61 | self.valid = self.preprocess(valid_trees, is_train=False)
62 | self.test = self.preprocess(test_trees, is_train=False)
63 | with open(dict_filepath, "wb") as file_dict:
64 | pickle.dump(self.dictionary, file_dict)
65 | with open(data_filepath, "wb") as file_data:
66 | pickle.dump((self.train, self.arc_dictionary,
67 | self.stag_dictionary), file_data)
68 | pickle.dump(self.valid, file_data)
69 | pickle.dump(self.test, file_data)
70 |
71 | print(len(self.arc_dictionary.idx2word))
72 | print(self.arc_dictionary.idx2word)
73 |
74 | def add_words(self, trees):
75 | words, tags = [], []
76 | for tree in trees:
77 | tree = process_NONE(tree)
78 | words, tags = zip(*tree.pos())
79 | words = [''] + list(words) + ['']
80 | for w in words:
81 | self.dictionary.add_word(w)
82 |
83 | def preprocess(self, parse_trees, is_train=False):
84 | sens_idx = []
85 | sens_tag = []
86 | sens_stag = []
87 | sens_arc = []
88 | distances = []
89 | sens = []
90 | trees = []
91 |
92 | print('\nConverting trees ...')
93 | for i, tree in enumerate(parse_trees):
94 | tree = process_NONE(tree)
95 | if i % 10 == 0:
96 | print("Done %d/%d\r" % (i, len(parse_trees)), end='')
97 | word_lexs, _ = zip(*tree.pos())
98 | idx = []
99 | for word in ([''] + list(word_lexs) + ['']):
100 | idx.append(self.dictionary[word])
101 |
102 | listerized_tree, arcs, tags = tree2list(tree)
103 | tags = [''] + tags + ['']
104 | arcs = [''] + arcs + ['']
105 |
106 | if type(listerized_tree) is str:
107 | listerized_tree = [listerized_tree]
108 | distances_sent, _ = distance(listerized_tree)
109 | distances_sent = [0] + distances_sent + [0]
110 |
111 | idx_arcs = []
112 | for arc in arcs:
113 | arc = precess_arc(arc)
114 | arc_id = self.arc_dictionary.add_word(arc) if is_train else self.arc_dictionary[arc]
115 | idx_arcs.append(arc_id)
116 |
117 | # the "tags" are the collapsed unary chains, i.e. FRAG+DT
118 | # at evaluation, we swap the word tag "DT" with the true tag in "stags" (see after)
119 | idx_tags = []
120 | for tag in tags:
121 | tag = precess_arc(tag)
122 | tag_id = self.arc_dictionary.add_word(tag) if is_train else self.arc_dictionary[tag]
123 | idx_tags.append(tag_id)
124 |
125 | assert len(distances_sent) == len(idx) - 1
126 | assert len(arcs) == len(idx) - 1
127 | assert len(idx) == len(word_lexs) + 2
128 |
129 | sens.append(word_lexs)
130 | trees.append(tree)
131 | sens_idx.append(idx)
132 | sens_tag.append(idx_tags)
133 | sens_arc.append(idx_arcs)
134 | distances.append(distances_sent)
135 |
136 | print('\nLabelling POS tags ...')
137 | st_outputs = self.st.tag_sents(sens)
138 | for i, word_tags in enumerate(st_outputs):
139 | if i % 10 == 0:
140 | print("Done %d/%d\r" % (i, len(parse_trees)), end='')
141 | word_tags = [t[1].split('#')[1] for t in word_tags]
142 | stags = [''] + list(word_tags) + ['']
143 |
144 | # the "stags" are the original word tags included in the data files
145 | # we keep track of them so that, during evaluation, we can swap them with the original ones.
146 | idx_stags = []
147 | for stag in stags:
148 | stag_id = self.stag_dictionary.add_word(stag) if is_train else self.stag_dictionary[stag]
149 | idx_stags.append(stag_id)
150 |
151 | sens_stag.append(idx_stags)
152 |
153 | return sens_idx, sens_tag, sens_stag, \
154 | sens_arc, distances, sens, trees
155 |
156 |
157 | if __name__ == '__main__':
158 | import sys
159 | CTBCreator(data_path=sys.argv[1], tagger_path=sys.argv[2])
160 |
--------------------------------------------------------------------------------
/src/datacreate_ptb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | from nltk.tree import Tree
5 |
6 | from helpers import *
7 |
8 |
9 | def load_trees(path, strip_top=True, strip_spmrl_features=True):
10 | trees = []
11 | with open(path) as infile:
12 | for line in infile:
13 | trees.append(Tree.fromstring(line))
14 |
15 | if strip_top:
16 | for i, tree in enumerate(trees):
17 | if tree.label() in ("TOP", "ROOT"):
18 | assert len(tree) == 1
19 | trees[i] = tree[0]
20 | return trees
21 |
22 |
23 | class PTBCreator(object):
24 | '''Data path is assumed to be a directory with
25 | pkl files and a corpora subdirectory.
26 | '''
27 | def __init__(self,
28 | wordembed_dim=300,
29 | embeddingstd=0.1,
30 | data_path=None,
31 | glove_path=None):
32 | assert data_path is not None
33 | dict_filepath = os.path.join(data_path, 'dict.pkl')
34 | data_filepath = os.path.join(data_path, 'parsed.pkl')
35 | train_filepath = os.path.join(data_path, "02-21.10way.clean")
36 | valid_filepath = os.path.join(data_path, "22.auto.clean")
37 | test_filepath = os.path.join(data_path, "23.auto.clean")
38 | embed_filepath = os.path.join(data_path, "ptb_glove.npy") #'../data/ptb/ptb_glove.npy'
39 |
40 | print("building dictionary ...")
41 | f_dict = open(dict_filepath, 'wb')
42 | self.dictionary = Dictionary()
43 |
44 | print("loading trees from {}".format(train_filepath))
45 | train_trees = load_trees(train_filepath)
46 | print("loading trees from {}".format(valid_filepath))
47 | valid_trees = load_trees(valid_filepath)
48 | print("loading trees from {}".format(test_filepath))
49 | test_trees = load_trees(test_filepath)
50 |
51 | self.add_words(train_trees)
52 | self.dictionary.rebuild_by_freq()
53 | self.arc_dictionary = Dictionary()
54 | self.stag_dictionary = Dictionary()
55 | self.train = self.preprocess(train_trees, is_train=True)
56 | self.valid = self.preprocess(valid_trees, is_train=False)
57 | self.test = self.preprocess(test_trees, is_train=False)
58 | with open(dict_filepath, "wb") as file_dict:
59 | pickle.dump(self.dictionary, file_dict)
60 | with open(data_filepath, "wb") as file_data:
61 | pickle.dump((self.train, self.arc_dictionary,
62 | self.stag_dictionary), file_data)
63 | pickle.dump(self.valid, file_data)
64 | pickle.dump(self.test, file_data)
65 |
66 |
67 | if glove_path is not None:
68 | maxvocabsize = len(self.dictionary)
69 | print("loading raw GloVe file ...")
70 | wv = {}
71 | vec = open(glove_path, 'r')
72 | for line in vec.readlines():
73 | line = line.split(' ')
74 | wv[line[0]] = numpy.asarray(
75 | [float(x) for x in line[1:]]).astype('float32')
76 | vec.close()
77 |
78 | self.wordembed_matrix = embeddingstd * \
79 | numpy.random.randn(maxvocabsize, wordembed_dim).astype('float32')
80 | key_error = 0
81 | for key, value in enumerate(self.dictionary.idx2word):
82 | try:
83 | self.wordembed_matrix[key] = wv[value]
84 | except KeyError:
85 | key_error += 1
86 | print("Total vocab size: %d, tokens not found in glove: %d" % (
87 | maxvocabsize, key_error))
88 | del wv
89 |
90 | print("dumping augmented word embedding matrix ...")
91 | f_we = open(embed_filepath, 'wb')
92 | numpy.save(f_we, self.wordembed_matrix)
93 | f_we.close()
94 |
95 | print(len(self.arc_dictionary.idx2word))
96 | print(self.arc_dictionary.idx2word)
97 |
98 | def add_words(self, trees):
99 | words, tags = [], []
100 | for tree in trees:
101 | words, tags = zip(*tree.pos())
102 | words = [''] + list(words) + ['']
103 | for w in words:
104 | self.dictionary.add_word(w)
105 |
106 | def preprocess(self, parse_trees, is_train=False):
107 | sens_idx = []
108 | sens_tag = []
109 | sens_stag = []
110 | sens_arc = []
111 | distances = []
112 | sens = []
113 | trees = []
114 |
115 | print('\nConverting trees ...')
116 | for i, tree in enumerate(parse_trees):
117 | if i % 10 == 0:
118 | print("Done %d/%d\r" % (i, len(parse_trees)), end='')
119 | word_lexs, word_tags = zip(*tree.pos())
120 | idx = []
121 | for word in ([''] + list(word_lexs) + ['']):
122 | idx.append(self.dictionary[word])
123 |
124 | listerized_tree, arcs, tags = tree2list(tree)
125 | stags = [''] + list(word_tags) + ['']
126 | tags = [''] + tags + ['']
127 | arcs = [''] + arcs + ['']
128 |
129 | if type(listerized_tree) is str:
130 | listerized_tree = [listerized_tree]
131 | distances_sent, _ = distance(listerized_tree)
132 | distances_sent = [0] + distances_sent + [0]
133 |
134 | idx_arcs = []
135 | for arc in arcs:
136 | arc = precess_arc(arc)
137 | arc_id = self.arc_dictionary.add_word(arc) if is_train else self.arc_dictionary[arc]
138 | idx_arcs.append(arc_id)
139 |
140 | # the "stags" are the original word tags included in the data files
141 | # we keep track of them so that, during evaluation, we can swap them with the original ones.
142 | idx_stags = []
143 | for stag in stags:
144 | stag_id = self.stag_dictionary.add_word(stag) if is_train else self.stag_dictionary[stag]
145 | idx_stags.append(stag_id)
146 |
147 | # the "tags" are the collapsed unary chains, i.e. FRAG+DT
148 | # at evaluation, we swap the word tag "DT" with the true tag in "stags" (see after)
149 | idx_tags = []
150 | for tag in tags:
151 | tag = precess_arc(tag)
152 | tag_id = self.arc_dictionary.add_word(tag) if is_train else self.arc_dictionary[tag]
153 | idx_tags.append(tag_id)
154 |
155 | assert len(distances_sent) == len(idx) - 1
156 | assert len(arcs) == len(idx) - 1
157 | assert len(idx) == len(word_lexs) + 2
158 | assert len(stags) == len(tags)
159 |
160 | sens.append(word_lexs)
161 | trees.append(tree)
162 | sens_idx.append(idx)
163 | sens_tag.append(idx_tags)
164 | sens_arc.append(idx_arcs)
165 | sens_stag.append(idx_stags)
166 | distances.append(distances_sent)
167 |
168 | return sens_idx, sens_tag, sens_stag, \
169 | sens_arc, distances, sens, trees
170 |
171 |
172 | if __name__ == '__main__':
173 | import sys
174 | PTBCreator(data_path=sys.argv[1], glove_path=sys.argv[2] if len(sys.argv) > 2 else None)
175 |
--------------------------------------------------------------------------------
/src/helpers.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sys
3 | import nltk
4 | import numpy
5 |
6 |
7 | word_tags = ['CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR',
8 | 'JJS', 'LS', 'MD', 'NN', 'NNS', 'NNP', 'NNPS', 'PDT',
9 | 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP',
10 | 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP',
11 | 'VBZ', 'WDT', 'WP', 'WP$', 'WRB']
12 | currency_tags_words = ['#', '$', 'C$', 'A$']
13 | ellipsis = ['*', '*?*', '0', '*T*', '*ICH*', '*U*', '*RNR*', '*EXP*',
14 | '*PPA*', '*NOT*']
15 | punctuation_tags = ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', '``']
16 | punctuation_words = ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', '``',
17 | '--', ';', '-', '?', '!', '...', '-LCB-',
18 | '-RCB-']
19 | delated_tags = ['TOP', '-NONE-', ',', ':', '``', '\'\'']
20 |
21 |
22 | def precess_arc(label):
23 | labels = label.split('+')
24 | new_arc = []
25 | for l in labels:
26 | if l == 'ADVP':
27 | l = 'PRT'
28 | # if len(new_arc) > 0 and l == new_arc[-1]:
29 | # continue
30 | new_arc.append(l)
31 | label = '+'.join(new_arc)
32 | return label
33 |
34 |
35 | def process_NONE(tree):
36 | if isinstance(tree, nltk.Tree):
37 | label = tree.label()
38 | if label == '-NONE-':
39 | return None
40 | else:
41 | tr = []
42 | for node in tree:
43 | new_node = process_NONE(node)
44 | if new_node is not None:
45 | tr.append(new_node)
46 | if tr == []:
47 | return None
48 | else:
49 | return nltk.Tree(label, tr)
50 | else:
51 | return tree
52 |
53 |
54 | class Dictionary(object):
55 | def __init__(self):
56 | self.word2idx = {'': 0}
57 | self.idx2word = ['']
58 | self.word2frq = {}
59 |
60 | def add_word(self, word):
61 | if word not in self.word2idx:
62 | self.idx2word.append(word)
63 | self.word2idx[word] = len(self.idx2word) - 1
64 | if word not in self.word2frq:
65 | self.word2frq[word] = 1
66 | else:
67 | self.word2frq[word] += 1
68 | return self.word2idx[word]
69 |
70 | def __len__(self):
71 | return len(self.idx2word)
72 |
73 | def __getitem__(self, item):
74 | if item in self.word2idx:
75 | return self.word2idx[item]
76 | else:
77 | return self.word2idx['']
78 |
79 | def rebuild_by_freq(self, thd=3):
80 | self.word2idx = {'': 0}
81 | self.idx2word = ['']
82 |
83 | for k, v in self.word2frq.items():
84 | if v >= thd and (not k in self.idx2word):
85 | self.idx2word.append(k)
86 | self.word2idx[k] = len(self.idx2word) - 1
87 |
88 | print('Number of words:', len(self.idx2word))
89 | return len(self.idx2word)
90 |
91 | def class_weight(self):
92 | frq = [self.word2frq[self.idx2word[i]] for i in range(len(self.idx2word))]
93 | frq = numpy.array(frq).astype('float')
94 | weight = numpy.sqrt(frq.max() / frq)
95 | weight = numpy.clip(weight, a_min=0., a_max=5.)
96 | return weight
97 |
98 |
99 | class FScore(object):
100 | def __init__(self, recall, precision, fscore):
101 | self.recall = recall
102 | self.precision = precision
103 | self.fscore = fscore
104 |
105 | def __str__(self):
106 | return "(Recall={:.2f}, Precision={:.2f}, FScore={:.2f})".format(
107 | self.recall, self.precision, self.fscore)
108 |
109 |
110 | def build_nltktree(depth, arc, tag, sen, arcdict, tagdict, stagdict, stags=None):
111 | """stags are the stanford predicted tags present in the train/valid/test files.
112 | """
113 | assert len(sen) > 0
114 | assert len(depth) == len(sen) - 1, ("%s_%s" % (len(depth), len(sen)))
115 | if stags:
116 | assert len(stags) == len(tag)
117 |
118 | if len(sen) == 1:
119 | tag_list = str(tagdict[tag[0]]).split('+')
120 | tag_list.reverse()
121 | # if stags, put the real stanford pos TAG for the word and leave the
122 | # unary chain on top.
123 | if stags is not None:
124 | assert len(stags) > 0
125 | tag_list.insert(0, str(stagdict[stags[0]]))
126 | word = str(sen[0])
127 | for t in tag_list:
128 | word = nltk.Tree(t, [word])
129 | assert isinstance(word, nltk.Tree)
130 | return word
131 | else:
132 | idx = numpy.argmax(depth)
133 | node0 = build_nltktree(
134 | depth[:idx], arc[:idx], tag[:idx + 1], sen[:idx + 1],
135 | arcdict, tagdict, stagdict, stags[:idx + 1] if stags else None)
136 | node1 = build_nltktree(
137 | depth[idx + 1:], arc[idx + 1:], tag[idx + 1:], sen[idx + 1:],
138 | arcdict, tagdict, stagdict, stags[idx + 1:] if stags else None)
139 |
140 | if node0.label() != '' and node1.label() != '':
141 | tr = [node0, node1]
142 | elif node0.label() == '' and node1.label() != '':
143 | tr = [c for c in node0] + [node1]
144 | elif node0.label() != '' and node1.label() == '':
145 | tr = [node0] + [c for c in node1]
146 | elif node0.label() == '' and node1.label() == '':
147 | tr = [c for c in node0] + [c for c in node1]
148 |
149 | arc_list = str(arcdict[arc[idx]]).split('+')
150 | arc_list.reverse()
151 | for a in arc_list:
152 | if isinstance(tr, nltk.Tree):
153 | tr = [tr]
154 | tr = nltk.Tree(a, tr)
155 |
156 | return tr
157 |
158 |
159 | def MRG(tr):
160 | if isinstance(tr, str):
161 | return '( %s )' % tr
162 | # return tr + ' '
163 | else:
164 | s = '('
165 | for subtr in tr:
166 | s += MRG(subtr) + ' '
167 | s += ')'
168 | return s
169 |
170 |
171 | def get_brackets(tree, start_idx=0, root=False):
172 | assert isinstance(tree, nltk.Tree)
173 | label = tree.label()
174 | label = label.replace('ADVP', 'PRT')
175 |
176 | brackets = set()
177 | if isinstance(tree[0], nltk.Tree):
178 | end_idx = start_idx
179 | for node in tree:
180 | node_brac, next_idx = get_brackets(node, end_idx)
181 | brackets.update(node_brac)
182 | end_idx = next_idx
183 | if not root:
184 | brackets.add((start_idx, end_idx, label))
185 | else:
186 | end_idx = start_idx + 1
187 |
188 | return brackets, end_idx
189 |
190 |
191 | def normalize(x):
192 | return x / (sum(x) + 1e-8)
193 |
194 |
195 | def tree2list(tree, parent_arc=[]):
196 | if isinstance(tree, nltk.Tree):
197 | label = tree.label()
198 | if isinstance(tree[0], nltk.Tree):
199 | label = re.split('-|=', tree.label())[0]
200 | root_arc_list = parent_arc + [label]
201 | root_arc = '+'.join(root_arc_list)
202 | if len(tree) == 1:
203 | root, arc, tag = tree2list(tree[0], parent_arc=root_arc_list)
204 | elif len(tree) == 2:
205 | c0, arc0, tag0 = tree2list(tree[0])
206 | c1, arc1, tag1 = tree2list(tree[1])
207 | root = [c0, c1]
208 | arc = arc0 + [root_arc] + arc1
209 | tag = tag0 + tag1
210 | else:
211 | c0, arc0, tag0 = tree2list(tree[0])
212 | c1, arc1, tag1 = tree2list(nltk.Tree('', tree[1:]))
213 | if bin == 0:
214 | root = [c0] + c1
215 | else:
216 | root = [c0, c1]
217 | arc = arc0 + [root_arc] + arc1
218 | tag = tag0 + tag1
219 | return root, arc, tag
220 | else:
221 | if len(parent_arc) == 1:
222 | parent_arc.insert(0, '')
223 | # parent_arc[-1] = ''
224 | del parent_arc[-1]
225 | return str(tree), [], ['+'.join(parent_arc)]
226 |
227 |
228 | def distance(root):
229 | if isinstance(root, list):
230 | dist_list = []
231 | depth_list = []
232 | for child in root:
233 | dist, depth = distance(child)
234 | dist_list.append(dist)
235 | depth_list.append(depth)
236 |
237 | max_depth = max(depth_list)
238 |
239 | out = dist_list[0]
240 | for dist in dist_list[1:]:
241 | out.append(max_depth)
242 | out.extend(dist)
243 | return out, max_depth + 1
244 | else:
245 | return [], 1
246 |
--------------------------------------------------------------------------------
/src/dp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import os
4 | import random
5 |
6 | import torch.nn as nn
7 | from torch.optim.lr_scheduler import ReduceLROnPlateau
8 |
9 | from dataloader import PTBLoader
10 | from helpers import *
11 | from loss import *
12 | from model import distance_parser
13 |
14 |
15 | def get_args():
16 | parser = argparse.ArgumentParser(
17 | description='Syntactic distance based neural parser')
18 | parser.add_argument('--epc', type=int, default=100)
19 | parser.add_argument('--lr', type=float, default=.001)
20 | parser.add_argument('--bthsz', type=int, default=20)
21 | parser.add_argument('--hidsz', type=int, default=1200)
22 | parser.add_argument('--embedsz', type=int, default=400)
23 | parser.add_argument('--window_size', type=int, default=2)
24 | parser.add_argument('--dpout', type=float, default=0.3)
25 | parser.add_argument('--dpoute', type=float, default=0.1)
26 | parser.add_argument('--dpoutr', type=float, default=0.2)
27 | parser.add_argument('--seed', type=int, default=1234)
28 | parser.add_argument('--weight_decay', type=float, default=1e-6)
29 | parser.add_argument('--use_glove', action='store_true')
30 | parser.add_argument('--logfre', type=int, default=200)
31 | parser.add_argument('--devfre', type=int, default=-1)
32 | parser.add_argument('--cuda', action='store_true', dest='cuda')
33 | parser.add_argument('--datapath', type=str, default='../data/ptb')
34 | parser.add_argument('--savepath', type=str, default='../results')
35 | parser.add_argument('--filename', type=str, default=None)
36 |
37 | args = parser.parse_args()
38 | # set seed and return args
39 | random.seed(args.seed)
40 | torch.random.manual_seed(args.seed)
41 | if args.cuda and torch.cuda.is_available():
42 | torch.cuda.random.manual_seed(args.seed)
43 | return args
44 |
45 |
46 | def evaluate(model, data, mode='valid'):
47 | import tempfile
48 | model.eval()
49 | if mode == 'valid':
50 | idxs, tags, stags, arcs, dsts = data.batchify(mode, 1)
51 | _, _, _, _, _, sents, trees = data.valid
52 | elif mode == 'test':
53 | idxs, tags, stags, arcs, dsts = data.batchify(mode, 1)
54 | _, _, _, _, _, sents, trees = data.test
55 |
56 | temp_path = tempfile.TemporaryDirectory(prefix="evalb-")
57 | temp_file_path = os.path.join(temp_path.name, "pred_trees.txt")
58 | temp_targ_path = os.path.join(temp_path.name, "true_trees.txt")
59 | temp_eval_path = os.path.join(temp_path.name, "evals.txt")
60 |
61 | print("Temp: {}, {}".format(temp_file_path, temp_targ_path))
62 | temp_tree_file = open(temp_file_path, "w")
63 | temp_targ_file = open(temp_targ_path, "w")
64 |
65 | set_loss = 0.0
66 | set_counter = 0
67 | set_arc_prec = 0.0
68 | arc_counter = 0
69 | set_tag_prec = 0.0
70 | tag_counter = 0
71 | for _, (idx, tag, stag, arc, dst, sent, tree) in enumerate(
72 | zip(idxs, tags, stags, arcs, dsts, sents, trees)):
73 |
74 | if args.cuda:
75 | idx = idx.cuda()
76 | tag = tag.cuda()
77 | stag = stag.cuda()
78 | arc = arc.cuda()
79 | dst = dst.cuda()
80 |
81 | mask = (idx >= 0).float()
82 | idx = idx * mask.long()
83 | dstmask = (dst > 0).float()
84 | pred_dst, pred_arc, pred_tag = model(idx, stag, mask)
85 |
86 | loss = rankloss(pred_dst, dst, dstmask)
87 | set_loss += loss.item()
88 | set_counter += 1
89 |
90 | _, pred_arc_idx = torch.max(pred_arc, dim=-1)
91 | arc_prec = ((arc == pred_arc_idx).float() * dstmask).sum()
92 | set_arc_prec += arc_prec.item()
93 | arc_counter += dstmask.sum().item()
94 |
95 | _, pred_tag_idx = torch.max(pred_tag, dim=-1)
96 | pred_tag_idx[0], pred_tag_idx[-1] = -1, -1
97 | tag_prec = (tag == pred_tag_idx).float().sum()
98 | set_tag_prec += tag_prec.item()
99 | tag_counter += (tag != 0).float().sum().item()
100 |
101 | pred_tree = build_nltktree(
102 | pred_dst.data.squeeze().cpu().numpy().tolist()[1:-1],
103 | pred_arc_idx.data.squeeze().cpu().numpy().tolist()[1:-1],
104 | pred_tag_idx.data.squeeze().cpu().numpy().tolist()[1:-1],
105 | sent,
106 | ptb_parsed.arc_dictionary.idx2word,
107 | ptb_parsed.arc_dictionary.idx2word,
108 | ptb_parsed.stag_dictionary.idx2word,
109 | stags=stag.data.squeeze().cpu().numpy().tolist()[1:-1]
110 | )
111 |
112 | def process_str_tree(str_tree):
113 | return re.sub('[ |\n]+', ' ', str_tree)
114 |
115 | temp_tree_file.write(process_str_tree(str(pred_tree)) + '\n')
116 | temp_targ_file.write(process_str_tree(str(tree)) + '\n')
117 |
118 | # execute the evalb command:
119 | temp_tree_file.close()
120 | temp_targ_file.close()
121 |
122 | evalb_dir = os.path.join(os.getcwd(), "../EVALB")
123 | evalb_param_path = os.path.join(evalb_dir, "COLLINS.prm")
124 | evalb_program_path = os.path.join(evalb_dir, "evalb")
125 | command = "{} -p {} {} {} > {}".format(
126 | evalb_program_path,
127 | evalb_param_path,
128 | temp_targ_path,
129 | temp_file_path,
130 | temp_eval_path)
131 |
132 | import subprocess
133 | subprocess.run(command, shell=True)
134 | fscore = FScore(math.nan, math.nan, math.nan)
135 |
136 | with open(temp_eval_path) as infile:
137 | for line in infile:
138 | match = re.match(r"Bracketing Recall\s+=\s+(\d+\.\d+)", line)
139 | if match:
140 | fscore.recall = float(match.group(1))
141 | match = re.match(r"Bracketing Precision\s+=\s+(\d+\.\d+)", line)
142 | if match:
143 | fscore.precision = float(match.group(1))
144 | match = re.match(r"Bracketing FMeasure\s+=\s+(\d+\.\d+)", line)
145 | if match:
146 | fscore.fscore = float(match.group(1))
147 | break
148 |
149 | temp_path.cleanup()
150 |
151 | set_loss /= set_counter
152 | set_arc_prec /= arc_counter
153 | set_tag_prec /= tag_counter
154 |
155 | model.train()
156 |
157 | return (set_loss, set_arc_prec, set_tag_prec,
158 | fscore.precision, fscore.recall, fscore.fscore)
159 |
160 |
161 | args = get_args()
162 |
163 | if args.filename is None:
164 | filename = sorted(str(args)[10:-1].split(', '))
165 | filename = [i for i in filename if ('dir' not in i) and
166 | ('tblog' not in i) and
167 | ('fre' not in i) and
168 | ('cuda' not in i) and
169 | ('nlookback' not in i)]
170 | filename = __file__.split('.')[0].split('/')[-1] + '_' + \
171 | '_'.join(filename).replace('=', '') \
172 | .replace('/', '') \
173 | .replace('\'', '') \
174 | .replace('..', '') \
175 | .replace('\"', '')
176 | else:
177 | filename = args.filename
178 |
179 | if not os.path.isdir(args.savepath):
180 | os.mkdir(args.savepath)
181 | parameter_filepath = os.path.join(args.savepath, filename + '.th')
182 | print('model parth:', parameter_filepath)
183 |
184 | print(args)
185 | print("loading data ...")
186 | ptb_parsed = PTBLoader(data_path=args.datapath, use_glove=args.use_glove)
187 |
188 | wordembed = ptb_parsed.wordembed_matrix
189 | args.vocab_size = len(ptb_parsed.dictionary)
190 |
191 | train_log_template = 'epoch {:<3d} batch {:<4d} loss {:<.6f} rank {:<.6f} arc {:<.6f} tag {:<.6f}'
192 | valid_log_template = \
193 | '*** epoch {:<3d} \tloss \tarc prec \ttag prec \tprecision\trecall \tlf1 \n' \
194 | '{:10}DEV\t{:<.6f}\t{:<.6f}\t{:<.6f}\t{:<.6f}\t{:<.6f}\t{:<.6f}\n' \
195 | '{:10}TEST\t{:<.6f}\t{:<.6f}\t{:<.6f}\t{:<.6f}\t{:<.6f}\t{:<.6f}'
196 |
197 | if __name__ == '__main__':
198 | print("building model...")
199 | model = distance_parser(vocab_size=args.vocab_size,
200 | embed_size=args.embedsz,
201 | hid_size=args.hidsz,
202 | arc_size=len(ptb_parsed.arc_dictionary),
203 | stag_size=len(ptb_parsed.stag_dictionary),
204 | window_size=args.window_size,
205 | dropout=args.dpout,
206 | dropoute=args.dpoute,
207 | dropoutr=args.dpoutr,
208 | wordembed=wordembed)
209 | if args.cuda:
210 | model.cuda()
211 |
212 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0, 0.999),
213 | weight_decay=args.weight_decay)
214 | scheduler = ReduceLROnPlateau(optimizer, mode='max', patience=2, factor=0.5, min_lr=0.000001)
215 |
216 | print(" ")
217 | numparams = sum([numpy.prod(i.size()) for i in model.parameters()])
218 | print("Number of params: {0}\n{1:35}{2:35}Size".format(
219 | numparams, 'Name', 'Shape')) # this includes tied parameters
220 | print("---------------------------------------------------------------------------")
221 | for item in model.state_dict().keys():
222 | this_param = model.state_dict()[item]
223 | print("{:60}{!s:35}{}".format(
224 | item, this_param.size(), numpy.prod(this_param.size())))
225 | print(" ")
226 |
227 | # setting up training initial variables; checking out to previous model (if exist)
228 | best_valid_f1 = 0.0
229 | start_epoch = 0
230 |
231 | print("training ...")
232 |
233 | train_idxs, train_tags, train_stags, \
234 | train_arcs, train_distances, \
235 | train_sents, train_trees = ptb_parsed.batchify('train', args.bthsz)
236 | if args.devfre == -1:
237 | args.devfre = len(train_idxs)
238 |
239 | for epoch in range(start_epoch, start_epoch + args.epc):
240 | inds = list(range(len(train_idxs)))
241 | random.shuffle(inds)
242 | epc_train_idxs = [train_idxs[i] for i in inds]
243 | epc_train_tags = [train_tags[i] for i in inds]
244 | epc_train_stags = [train_stags[i] for i in inds]
245 | epc_train_arcs = [train_arcs[i] for i in inds]
246 | epc_train_distances = [train_distances[i] for i in inds]
247 |
248 | for ibatch, (idx, tag, stag, arc, dst) in \
249 | enumerate(
250 | zip(
251 | epc_train_idxs,
252 | epc_train_tags,
253 | epc_train_stags,
254 | epc_train_arcs,
255 | epc_train_distances,
256 | )):
257 |
258 | if args.cuda:
259 | idx = idx.cuda()
260 | tag = tag.cuda()
261 | stag = stag.cuda()
262 | arc = arc.cuda()
263 | dst = dst.cuda()
264 |
265 | mask = (idx >= 0).float()
266 | idx = idx * mask.long()
267 | dstmask = (dst > 0).float()
268 |
269 | optimizer.zero_grad()
270 | pred_dst, pred_arc, pred_tag = model(idx, stag, mask)
271 | loss_rank = rankloss(pred_dst, dst, dstmask)
272 | loss_arc = arcloss(pred_arc, arc.view(-1))
273 | loss_tag = tagloss(pred_tag, tag.view(-1))
274 |
275 | loss = loss_rank + loss_arc + loss_tag
276 | loss.backward()
277 |
278 | nn.utils.clip_grad_norm_(model.parameters(), 1.)
279 | optimizer.step()
280 |
281 | if (ibatch + 1) % args.logfre == 0:
282 | print(train_log_template.format(epoch, ibatch + 1, loss.item(),
283 | loss_rank.item(), loss_arc.item(),
284 | loss_tag.item()))
285 |
286 | #####
287 |
288 | print("Evaluating valid... ")
289 | valid_loss, valid_arc_prec, valid_tag_prec, \
290 | valid_precision, valid_recall, valid_f1 = evaluate(model, ptb_parsed, 'valid')
291 | print("Evaluating test... ")
292 | test_loss, test_arc_prec, test_tag_prec, \
293 | test_precision, test_recall, test_f1 = evaluate(model, ptb_parsed, 'test')
294 | print(valid_log_template.format(
295 | epoch,
296 | ' ', valid_loss, valid_arc_prec, valid_tag_prec,
297 | valid_precision, valid_recall, valid_f1,
298 | ' ', test_loss, test_arc_prec, test_tag_prec,
299 | test_precision, test_recall, test_f1))
300 |
301 | if valid_f1 > best_valid_f1:
302 | best_valid_f1 = valid_f1
303 | torch.save({
304 | 'epoch': epoch,
305 | 'valid_loss': valid_loss,
306 | 'valid_precision': valid_precision,
307 | 'valid_recall': valid_recall,
308 | 'valid_f1': valid_f1,
309 | 'model_state_dict': model.state_dict(),
310 | 'optimizer': optimizer.state_dict(),
311 | }, parameter_filepath)
312 |
313 | scheduler.step(valid_f1)
314 |
--------------------------------------------------------------------------------
/EVALB/README:
--------------------------------------------------------------------------------
1 | #################################################################
2 | # #
3 | # Bug fix and additional functionality for evalb #
4 | # #
5 | # This updated version of evalb fixes a bug in which sentences #
6 | # were incorrectly categorized as "length mismatch" when the #
7 | # the parse output had certain mislabeled parts-of-speech. #
8 | # #
9 | # The bug was the result of evalb treating one of the tags (in #
10 | # gold or test) as a label to be deleted (see sections [6],[7] #
11 | # for details), but not the corresponding tag in the other. #
12 | # This most often occurs with punctuation. See the subdir #
13 | # "bug" for an example gld and tst file demonstating the bug, #
14 | # as well as output of evalb with and without the bug fix. #
15 | # #
16 | # For the present version in case of length mismatch, the nodes #
17 | # causing the imbalance are reinserted to resolve the miscount. #
18 | # If the lengths of gold and test truly differ, the error is #
19 | # still reported. The parameter file "new.prm" (derived from #
20 | # COLLINS.prm) shows how to add new potential mislabelings for #
21 | # quotes (",``,',`). #
22 | # #
23 | # I have preserved DJB's revision for modern compilers except #
24 | # for the delcaration of "exit" which is provided by stdlib. #
25 | # #
26 | # Other changes: #
27 | # #
28 | # * output of F-Measure in addition to precision and recall #
29 | # (I did not update the documention in section [4] for this) #
30 | # #
31 | # * more comprehensive DEBUG output that includes bracketing #
32 | # information as evalb is processing each sentence #
33 | # (useful in working through this, and peraps other bugs). #
34 | # Use either the "-D" run-time switch or set DEBUG to 2 in #
35 | # the parameter file. #
36 | # #
37 | # * added DELETE_LABEL lines in new.prm for S1 nodes produced #
38 | # by the Charniak parser and "?", "!" punctuation produced by #
39 | # the Bikel parser. #
40 | # #
41 | # #
42 | # David Ellis (Brown) #
43 | # #
44 | # January.2006 #
45 | #################################################################
46 |
47 | #################################################################
48 | # #
49 | # Update of evalb for modern compilers #
50 | # #
51 | # This is an updated version of evalb, for use with modern C #
52 | # compilers. There are a few updates, each marked in the code: #
53 | # #
54 | # /* DJB: explanation of comment */ #
55 | # #
56 | # The updates are purely to help compilation with recent #
57 | # versions of GCC (and other C compilers). There are *NO* other #
58 | # changes to the algorithm itself. #
59 | # #
60 | # I have made these changes following recommendations from #
61 | # users of the Corpora Mailing List, especially Peet Morris and #
62 | # Ramon Ziai. #
63 | # #
64 | # David Brooks (Birmingham) #
65 | # #
66 | # September.2005 #
67 | #################################################################
68 |
69 | #################################################################
70 | # #
71 | # README file for evalb #
72 | # #
73 | # Satoshi Sekine (NYU) #
74 | # Mike Collins (UPenn) #
75 | # #
76 | # October.1997 #
77 | #################################################################
78 |
79 | Contents of this README:
80 |
81 | [0] COPYRIGHT
82 | [1] INTRODUCTION
83 | [2] INSTALLATION AND RUN
84 | [3] OPTIONS
85 | [4] OUTPUT FORMAT FROM THE SCORER
86 | [5] HOW TO CREATE A GOLDFILE FROM THE TREEBANK
87 | [6] THE PARAMETER FILE
88 | [7] MORE DETAILS ABOUT THE SCORING ALGORITHM
89 |
90 |
91 | [0] COPYRIGHT
92 |
93 | The authors abandon the copyright of this program. Everyone is
94 | permitted to copy and distribute the program or a portion of the program
95 | with no charge and no restrictions unless it is harmful to someone.
96 |
97 | However, the authors are delightful for the user's kindness of proper
98 | usage and letting the authors know bugs or problems.
99 |
100 | This software is provided "AS IS", and the authors make no warranties,
101 | express or implied.
102 |
103 | To legally enforce the abandonment of copyright, this package is released
104 | under the Unlicense (see LICENSE).
105 |
106 | [1] INTRODUCTION
107 |
108 | Evaluation of bracketing looks simple, but in fact, there are minor
109 | differences from system to system. This is a program to parametarize
110 | such minor differences and to give an informative result.
111 |
112 | "evalb" evaluates bracketing accuracy in a test-file against a gold-file.
113 | It returns recall, precision, tagging accuracy. It uses an identical
114 | algorithm to that used in (Collins ACL97).
115 |
116 |
117 | [2] Installation and Run
118 |
119 | To compile the scorer, type
120 |
121 | > make
122 |
123 |
124 | To run the scorer:
125 |
126 | > evalb -p Parameter_file Gold_file Test_file
127 |
128 |
129 | For example to use the sample files:
130 |
131 | > evalb -p sample.prm sample.gld sample.tst
132 |
133 |
134 |
135 | [3] OPTIONS
136 |
137 | You can specify system parameters in the command line options.
138 | Other options concerning to evaluation metrix should be specified
139 | in parameter file, described later.
140 |
141 | -p param_file parameter file
142 | -d debug mode
143 | -e n number of error to kill (default=10)
144 | -h help
145 |
146 |
147 |
148 | [4] OUTPUT FORMAT FROM THE SCORER
149 |
150 | The scorer gives individual scores for each sentence, for
151 | example:
152 |
153 | Sent. Matched Bracket Cross Correct Tag
154 | ID Len. Stat. Recal Prec. Bracket gold test Bracket Words Tags Accracy
155 | ============================================================================
156 | 1 8 0 100.00 100.00 5 5 5 0 6 5 83.33
157 |
158 | At the end of the output the === Summary === section gives statistics
159 | for all sentences, and for sentences <=40 words in length. The summary
160 | contains the following information:
161 |
162 | i) Number of sentences -- total number of sentences.
163 |
164 | ii) Number of Error/Skip sentences -- should both be 0 if there is no
165 | problem with the parsed/gold files.
166 |
167 | iii) Number of valid sentences = Number of sentences - Number of Error/Skip
168 | sentences
169 |
170 | iv) Bracketing recall = (number of correct constituents)
171 | ----------------------------------------
172 | (number of constituents in the goldfile)
173 |
174 | v) Bracketing precision = (number of correct constituents)
175 | ----------------------------------------
176 | (number of constituents in the parsed file)
177 |
178 | vi) Complete match = percentaage of sentences where recall and precision are
179 | both 100%.
180 |
181 | vii) Average crossing = (number of constituents crossing a goldfile constituen
182 | ----------------------------------------------------
183 | (number of sentences)
184 |
185 | viii) No crossing = percentage of sentences which have 0 crossing brackets.
186 |
187 | ix) 2 or less crossing = percentage of sentences which have <=2 crossing brackets.
188 |
189 | x) Tagging accuracy = percentage of correct POS tags (but see [5].3 for exact
190 | details of what is counted).
191 |
192 |
193 |
194 | [5] HOW TO CREATE A GOLDFILE FROM THE PENN TREEBANK
195 |
196 |
197 | The gold and parsed files are in a format similar to this:
198 |
199 | (TOP (S (INTJ (RB No)) (, ,) (NP (PRP it)) (VP (VBD was) (RB n't) (NP (NNP Black) (NNP Monday))) (. .)))
200 |
201 | To create a gold file from the treebank:
202 |
203 | tgrep -wn '/.*/' | tgrep_proc.prl
204 |
205 | will produce a goldfile in the required format. ("tgrep -wn '/.*/'" prints
206 | parse trees, "tgrep_process.prl" just skips blank lines).
207 |
208 | For example, to produce a goldfile for section 23 of the treebank:
209 |
210 | tgrep -wn '/.*/' | tail +90895 | tgrep_process.prl | sed 2416q > sec23.gold
211 |
212 |
213 |
214 | [6] THE PARAMETER (.prm) FILE
215 |
216 |
217 | The .prm file sets options regarding the scoring method. COLLINS.prm gives
218 | the same scoring behaviour as the scorer used in (Collins 97). The options
219 | chosen were:
220 |
221 | 1) LABELED 1
222 |
223 | to give labelled precision/recall figures, i.e. a constituent must have the
224 | same span *and* label as a constituent in the goldfile.
225 |
226 | 2) DELETE_LABEL TOP
227 |
228 | Don't count the "TOP" label (which is always given in the output of tgrep)
229 | when scoring.
230 |
231 | 3) DELETE_LABEL -NONE-
232 |
233 | Remove traces (and all constituents which dominate nothing but traces) when
234 | scoring. For example
235 |
236 | .... (VP (VBD reported) (SBAR (-NONE- 0) (S (-NONE- *T*-1)))) (. .)))
237 |
238 | would be processed to give
239 |
240 | .... (VP (VBD reported)) (. .)))
241 |
242 |
243 | 4)
244 | DELETE_LABEL , -- for the purposes of scoring remove punctuation
245 | DELETE_LABEL :
246 | DELETE_LABEL ``
247 | DELETE_LABEL ''
248 | DELETE_LABEL .
249 |
250 | 5) DELETE_LABEL_FOR_LENGTH -NONE- -- don't include traces when calculating
251 | the length of a sentence (important
252 | when classifying a sentence as <=40
253 | words or >40 words)
254 |
255 | 6) EQ_LABEL ADVP PRT
256 |
257 | Count ADVP and PRT as being the same label when scoring.
258 |
259 |
260 |
261 |
262 | [7] MORE DETAILS ABOUT THE SCORING ALGORITHM
263 |
264 |
265 | 1) The scorer initially processes the files to remove all nodes specified
266 | by DELETE_LABEL in the .prm file. It also recursively removes nodes which
267 | dominate nothing due to all their children being removed. For example, if
268 | -NONE- is specified as a label to be deleted,
269 |
270 | .... (VP (VBD reported) (SBAR (-NONE- 0) (S (-NONE- *T*-1)))) (. .)))
271 |
272 | would be processed to give
273 |
274 | .... (VP (VBD reported)) (. .)))
275 |
276 | 2) The scorer also removes all functional tags attached to non-terminals
277 | (functional tags are prefixed with "-" or "=" in the treebank). For example
278 | "NP-SBJ" is processed to give "NP", "NP=2" is changed to "NP".
279 |
280 |
281 | 3) Tagging accuracy counts tags for all words *except* any tags which are
282 | deleted by a DELETE_LABEL specification in the .prm file. (For example, for
283 | COLLINS.prm, punctuation tagged as "," ":" etc. would not be included).
284 |
285 | 4) When calculating the length of a sentence, all words with POS tags not
286 | included in the "DELETE_LABEL_FOR_LENGTH" list in the .prm file are
287 | counted. (For COLLINS.prm, only "-NONE-" is specified in this list, so
288 | traces are removed before calculating the length of the sentence).
289 |
290 | 5) There are some subtleties in scoring when either the goldfile or parsed
291 | file contains multiple constituents for the same span which have the same
292 | non-terminal label. e.g. (NP (NP the man)) If the goldfile contains n
293 | constituents for the same span, and the parsed file contains m constituents
294 | with that nonterminal, the scorer works as follows:
295 |
296 | i) If m>n, then the precision is n/m, recall is 100%
297 |
298 | ii) If n>m, then the precision is 100%, recall is m/n.
299 |
300 | iii) If n==m, recall and precision are both 100%.
301 |
--------------------------------------------------------------------------------
/EVALB/evalb.c:
--------------------------------------------------------------------------------
1 | /*****************************************************************/
2 | /* evalb [-p param_file] [-dh] [-e n] gold-file test-file */
3 | /* */
4 | /* Evaluate bracketing in test-file against gold-file. */
5 | /* Return recall, precision, tagging accuracy. */
6 | /* */
7 | /*