├── .idea
├── encodings.xml
├── misc.xml
├── modules.xml
├── neural-chinese-address-parsing.iml
└── vcs.xml
├── README.md
├── conlleval.pl
├── data
├── anno
│ ├── anno-cn.md
│ └── anno-en.md
├── dev.txt
├── giga.vec100
├── labels.txt
├── test.txt
└── train.txt
├── exp_dytree.sh
├── giga.emb
├── log.jpg
└── src
├── evaluate.py
├── latent.py
├── latenttrees.py
├── main_dyRBT.py
├── parse.py
├── trees.py
├── util.py
└── vocabulary.py
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/neural-chinese-address-parsing.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Neural Chinese Address Parsing
2 | This page contains the code used in the work ["Neural Chinese Address Parsing"](http://statnlp.org/research/sp) published at [NAACL 2019](https://naacl2019.org/program/accepted/).
3 |
4 |
5 | ## Contents
6 | 1. [Usage](#usage)
7 | 2. [SourceCode](#sourcecode)
8 | 3. [Data](#data)
9 | 4. [Citation](#citation)
10 | 5. [Credits](#credits)
11 |
12 |
13 | ## Usage
14 |
15 | Prerequisite: Python (3.5 or later), Dynet (2.0 or later)
16 |
17 | Run the following command to try out the APLT(sp=7) model in the paper.
18 | ```sh
19 | ./exp_dytree.sh
20 | ```
21 | After the training is complete, type the following command to display the result on test data. The performance outputed by conlleval.pl is shown as below.
22 | ```sh
23 | perl conlleval.pl < addr_dytree_giga_0.4_200_1_chardyRBTC_dytree_1_houseno_0_0.test.txt
24 | ```
25 |
26 | 
27 |
28 | ## SourceCode
29 |
30 | The source code is written in Dynet, which can be found under the "src" folder.
31 |
32 |
33 | ## Data
34 |
35 | The **data** is stored in "data" folder containing "train.txt", "dev.txt" and "test.txt". The embedding file "giga.vec100" is also located in the folder "data".
36 |
37 | **The annotation guidelines** are in the folder ["data/anno"](https://github.com/leodotnet/neural-chinese-address-parsing/blob/master/data/anno). Both [Chinese](https://github.com/leodotnet/neural-chinese-address-parsing/blob/master/data/anno/anno-cn.md) and [English](https://github.com/leodotnet/neural-chinese-address-parsing/blob/master/data/anno/anno-en.md) versions are available.
38 |
39 | ## Citation
40 | If you use this software for research, please cite our paper as follows:
41 |
42 | ```
43 | @InProceedings{chineseaddressparsing19li,
44 | author = "Li, Hao and Lu, Wei and Xie, Pengjun and Li, Linlin",
45 | title = "Neural Chinese Address Parsing",
46 | booktitle = "Proc. of NAACL",
47 | year = "2019",
48 | }
49 | ```
50 |
51 |
52 | ## Credits
53 | The code in this repository are based on https://github.com/mitchellstern/minimal-span-parser
54 |
55 | Email to [hao_li@mymail.sutd.edu.sg](hao_li@mymail.sutd.edu.sg) if any inquery.
56 |
--------------------------------------------------------------------------------
/conlleval.pl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/perl -w
2 | # conlleval: evaluate result of processing CoNLL-2000 shared task
3 | # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file
4 | # README: http://cnts.uia.ac.be/conll2000/chunking/output.html
5 | # options: l: generate LaTeX output for tables like in
6 | # http://cnts.uia.ac.be/conll2003/ner/example.tex
7 | # r: accept raw result tags (without B- and I- prefix;
8 | # assumes one word per chunk)
9 | # d: alternative delimiter tag (default is single space)
10 | # o: alternative outside tag (default is O)
11 | # note: the file should contain lines with items separated
12 | # by $delimiter characters (default space). The final
13 | # two items should contain the correct tag and the
14 | # guessed tag in that order. Sentences should be
15 | # separated from each other by empty lines or lines
16 | # with $boundary fields (default -X-).
17 | # url: http://lcg-www.uia.ac.be/conll2000/chunking/
18 | # started: 1998-09-25
19 | # version: 2004-01-26
20 | # author: Erik Tjong Kim Sang
21 |
22 | use strict;
23 |
24 | my $false = 0;
25 | my $true = 42;
26 |
27 | my $boundary = "-X-"; # sentence boundary
28 | my $correct; # current corpus chunk tag (I,O,B)
29 | my $correctChunk = 0; # number of correctly identified chunks
30 | my $correctTags = 0; # number of correct chunk tags
31 | my $correctType; # type of current corpus chunk tag (NP,VP,etc.)
32 | my $delimiter = "\t"; # field delimiter
33 | my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979)
34 | my $firstItem; # first feature (for sentence boundary checks)
35 | my $foundCorrect = 0; # number of chunks in corpus
36 | my $foundGuessed = 0; # number of identified chunks
37 | my $guessed; # current guessed chunk tag
38 | my $guessedType; # type of current guessed chunk tag
39 | my $i; # miscellaneous counter
40 | my $inCorrect = $false; # currently processed chunk is correct until now
41 | my $lastCorrect = "O"; # previous chunk tag in corpus
42 | my $latex = 0; # generate LaTeX formatted output
43 | my $lastCorrectType = ""; # type of previously identified chunk tag
44 | my $lastGuessed = "O"; # previously identified chunk tag
45 | my $lastGuessedType = ""; # type of previous chunk tag in corpus
46 | my $lastType; # temporary storage for detecting duplicates
47 | my $line; # line
48 | my $nbrOfFeatures = -1; # number of features per line
49 | my $precision = 0.0; # precision score
50 | my $oTag = "O"; # outside tag, default O
51 | my $raw = 0; # raw input: add B to every token
52 | my $recall = 0.0; # recall score
53 | my $tokenCounter = 0; # token counter (ignores sentence breaks)
54 |
55 | my %correctChunk = (); # number of correctly identified chunks per type
56 | my %foundCorrect = (); # number of chunks in corpus per type
57 | my %foundGuessed = (); # number of identified chunks per type
58 |
59 | my @features; # features on line
60 | my @sortedTypes; # sorted list of chunk type names
61 |
62 | # sanity check
63 | while (@ARGV and $ARGV[0] =~ /^-/) {
64 | if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); }
65 | elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); }
66 | elsif ($ARGV[0] eq "-d") {
67 | shift(@ARGV);
68 | if (not defined $ARGV[0]) {
69 | die "conlleval: -d requires delimiter character";
70 | }
71 | $delimiter = shift(@ARGV);
72 | } elsif ($ARGV[0] eq "-o") {
73 | shift(@ARGV);
74 | if (not defined $ARGV[0]) {
75 | die "conlleval: -o requires delimiter character";
76 | }
77 | $oTag = shift(@ARGV);
78 | } else { die "conlleval: unknown argument $ARGV[0]\n"; }
79 | }
80 | if (@ARGV) { die "conlleval: unexpected command line argument\n"; }
81 | # process input
82 | while () {
83 | chomp($line = $_);
84 | @features = split(/$delimiter/,$line);
85 | if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; }
86 | elsif ($nbrOfFeatures != $#features and @features != 0) {
87 | printf STDERR "unexpected number of features: %d (%d)\n",
88 | $#features+1,$nbrOfFeatures+1;
89 | exit(1);
90 | }
91 | if (@features == 0 or
92 | $features[0] eq $boundary) { @features = ($boundary,"O","O"); }
93 | if (@features < 2) {
94 | die "conlleval: unexpected number of features in line $line\n";
95 | }
96 | if ($raw) {
97 | if ($features[$#features] eq $oTag) { $features[$#features] = "O"; }
98 | if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; }
99 | if ($features[$#features] ne "O") {
100 | $features[$#features] = "B-$features[$#features]";
101 | }
102 | if ($features[$#features-1] ne "O") {
103 | $features[$#features-1] = "B-$features[$#features-1]";
104 | }
105 | }
106 | # 20040126 ET code which allows hyphens in the types
107 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
108 | $guessed = $1;
109 | $guessedType = $2;
110 | } else {
111 | $guessed = $features[$#features];
112 | $guessedType = "";
113 | }
114 | pop(@features);
115 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
116 | $correct = $1;
117 | $correctType = $2;
118 | } else {
119 | $correct = $features[$#features];
120 | $correctType = "";
121 | }
122 | pop(@features);
123 | # ($guessed,$guessedType) = split(/-/,pop(@features));
124 | # ($correct,$correctType) = split(/-/,pop(@features));
125 | $guessedType = $guessedType ? $guessedType : "";
126 | $correctType = $correctType ? $correctType : "";
127 | $firstItem = shift(@features);
128 |
129 | # 1999-06-26 sentence breaks should always be counted as out of chunk
130 | if ( $firstItem eq $boundary ) { $guessed = "O"; }
131 |
132 | if ($inCorrect) {
133 | if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
134 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
135 | $lastGuessedType eq $lastCorrectType) {
136 | $inCorrect=$false;
137 | $correctChunk++;
138 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
139 | $correctChunk{$lastCorrectType}+1 : 1;
140 | } elsif (
141 | &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) !=
142 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or
143 | $guessedType ne $correctType ) {
144 | $inCorrect=$false;
145 | }
146 | }
147 |
148 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
149 | &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
150 | $guessedType eq $correctType) { $inCorrect = $true; }
151 |
152 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) {
153 | $foundCorrect++;
154 | $foundCorrect{$correctType} = $foundCorrect{$correctType} ?
155 | $foundCorrect{$correctType}+1 : 1;
156 | }
157 | if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) {
158 | $foundGuessed++;
159 | $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ?
160 | $foundGuessed{$guessedType}+1 : 1;
161 | }
162 | if ( $firstItem ne $boundary ) {
163 | if ( $correct eq $guessed and $guessedType eq $correctType ) {
164 | $correctTags++;
165 | }
166 | $tokenCounter++;
167 | }
168 |
169 | $lastGuessed = $guessed;
170 | $lastCorrect = $correct;
171 | $lastGuessedType = $guessedType;
172 | $lastCorrectType = $correctType;
173 | }
174 | if ($inCorrect) {
175 | $correctChunk++;
176 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
177 | $correctChunk{$lastCorrectType}+1 : 1;
178 | }
179 |
180 | if (not $latex) {
181 | # compute overall precision, recall and FB1 (default values are 0.0)
182 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
183 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
184 | $FB1 = 2*$precision*$recall/($precision+$recall)
185 | if ($precision+$recall > 0);
186 |
187 | # print overall performance
188 | printf "processed $tokenCounter tokens with $foundCorrect phrases; ";
189 | printf "found: $foundGuessed phrases; correct: $correctChunk.\n";
190 | if ($tokenCounter>0) {
191 | printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter;
192 | printf "precision: %6.2f%%; ",$precision;
193 | printf "recall: %6.2f%%; ",$recall;
194 | printf "FB1: %6.2f\n",$FB1;
195 | }
196 | }
197 |
198 | # sort chunk type names
199 | undef($lastType);
200 | @sortedTypes = ();
201 | foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) {
202 | if (not($lastType) or $lastType ne $i) {
203 | push(@sortedTypes,($i));
204 | }
205 | $lastType = $i;
206 | }
207 | # print performance per chunk type
208 | if (not $latex) {
209 | for $i (@sortedTypes) {
210 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
211 | if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; }
212 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
213 | if (not($foundCorrect{$i})) { $recall = 0.0; }
214 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
215 | if ($precision+$recall == 0.0) { $FB1 = 0.0; }
216 | else { $FB1 = 2*$precision*$recall/($precision+$recall); }
217 | printf "%17s: ",$i;
218 | printf "precision: %6.2f%%; ",$precision;
219 | printf "recall: %6.2f%%; ",$recall;
220 | printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i};
221 | }
222 | } else {
223 | print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline";
224 | for $i (@sortedTypes) {
225 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
226 | if (not($foundGuessed{$i})) { $precision = 0.0; }
227 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
228 | if (not($foundCorrect{$i})) { $recall = 0.0; }
229 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
230 | if ($precision+$recall == 0.0) { $FB1 = 0.0; }
231 | else { $FB1 = 2*$precision*$recall/($precision+$recall); }
232 | printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\",
233 | $i,$precision,$recall,$FB1;
234 | }
235 | print "\\hline\n";
236 | $precision = 0.0;
237 | $recall = 0;
238 | $FB1 = 0.0;
239 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
240 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
241 | $FB1 = 2*$precision*$recall/($precision+$recall)
242 | if ($precision+$recall > 0);
243 | printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n",
244 | $precision,$recall,$FB1;
245 | }
246 |
247 | exit 0;
248 |
249 | # endOfChunk: checks if a chunk ended between the previous and current word
250 | # arguments: previous and current chunk tags, previous and current types
251 | # note: this code is capable of handling other chunk representations
252 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
253 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
254 |
255 | sub endOfChunk {
256 | my $prevTag = shift(@_);
257 | my $tag = shift(@_);
258 | my $prevType = shift(@_);
259 | my $type = shift(@_);
260 | my $chunkEnd = $false;
261 |
262 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; }
263 | if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; }
264 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; }
265 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; }
266 |
267 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; }
268 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; }
269 | if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; }
270 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; }
271 |
272 | if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) {
273 | $chunkEnd = $true;
274 | }
275 |
276 | # corrected 1998-12-22: these chunks are assumed to have length 1
277 | if ( $prevTag eq "]" ) { $chunkEnd = $true; }
278 | if ( $prevTag eq "[" ) { $chunkEnd = $true; }
279 |
280 | return($chunkEnd);
281 | }
282 |
283 | # startOfChunk: checks if a chunk started between the previous and current word
284 | # arguments: previous and current chunk tags, previous and current types
285 | # note: this code is capable of handling other chunk representations
286 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
287 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
288 |
289 | sub startOfChunk {
290 | my $prevTag = shift(@_);
291 | my $tag = shift(@_);
292 | my $prevType = shift(@_);
293 | my $type = shift(@_);
294 | my $chunkStart = $false;
295 |
296 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; }
297 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; }
298 | if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; }
299 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; }
300 |
301 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; }
302 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; }
303 | if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; }
304 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; }
305 |
306 | if ($tag ne "O" and $tag ne "." and $prevType ne $type) {
307 | $chunkStart = $true;
308 | }
309 |
310 | # corrected 1998-12-22: these chunks are assumed to have length 1
311 | if ( $tag eq "[" ) { $chunkStart = $true; }
312 | if ( $tag eq "]" ) { $chunkStart = $true; }
313 |
314 | return($chunkStart);
315 | }
316 |
--------------------------------------------------------------------------------
/data/anno/anno-cn.md:
--------------------------------------------------------------------------------
1 | # 地址标准化标注规范
2 | ## 标注目标
3 | 地址结构化:识别地址中的行政区划、其它地址要素、辅助定位词,并打上英文标注。标注包括每个要素的范围和标签。
4 |
5 |
6 | ## 标注示例
7 | ####【prov】
8 | 1. 解释:省级行政区划,省、自治区;
9 | 2. 示例1:
10 | 1. Query:内蒙古赤峰市锦山镇
11 | 2. 结构化:prov=内蒙古 city=赤峰市 town=锦山镇
12 |
13 | ####【city】
14 | 1. 解释:地级行政区划,地级市、直辖市、地区、自治州等;
15 | 2. 示例1:
16 | 1. Query:杭州市富阳区戴家墩路91号东阳诚心木线(富阳店)
17 | 2. 结构化:city=杭州市 district=富阳区 road=戴家墩路 roadno=91号 poi=东阳诚心木线(富阳店)
18 | 3. 示例2:
19 | 1. Query:西藏自治区日喀则地区定日县柑碑村712号
20 | 2. 结构化:prov=西藏自治区 city=日喀则地区 district=定日县 community=柑碑村 roadno=712号
21 |
22 | ####【district】
23 | 1. 解释:县级行政区划,市辖区、县级市、县等;
24 | 2. 示例1:
25 | 1. Query:东城区福新东路245号
26 | 2. 结构化:district=东城区 road=福新东路 roadno=245号
27 |
28 | ####【devzone】
29 | 1. 解释:非正式行政区划,级别在district和town之间,或者town之后,特指经济技术开发区;
30 | 2. 示例1:
31 | 1. Query:内蒙古自治区呼和浩特市土默特左旗金川开发区公元仰山9号楼2单元202
32 | 2. 结构化:prov=内蒙古自治区 city=呼和浩特市 district=土默特左旗 town=察素齐镇 devzone=金川开发区 poi=公元仰山 houseno=9号楼 cellno=2单元 roomno=202室
33 | 3. 示例2:
34 | 1. Query:南宁市青秀区仙葫经济开发区开泰路148号广西警察学院仙葫校区
35 | 2. 结构化:city=南宁市 district=青秀区 devzone=仙葫经济开发区 road=开泰路 roadno=148号 poi=广西警察学院仙葫校区
36 |
37 | ####【town】
38 | 1. 解释:乡级行政区划,镇、街道、乡等。
39 | 2. 示例1:
40 | 1. Query:上海市 静安区 共和新路街道 柳营路669弄14号1102
41 | 2. 结构化:city=上海市 district=静安区 town=共和新路街道 road=柳营路 roadno=669弄 houseno=14号 roomno=1102室
42 | 3. 示例2:
43 | 1. Query:五常街道顾家桥社区河西北9号衣服鞋子店
44 | 2. 结构化:town=五常街道 community=顾家桥社区 road=河西北 roadno=9号 poi=衣服鞋子店
45 |
46 | ####【community】
47 | 1. 解释:社区、自然村;
48 | 2. 示例1:
49 | 1. Query:张庆乡北胡乔村
50 | 2. 结构化:town=张庆乡 community=北胡乔村
51 | 3. 示例2:
52 | 1. Query:五常街道顾家桥社区河西北9号衣服鞋子店
53 | 2. 结构化:town=五常街道 community=顾家桥社区 road=河西北 roadno=9号 poi=衣服鞋子店
54 |
55 | ####【road】
56 | 1. 解释:道路、组;
57 | 2. 示例1:
58 | 1. Query:静安区江场三路238号1613室
59 | 2. 结构化:district=静安区 road=江场三路 roadno=238号 roomno=1613室
60 | 3. 示例2:
61 | 1. Query:沿山村5组
62 | 2. 结构化:community=沿山村 road=5组
63 | 4. 示例2:
64 | 1. Query:江宁区江宁滨江开发区中环大道10号环宇人力行政部
65 | 2. 结构化:district=江宁区 town=江宁街道 devzone=江宁滨江开发区 road=中环大道 roadno=10号 poi=环宇人力行政部
66 |
67 | ####【roadno】
68 | 1. 解释:门牌号、主路号、组号、组号附号;
69 | 2. 示例1:
70 | 1. Query:江宁区江宁滨江开发区中环大道10号环宇人力行政部
71 | 2. 结构化:district=江宁区 town=江宁街道 devzone=江宁滨江开发区 road=中环大道 roadno=10号 poi=环宇人力行政部
72 | 3. 示例2:
73 | 1. Query:沿山村5组6号
74 | 2. 结构化:community=沿山村 roadno=5组6号
75 | 4. 注意:出现“X组X号”,“X组X号附X号”,“X号附X号” 这三种情况,统一标为roadno,不拆分road和roadno。
76 |
77 | ####【subroad】
78 | 1. 解释:子路、支路、辅路;
79 | 2. 示例1:
80 | 1. Query:浙江省台州市临海市江南大道创业大道288号
81 | 2. 结构化:prov=浙江省 city=台州市 district=临海市 town=江南街道 road=江南大道 subroad=创业大道 subroadno=288号
82 |
83 |
84 | ####【subroadno】
85 | 1. 解释:(子路、支路、辅路)门牌号、路号、组号、组号附号;
86 | 2. 示例1:见上一条。
87 |
88 | ####【poi】
89 | 1. 解释:兴趣点;
90 | 2. 示例1:
91 | 1. Query:浙江省杭州市余杭区五常街道文一西路969号阿里巴巴西溪园区
92 | 2. 结构化:prov=浙江省 city=杭州市 district=余杭区 town=五常街道 road=文一西路 roadno=969号 poi=阿里巴巴西溪园区
93 |
94 | ####【subpoi】
95 | 1. 解释:子兴趣点;
96 | 2. 示例1:
97 | 1. Query:新疆维吾尔自治区 昌吉回族自治州 昌吉市 延安北路街道 延安南路石油小区东门
98 | 2. 结构化:prov=新疆维吾尔自治区 city=昌吉回族自治州 district=昌吉市 town=延安北路街道 road=延安南路 poi=石油小区 subpoi=东门
99 | 3. 示例2:
100 | 1. Query:西湖区新金都城市花园西雅园10幢3底层
101 | 2. 结构化:district=西湖区 town=文新街道 road=古墩路 roadno=415号 poi=新金都城市花园 subpoi=西雅园 houseno=10幢 floorno=3底层
102 | 4. 示例3:
103 | 1. Query:广宁伯街2号金泽大厦东区15层
104 | 2. 结构化:road=广宁伯街 roadno=2号 poi=金泽大厦 subpoi=东区 floorno=15层
105 | 5. 注意:poi后面还有一个poi,而且地域上属于之前poi的范围内或附近,可以认为是subpoi,通常见于“XX小区【东门】”、“XX大厦【收发室】”、“阿里巴巴西溪园区【全家】”,“西溪花园【蒹葭苑】”,“西溪北苑【东区】”;特别需要注意“XXX小区一期”、“XXX小区二期”这种,整体是一个POI,“一期”不是一个subpoi。
106 |
107 | ####【houseno】
108 | 1. 解释:楼栋号;
109 | 2. 示例1:
110 | 1. Query:阿里巴巴西溪园区6号楼小邮局
111 | 2. 结构化:poi=阿里巴巴西溪园区 houseno=6号楼 person=小邮局
112 | 3. 示例2:
113 | 1. Query:四川省 成都市 金牛区 沙河源街道 金牛区九里堤街道 金府机电城A区3栋16号
114 | 2. 结构化:prov=四川省 city=成都市 district=金牛区 town=沙河源街道 road=金府路 poi=金府机电城 subpo=A区 houseno=3栋 cellno=16号
115 | 4. 示例3:
116 | 1. Query:竹海水韵春风里12-3-1001
117 | 2. 结构化:poi=竹海水韵 subpoi=春风里 houseno=12幢 cellno=3单元 roomno= 1001室
118 | 5. 注意:一般跟在poi之后,常见于“智慧产业园二期【一号楼】”、“春风里【7】-2-801”。
119 |
120 | ####【cellno】
121 | 1. 解释:单元号;
122 | 2. 示例1:
123 | 1. Query:竹海水韵春风里12-3-1001
124 | 2. 结构化:poi=竹海水韵 subpoi=春风里 houseno=12幢 cellno=3单元 roomno= 1001室
125 | 3. 示例2:
126 | 1. Query:蒋村花园新达苑18幢二单元101
127 | 2. 结构化:town=蒋村街道 community=蒋村花园社区 road=晴川街 poi=蒋村花园 subpoi=新达苑 houseno=18幢 cellno=二单元 roomno=101室
128 |
129 | ####【floorno】
130 | 1. 解释:楼层号;
131 | 2. 示例1:
132 | 1. Query:北京市东城区东中街29号东环广场B座5层信达资本
133 | 2. 结构化:city=北京市 district=东城区 town=东直门街道 road=东中街 roadno=29号 poi=东环广场 houseno=B座 floorno=5层 person=信达资本
134 |
135 | ####【roomno】
136 | 1. 解释:房间号、户号;
137 | 2. 示例1:见cellno的示例
138 |
139 | ####【person】
140 | 1. 解释:楼栋内的“POI”,企业、法人、商铺名等。
141 | 2. 示例1:
142 | 1. Query:北京 北京市 西城区 广安门外街道 马连道马正和大厦3层我的未来网总部
143 | 2. 结构化:city=北京市 district=西城区 town=广安门外街道 road=马连道 poi=马正和大厦 floorno=3层 person=我的未来网总部
144 | 3. 示例2:
145 | 1. Query:浙江省 杭州市 余杭区 良渚街道沈港路11号2楼 常春藤公司
146 | 2. 结构化:prov=浙江省 city=杭州市 district=余杭区 town=良渚街道 road=沈港路 roadno=11号 floorno=2楼 person=常春藤公司
147 |
148 |
149 | ####【assist】
150 | 1. 解释:普通辅助定位词,比如门口、旁边、附近、楼下、边上等等较为一般的或者模糊的定位词;
151 | 2. 示例1:
152 | 1. Query:广西柳州市城中区潭中东路勿忘我网吧门口
153 | 2. 结构化:prov=广西壮族自治区 city=柳州市 district=城中区 town=潭中街道 road=潭中东路 poi=勿忘我网吧 assist=门口
154 |
155 | ####【redundant】
156 | 1. 解释:无意义词,冗余等对地址无帮助的信息;
157 | 2. 示例1:
158 | 1. Query:浙江省 杭州市 滨江区 浙江省 杭州市 滨江区 六和路东50米 六和路东信大道口自行车租赁点
159 | 2. 结构化:prov=浙江省 city=杭州市 district=滨江区 redundant=东 浙江省杭州市滨江区 town=浦沿街道 road=六和路 subroad=东信大道 intersection=口 poi=自行车租赁点
160 |
161 | 3. 示例2:
162 | 1. Query:浙江省 杭州市 滨江区 六和路 ---- 东信大道口自行车租赁点
163 | 2. 结构化:prov=浙江省 city=杭州市 district=滨江区 town=浦沿街道 road=六和路 redundant=---- subroad=东信大道 intersection=口 poi=自行车租赁点
164 |
165 | ####【otherinfo】
166 | 1. 解释:其他无法分类的信息。
167 |
168 | ## 标注篇序关系
169 | label之间存在较强的篇序关系,比如说city不能跑到prov前面去,具体有如下几种偏序关系:
170 |
171 | #### 1. prov > city > district > town > comm > road > roadno > poi > houseno > cellno > floorno > roomno
172 | #### 2. district > devzone
173 | #### 3. devzone > comm
174 | #### 4. road > subroad
175 | #### 5. poi > subpoi
176 |
177 |
--------------------------------------------------------------------------------
/data/anno/anno-en.md:
--------------------------------------------------------------------------------
1 | # Chinese Address Parsing Annotation Guideline
2 | ## Goal
3 | Recognize all the chunks (e.g. city, road, etc.) in the given Chinese address. For each chunk, the boundary and the label are required to be annotated.
4 |
5 | ####【prov】
6 | 1. Interpretation: province, autonomous region
7 | 2. Example 1:
8 | 1. Query: 内蒙古赤峰市锦山镇
9 | 2. Annotation: prov=内蒙古 city=赤峰市 district=喀喇沁旗 town=锦山镇
10 | 3. Example 2:
11 | 1. Query: 渭南市大荔县户家乡边章营村
12 | 2. Annotation: prov=陕西省 city=渭南市 district=大荔县 town=户家乡 community=边章营村
13 |
14 | ####【city】
15 | 1. Interpretation: city,municipality,autonomous District
16 | 2. Example 1:
17 | 1. Query: 杭州市富阳区戴家墩路91号东阳诚心木线(富阳店)
18 | 2. Annotation: city=杭州市 district=富阳区 road=戴家墩路 roadno=91号 poi=东阳诚心木线(富阳店)
19 | 3. Example 2:
20 | 1. Query: 西藏自治区日喀则地区定日县柑碑村712号
21 | 2. Annotation: prov=西藏自治区 city=日喀则地区 district=定日县 community=柑碑村 roadno=712号
22 |
23 | ####【district】
24 | 1. Interpretation: district
25 | 2. Example 1:
26 | 1. Query: 东城区福新东路245号
27 | 2. Annotation: district=东城区 road=福新东路 roadno=245号
28 |
29 | ####【devzone】
30 | 1. Interpretation: economical development zone
31 | 2. Example 1:
32 | 1. Query: 内蒙古自治区呼和浩特市土默特左旗金川开发区公元仰山9号楼2单元202
33 | 2. Annotation: prov=内蒙古自治区 city=呼和浩特市 district=土默特左旗 town=察素齐镇 devzone=金川开发区 poi=公元仰山 houseno=9号楼 cellno=2单元 roomno=202室
34 | 3. Example 2:
35 | 1. Query: 南宁市青秀区仙葫经济开发区开泰路148号广西警察学院仙葫校区
36 | 2. Annotation: city=南宁市 district=青秀区 devzone=仙葫经济开发区 road=开泰路 roadno=148号 poi=广西警察学院仙葫校区
37 |
38 | ####【town】
39 | 1. Interpretation: town, administrative street;
40 | 2. Example 1:
41 | 1. Query: 上海市 静安区 共和新路街道 柳营路669弄14号1102
42 | 2. Annotation: city=上海市 district=静安区 town=共和新路街道 road=柳营路 roadno=669弄 houseno=14号 roomno=1102室
43 | 3. Example 2:
44 | 1. Query: 五常街道顾家桥社区河西北9号衣服鞋子店
45 | 2. Annotation: town=五常街道 community=顾家桥社区 road=河西北 roadno=9号 poi=衣服鞋子店
46 |
47 | ####【community】
48 | 1. Interpretation: community
49 | 2. Example 1:
50 | 1. Query: 张庆乡北胡乔村
51 | 2. Annotation: town=张庆乡 community=北胡乔村
52 | 3. Example 2:
53 | 1. Query: 五常街道顾家桥社区河西北9号衣服鞋子店
54 | 2. Annotation: town=五常街道 community=顾家桥社区 road=河西北 roadno=9号 poi=衣服鞋子店
55 |
56 | ####【road】
57 | 1. Interpretation: road
58 | 2. Example 1:
59 | 1. Query: 静安区江场三路238号1613室
60 | 2. Annotation: district=静安区 road=江场三路 roadno=238号 roomno=1613室
61 | 3. Example 2:
62 | 1. Query: 沿山村5组
63 | 2. Annotation: community=沿山村 road=5组
64 | 4. Example 2:
65 | 1. Query: 江宁区江宁滨江开发区中环大道10号环宇人力行政部
66 | 2. Annotation: district=江宁区 town=江宁街道 devzone=江宁滨江开发区 road=中环大道 roadno=10号 poi=环宇人力行政部
67 |
68 | ####【roadno】
69 | 1. Interpretation: road number
70 | 2. Example 1:
71 | 1. Query: 江宁区江宁滨江开发区中环大道10号环宇人力行政部
72 | 2. Annotation: district=江宁区 town=江宁街道 devzone=江宁滨江开发区 road=中环大道 roadno=10号 poi=环宇人力行政部
73 | 3. Example 2:
74 | 1. Query: 沿山村5组6号
75 | 2. Annotation: community=沿山村 roadno=5组6号
76 |
77 | ####【subroad】
78 | 1. Interpretation: subroad
79 | 2. Example 1:
80 | 1. Query: 浙江省台州市临海市江南大道创业大道288号
81 | 2. Annotation: prov=浙江省 city=台州市 district=临海市 town=江南街道 road=江南大道 subroad=创业大道 subroadno=288号
82 |
83 |
84 | ####【subroadno】
85 | 1. Interpretation: subroad number
86 | 2. Example 1: See the last example。
87 |
88 | ####【poi】
89 | 1. Interpretation: point of interest;
90 | 2. Example 1:
91 | 1. Query: 浙江省杭州市余杭区五常街道文一西路969号阿里巴巴西溪园区
92 | 2. Annotation: prov=浙江省 city=杭州市 district=余杭区 town=五常街道 road=文一西路 roadno=969号 poi=阿里巴巴西溪园区
93 |
94 | ####【subpoi】
95 | 1. Interpretation: sub-poi;
96 | 2. Example 1:
97 | 1. Query: 新疆维吾尔自治区 昌吉回族自治州 昌吉市 延安北路街道 延安南路石油小区东门
98 | 2. Annotation: prov=新疆维吾尔自治区 city=昌吉回族自治州 district=昌吉市 town=延安北路街道 road=延安南路 poi=石油小区 subpoi=东门
99 | 3. Example 2:
100 | 1. Query: 西湖区新金都城市花园西雅园10幢3底层
101 | 2. Annotation: district=西湖区 town=文新街道 road=古墩路 roadno=415号 poi=新金都城市花园 subpoi=西雅园 houseno=10幢 floorno=3底层
102 | 4. Example 3:
103 | 1. Query: 广宁伯街2号金泽大厦东区15层
104 | 2. Annotation: road=广宁伯街 roadno=2号 poi=金泽大厦 subpoi=东区 floorno=15层
105 | 5. Comment: We regard the second poi appearing after the first poi as subpoi if they are located in the same region.
106 |
107 | ####【houseno】
108 | 1. Interpretation: house number;
109 | 2. Example 1:
110 | 1. Query: 阿里巴巴西溪园区6号楼小邮局
111 | 2. Annotation: poi=阿里巴巴西溪园区 houseno=6号楼 person=小邮局
112 | 3. Example 2:
113 | 1. Query: 四川省 成都市 金牛区 沙河源街道 金牛区九里堤街道 金府机电城A区3栋16号
114 | 2. Annotation: prov=四川省 city=成都市 district=金牛区 town=沙河源街道 road=金府路 poi=金府机电城 subpo=A区 houseno=3栋 cellno=16号
115 | 4. Example 3:
116 | 1. Query: 竹海水韵春风里12-3-1001
117 | 2. Annotation: poi=竹海水韵 subpoi=春风里 houseno=12幢 cellno=3单元 roomno= 1001室
118 | 5. Comment: It usually appears after poi.
119 |
120 | ####【cellno】
121 | 1. Interpretation: cell number;
122 | 2. Example 1:
123 | 1. Query: 竹海水韵春风里12-3-1001
124 | 2. Annotation: poi=竹海水韵 subpoi=春风里 houseno=12幢 cellno=3单元 roomno= 1001室
125 | 3. Example 2:
126 | 1. Query: 蒋村花园新达苑18幢二单元101
127 | 2. Annotation: town=蒋村街道 community=蒋村花园社区 road=晴川街 poi=蒋村花园 subpoi=新达苑 houseno=18幢 cellno=二单元 roomno=101室
128 |
129 | ####【floorno】
130 | 1. Interpretation: floor number;
131 | 2. Example 1:
132 | 1. Query: 北京市东城区东中街29号东环广场B座5层信达资本
133 | 2. Annotation: city=北京市 district=东城区 town=东直门街道 road=东中街 roadno=29号 poi=东环广场 houseno=B座 floorno=5层 person=信达资本
134 |
135 | ####【roomno】
136 | 1. Interpretation: room number;
137 | 2. Example 1: See the example in cell number
138 |
139 | ####【person】
140 | 1. Interpretation: name of a company or company representative
141 | 2. Example 1:
142 | 1. Query: 北京 北京市 西城区 广安门外街道 马连道马正和大厦3层我的未来网总部
143 | 2. Annotation: city=北京市 district=西城区 town=广安门外街道 road=马连道 poi=马正和大厦 floorno=3层 person=我的未来网总部
144 | 3. Example 2:
145 | 1. Query: 浙江省 杭州市 余杭区 良渚街道沈港路11号2楼 常春藤公司
146 | 2. Annotation: prov=浙江省 city=杭州市 district=余杭区 town=良渚街道 road=沈港路 roadno=11号 floorno=2楼 person=常春藤公司
147 |
148 | ####【assist】
149 | 1. Interpretation: assistant words for better location. For example, 旁边(beside), 对面(opposite)
150 | 2. Example 1:
151 | 1. Query: 广西柳州市城中区潭中东路勿忘我网吧门口
152 | 2. Annotation: prov=广西壮族自治区 city=柳州市 district=城中区 town=潭中街道 road=潭中东路 poi=勿忘我网吧 assist=门口
153 |
154 | ####【redundant】
155 | 1. Interpretation: useless and redundant words
156 | 2. Example 1:
157 | 1. Query: 浙江省 杭州市 滨江区 浙江省 杭州市 滨江区 六和路东信大道自行车租赁点
158 | 2. Annotation: prov=浙江省 city=杭州市 district=滨江区 redundant=浙江省杭州市滨江区 road=六和路 subroad=东信大道 poi=自行车租赁点
159 |
160 | 3. Example 2:
161 | 1. Query: 浙江省 杭州市 滨江区 六和路 ---- 东信大道自行车租赁点
162 | 2. Annotation: prov=浙江省 city=杭州市 district=滨江区 road=六和路 redundant=---- subroad=东信大道 poi=自行车租赁点
163 |
164 | ####【otherinfo】
165 | 1. Interpretation: other information which cannot be classified
166 |
167 | ## Partial Order
168 | There exists partial orders among labels. For example, city appears after province in the address. We summarize the following partial orders.
169 |
170 | #### 1. prov > city > district > town > comm > road > roadno > poi > houseno > cellno > floorno > roomno
171 | #### 2. district > devzone
172 | #### 3. devzone > comm
173 | #### 4. road > subroad
174 | #### 5. poi > subpoi
175 |
176 |
--------------------------------------------------------------------------------
/data/labels.txt:
--------------------------------------------------------------------------------
1 | country
2 | prov
3 | city
4 | district
5 | devzone
6 | town
7 | community
8 | road
9 | subroad
10 | roadno
11 | subroadno
12 | poi
13 | subpoi
14 | houseno
15 | cellno
16 | floorno
17 | roomno
18 | person
19 | assist
20 | redundant
21 | otherinfo
22 |
--------------------------------------------------------------------------------
/exp_dytree.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | model="chartdyRBTC"
3 | treetype="dytree"
4 |
5 | for pretrain in giga
6 | do
7 |
8 | for dropout in 0.4 #0.1 0.2 0.3 #0.4
9 | do
10 |
11 | for lstmdim in 200
12 | do
13 |
14 | for batchsize in 1 #10 #20
15 | do
16 |
17 | for normal in 1
18 | do
19 |
20 | for RBTlabel in houseno
21 | do
22 | #prov #city district devzone town community road subroad roadno subroadno poi subpoi #houseno cellno floorno roomno person otherinfo assist redundant #country #
23 |
24 |
25 | for nontlabelstyle in 0 #0 #1 3 #1 2
26 | do
27 |
28 | for zerocostchunk in 0 #0 1
29 | do
30 |
31 | log="addr_"$treetype"_"$pretrain"_"$dropout"_"$lstmdim"_"$batchsize"_"$model"_"$treetype"_"$normal"_"$RBTlabel"_"$nontlabelstyle"_"$zerocostchunk
32 | echo $log".log"
33 |
34 | nohup python3 src/main_dyRBT.py train --parser-type $model --model-path-base models/$model-model --lstm-dim $lstmdim --label-hidden-dim $lstmdim --split-hidden-dim $lstmdim --pretrainemb $pretrain --batch-size $batchsize --epochs 30 --treetype $treetype --expname $log --normal $normal --checks-per-epoch 4 --RBTlabel $RBTlabel --nontlabelstyle $nontlabelstyle --dropout $dropout --zerocostchunk $zerocostchunk --loadmodel none >> $log".log" 2>&1 &
35 |
36 | #
37 |
38 | done
39 | done
40 | done
41 | done
42 | done
43 | done
44 | done
45 | done
46 |
47 |
48 |
--------------------------------------------------------------------------------
/giga.emb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leodotnet/neural-chinese-address-parsing/54cf6bde941a152cc096f54294254fc4ee288fa7/giga.emb
--------------------------------------------------------------------------------
/log.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leodotnet/neural-chinese-address-parsing/54cf6bde941a152cc096f54294254fc4ee288fa7/log.jpg
--------------------------------------------------------------------------------
/src/evaluate.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os.path
3 | import re
4 | import subprocess
5 | import tempfile
6 |
7 | import trees
8 |
9 |
10 | class FScore(object):
11 | def __init__(self, recall, precision, fscore):
12 | self.recall = recall
13 | self.precision = precision
14 | self.fscore = fscore
15 |
16 | def __str__(self):
17 | return "(Recall={:.2f}%, Precision={:.2f}%, FScore={:.2f}%)".format(
18 | self.recall * 100, self.precision * 100, self.fscore * 100)
19 |
20 | def evalb(evalb_dir, gold_trees, predicted_trees, expname = "default"):
21 | assert os.path.exists(evalb_dir)
22 | evalb_program_path = os.path.join(evalb_dir, "evalb")
23 | evalb_param_path = os.path.join(evalb_dir, "COLLINS.prm")
24 | assert os.path.exists(evalb_program_path)
25 | assert os.path.exists(evalb_param_path)
26 |
27 | assert len(gold_trees) == len(predicted_trees)
28 | for gold_tree, predicted_tree in zip(gold_trees, predicted_trees):
29 | assert isinstance(gold_tree, trees.TreebankNode)
30 | assert isinstance(predicted_tree, trees.TreebankNode)
31 | gold_leaves = list(gold_tree.leaves())
32 | predicted_leaves = list(predicted_tree.leaves())
33 | assert len(gold_leaves) == len(predicted_leaves)
34 | assert all(
35 | gold_leaf.word == predicted_leaf.word
36 | for gold_leaf, predicted_leaf in zip(gold_leaves, predicted_leaves))
37 |
38 | temp_dir = tempfile.TemporaryDirectory(prefix="evalb-")
39 | temp_dir = 'EVALB/'
40 | gold_path = os.path.join(temp_dir, expname + ".gold.txt")
41 | predicted_path = os.path.join(temp_dir, expname + ".predicted.txt")
42 | output_path = os.path.join(temp_dir, expname + ".output.txt")
43 |
44 | with open(gold_path, "w", encoding='utf-8') as outfile:
45 | for tree in gold_trees:
46 | outfile.write("{}\n".format(tree.linearize()))
47 |
48 | with open(predicted_path, "w", encoding='utf-8') as outfile:
49 | for tree in predicted_trees:
50 | outfile.write("{}\n".format(tree.linearize()))
51 |
52 | command = "{} -p {} {} {} > {}".format(
53 | evalb_program_path,
54 | evalb_param_path,
55 | gold_path,
56 | predicted_path,
57 | output_path,
58 | )
59 | subprocess.run(command, shell=True)
60 |
61 | fscore = FScore(math.nan, math.nan, math.nan)
62 | with open(output_path) as infile:
63 | for line in infile:
64 | match = re.match(r"Bracketing Recall\s+=\s+(\d+\.\d+)", line)
65 | if match:
66 | fscore.recall = float(match.group(1))
67 | match = re.match(r"Bracketing Precision\s+=\s+(\d+\.\d+)", line)
68 | if match:
69 | fscore.precision = float(match.group(1))
70 | match = re.match(r"Bracketing FMeasure\s+=\s+(\d+\.\d+)", line)
71 | if match:
72 | fscore.fscore = float(match.group(1))
73 | break
74 |
75 | success = (
76 | not math.isnan(fscore.fscore) or
77 | fscore.recall == 0.0 or
78 | fscore.precision == 0.0)
79 |
80 | if success:
81 | pass
82 | #temp_dir.cleanup()
83 | else:
84 | print("Error reading EVALB results.")
85 | print("Gold path: {}".format(gold_path))
86 | print("Predicted path: {}".format(predicted_path))
87 | print("Output path: {}".format(output_path))
88 |
89 | return fscore
90 |
91 |
92 | def count_common_chunks(chunk1, chunk2):
93 | common = 0
94 | for c1 in chunk1:
95 | for c2 in chunk2:
96 | if c1 == c2:
97 | common += 1
98 |
99 | return common
100 |
101 |
102 | def get_performance(match_num, gold_num, pred_num):
103 | p = (match_num + 0.0) / pred_num
104 | r = (match_num + 0.0) / gold_num
105 |
106 | try:
107 | f1 = 2 * p * r / (p + r)
108 | except ZeroDivisionError:
109 | f1 = 0.0
110 |
111 | return p, r, f1
112 |
113 |
114 | def get_text_from_chunks(chunks):
115 | text = []
116 | for chunk in chunks:
117 | text += chunk[3]
118 |
119 | return text
120 |
121 |
122 | def chunk2seq(chunks:[]):
123 | seq = []
124 | for label, start_pos, end_pos, text_list in chunks:
125 | seq.append('B-' + label)
126 | for i in range(start_pos, end_pos - 1):
127 | seq.append('I-' + label)
128 |
129 | return seq
130 |
131 | def chunks2str(chunks):
132 | #print(chunks)
133 | return ' '.join(["({} {})".format(label, ''.join(text_list)) for label, _, _, text_list in chunks])
134 |
135 | def eval_chunks(evalb_dir, gold_trees, predicted_trees, output_filename = 'dev.out.txt'):
136 | match_num = 0
137 | gold_num = 0
138 | pred_num = 0
139 |
140 |
141 |
142 | fout = open(output_filename, 'w', encoding='utf-8')
143 |
144 | invalid_predicted_tree = 0
145 |
146 | for gold_tree, predicted_tree in zip(gold_trees, predicted_trees):
147 |
148 | # print(colored(gold_tree.linearize(), 'red'))
149 | # print(colored(predicted_tree.linearize(), 'yellow'))
150 |
151 | gold_chunks = gold_tree.to_chunks()
152 | predict_chunks = predicted_tree.to_chunks()
153 | input_seq = get_text_from_chunks(gold_chunks)
154 | gold_seq = chunk2seq(gold_chunks)
155 | predict_seq = chunk2seq(predict_chunks)
156 |
157 | if len(gold_seq) != len(predict_seq):
158 | invalid_predicted_tree += 1
159 | # print(colored('Error:', 'red'))
160 | # print(input_seq)
161 | # exit()
162 |
163 | o_list = zip(input_seq, gold_seq, predict_seq)
164 | fout.write('\n'.join(['\t'.join(x) for x in o_list]))
165 | fout.write('\n\n')
166 |
167 |
168 | # print(colored(chunks2str(gold_chunks), 'red'))
169 | # print(colored(chunks2str(predict_chunks), 'yellow'))
170 | # print()
171 |
172 | match_num += count_common_chunks(gold_chunks, predict_chunks)
173 | gold_num += len(gold_chunks)
174 | pred_num += len(predict_chunks)
175 |
176 | fout.close()
177 |
178 | #p, r, f1 = get_performance(match_num, gold_num, pred_num)
179 | # fscore = FScore(r, p, f1)
180 | # print('P,R,F: [{0:.2f}, {1:.2f}, {2:.2f}]'.format(p * 100, r * 100, f1 * 100), flush=True)
181 | print(output_filename)
182 |
183 | print('invalid_predicted_tree:',invalid_predicted_tree)
184 |
185 | cmdline = ["perl", "conlleval.pl"]
186 | cmd = subprocess.Popen(cmdline, stdin=open(output_filename, 'r', encoding='utf-8'), stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
187 | stdout, stderr = cmd.communicate()
188 | lines = stdout.decode("utf-8")
189 | print(lines)
190 | for line in lines.split('\n'):
191 | if line.startswith('accuracy'):
192 | #print('line:', line)
193 | for item in line.strip().split(';'):
194 | items = item.strip().split(':')
195 | if items[0].startswith('precision'):
196 | p = float(items[1].strip()[:-1]) / 100
197 | elif items[0].startswith('recall'):
198 | r = float(items[1].strip()[:-1]) / 100
199 | elif items[0].startswith('FB1'):
200 | f1 = float(items[1].strip()) / 100
201 |
202 | fscore = FScore(r, p, f1)
203 | print('P,R,F: [{0:.2f}, {1:.2f}, {2:.2f}]'.format(p * 100, r * 100, f1 * 100), flush=True)
204 | return fscore
205 |
206 |
207 |
208 |
209 |
210 |
211 | def eval_chunks2(evalb_dir, gold_chunks_list, pred_chunk_list, output_filename = 'dev.out.txt'):
212 | match_num = 0
213 | gold_num = 0
214 | pred_num = 0
215 |
216 |
217 | fout = open(output_filename, 'w', encoding='utf-8')
218 |
219 | invalid_predicted_tree = 0
220 |
221 | for gold_chunks, predict_chunks in zip(gold_chunks_list, pred_chunk_list):
222 |
223 | # print(colored(gold_tree.linearize(), 'red'))
224 | # print(colored(predicted_tree.linearize(), 'yellow'))
225 |
226 | #gold_chunks = gold_tree.to_chunks()
227 | #predict_chunks = predicted_tree.to_chunks()
228 | input_seq = get_text_from_chunks(gold_chunks)
229 | gold_seq = chunk2seq(gold_chunks)
230 | predict_seq = chunk2seq(predict_chunks)
231 |
232 | if len(gold_seq) != len(predict_seq):
233 | invalid_predicted_tree += 1
234 | # print(colored('Error:', 'red'))
235 | # print(input_seq)
236 | # exit()
237 |
238 | o_list = zip(input_seq, gold_seq, predict_seq)
239 | fout.write('\n'.join(['\t'.join(x) for x in o_list]))
240 | fout.write('\n\n')
241 |
242 |
243 | # print(colored(chunks2str(gold_chunks), 'red'))
244 | # print(colored(chunks2str(predict_chunks), 'yellow'))
245 | # print()
246 |
247 | match_num += count_common_chunks(gold_chunks, predict_chunks)
248 | gold_num += len(gold_chunks)
249 | pred_num += len(predict_chunks)
250 |
251 | fout.close()
252 |
253 | #p, r, f1 = get_performance(match_num, gold_num, pred_num)
254 | # fscore = FScore(r, p, f1)
255 | # print('P,R,F: [{0:.2f}, {1:.2f}, {2:.2f}]'.format(p * 100, r * 100, f1 * 100), flush=True)
256 | print(output_filename)
257 |
258 | print('invalid_predicted_tree:',invalid_predicted_tree)
259 |
260 | cmdline = ["perl", "conlleval.pl"]
261 | cmd = subprocess.Popen(cmdline, stdin=open(output_filename, 'r', encoding='utf-8'), stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
262 | stdout, stderr = cmd.communicate()
263 | lines = stdout.decode("utf-8")
264 | print(lines)
265 | for line in lines.split('\n'):
266 | if line.startswith('accuracy'):
267 | #print('line:', line)
268 | for item in line.strip().split(';'):
269 | items = item.strip().split(':')
270 | if items[0].startswith('precision'):
271 | p = float(items[1].strip()[:-1]) / 100
272 | elif items[0].startswith('recall'):
273 | r = float(items[1].strip()[:-1]) / 100
274 | elif items[0].startswith('FB1'):
275 | f1 = float(items[1].strip()) / 100
276 |
277 | fscore = FScore(r, p, f1)
278 | print('P,R,F: [{0:.2f}, {1:.2f}, {2:.2f}]'.format(p * 100, r * 100, f1 * 100), flush=True)
279 | return fscore
280 |
--------------------------------------------------------------------------------
/src/latent.py:
--------------------------------------------------------------------------------
1 | import random
2 | #import vocabulary
3 |
4 |
5 | ASSIST = 'assist'
6 | REDUNDANT = 'redundant'
7 | NO = 'no'
8 |
9 |
10 | class latent_tree_builder:
11 | def __init__(self, label_vocab, RBT_order_before_label, non_terminal_label_mode = 0):
12 | '''
13 | :param label_vocab:
14 | :param RBT_order_before_label:
15 | :param non_terminal_label_mode: 0, based on label order; 1, empty label ; 2, random label
16 | '''
17 | self.label_vocab = label_vocab
18 | if RBT_order_before_label == 'none':
19 | self.RBT_order_before_idx = -1
20 | elif RBT_order_before_label == 'start':
21 | self.RBT_order_before_idx = 23
22 | else:
23 | self.RBT_order_before_idx = self.get_label_order(RBT_order_before_label)
24 | self.non_terminal_label_mode = non_terminal_label_mode
25 | self.label_size = 21
26 |
27 |
28 | def get_label_order(self, label_str):
29 | if label_str.endswith("'"):
30 | label_str = label_str[:-1]
31 |
32 | label = (label_str,)
33 | label_order = self.label_vocab.size - self.label_vocab.index(label)
34 | return label_order;
35 |
36 |
37 | def non_terminal_label(self, label):
38 | if label[-1] == "'":
39 | return label
40 | else:
41 | return label + "'"
42 |
43 | def terminal_label(self, label):
44 | if label[-1] == "'":
45 | return label[:-1]
46 | else:
47 | return label
48 |
49 |
50 | def get_parent_label(self, child1 : str, child2 : str, order_type = 0):
51 | parent = None
52 |
53 | # if order_type == 1:
54 | # return get_parent_label_reverse(child1, child2)
55 | #
56 | # if order_type == 2:
57 | # return get_parent_label_order(child1, child2)
58 | #
59 | # if order_type == 3:
60 | # return get_parent_label_empty(child1, child2)
61 | #
62 | # if order_type == 4:
63 | # return get_parent_label_semi(child1, child2)
64 |
65 | if child1.startswith(REDUNDANT) or child2.startswith(REDUNDANT):
66 | if child1.startswith(REDUNDANT) and child2.startswith(REDUNDANT):
67 | parent = self.non_terminal_label(REDUNDANT)
68 | elif child1.startswith(REDUNDANT):
69 | parent = self.non_terminal_label(child2)
70 | else:
71 | parent = self.non_terminal_label(child1)
72 |
73 | elif child1.startswith(ASSIST) or child2.startswith(ASSIST):
74 | if child1.startswith(ASSIST) and child2.startswith(ASSIST):
75 | parent = self.non_terminal_label(ASSIST)
76 | elif child1.startswith(ASSIST):
77 | parent = self.non_terminal_label(child2)
78 | else:
79 | parent = self.non_terminal_label(child1)
80 |
81 |
82 | else:
83 |
84 | last1_priority_id = self.get_label_order(self.terminal_label(child1))
85 | last2_priority_id = self.get_label_order(self.terminal_label(child2))
86 |
87 |
88 | if last1_priority_id >= last2_priority_id:
89 | parent = self.non_terminal_label(child1)
90 | else:
91 | parent = self.non_terminal_label(child2)
92 |
93 |
94 |
95 | return parent
96 |
97 |
98 |
99 | def build_latent_tree_str(self, x, chunks):
100 |
101 | RBT_order_before_idx = self.RBT_order_before_idx
102 | pos_POI = -1
103 | for i in range(len(chunks)):
104 | if chunks[i][0] == 'poi':
105 | pos_POI = i
106 |
107 | if pos_POI == -1:
108 | RBT_order_before_idx = 0
109 |
110 |
111 | idx = pos_POI
112 | chunk_aug = [(chunk[0], chunk[1], chunk[2],'(' + chunk[0] + ' ' + ' '.join(['(XX ' + item + ')' for item in x[chunk[1]:chunk[2]]]) + ' )') for chunk in chunks]
113 |
114 |
115 | def create_parent_node(left, right):
116 | parent_label = self.get_parent_label(left[0], right[0])
117 | parent_left_boundary = left[1]
118 | parent_right_boundary = right[2]
119 | parent_str = '(' + parent_label + ' ' + left[3] + ' ' + right[3] + ' )'
120 |
121 | parent = (parent_label, parent_left_boundary, parent_right_boundary, parent_str)
122 | return parent
123 |
124 | #Build Random Tree
125 | while len(chunk_aug) > 1 and idx >= 0:
126 | label_order = self.get_label_order(chunk_aug[idx][0])
127 |
128 | options = []
129 | if label_order < RBT_order_before_idx:
130 | if idx + 1 < len(chunk_aug):
131 | label_order_next = self.get_label_order(chunk_aug[idx + 1][0])
132 | if label_order_next < RBT_order_before_idx:
133 | options.append(idx + 1)
134 |
135 | if idx - 1 >= 0:
136 | label_order_prev = self.get_label_order(chunk_aug[idx - 1][0])
137 | if label_order_prev < RBT_order_before_idx:
138 | options.append(idx - 1)
139 |
140 | if len(options) == 0:
141 | break
142 |
143 |
144 | if len(options) == 1:
145 | option = options[0]
146 | else:
147 | p = random.random()
148 | option = options[0 if p > 0.5 else 1]
149 |
150 |
151 | if option == idx - 1:
152 | left = chunk_aug[idx - 1]
153 | right = chunk_aug[idx]
154 |
155 | parent = create_parent_node(left, right)
156 |
157 | chunk_aug[idx - 1] = parent
158 | chunk_aug.remove(right)
159 |
160 | idx = idx - 1
161 | else:
162 | left = chunk_aug[idx]
163 | right = chunk_aug[idx + 1]
164 |
165 | parent = create_parent_node(left, right)
166 |
167 | chunk_aug[idx] = parent
168 | chunk_aug.remove(right)
169 |
170 | idx = idx
171 |
172 |
173 | #Build RBT tree for the rest
174 | while len(chunk_aug) > 1:
175 | idx = len(chunk_aug) - 1
176 |
177 | left = chunk_aug[idx - 1]
178 | right = chunk_aug[idx]
179 |
180 | parent = create_parent_node(left, right)
181 |
182 | chunk_aug[idx - 1] = parent
183 | chunk_aug.remove(right)
184 |
185 |
186 | return chunk_aug[0][3]
187 |
188 | def build_latent_tree(self, x, chunks):
189 | import util
190 | tree_str = self.build_latent_tree_str(x, chunks)
191 | tree = util.load_trees_from_str(tree_str, 0)
192 | return tree
193 |
194 | def build_latent_trees(self, insts):
195 | import util
196 | trees_str = ''
197 | for x, chunks in insts:
198 | tree_str = self.build_latent_tree_str(x, chunks)
199 | trees_str += tree_str + '\n'
200 | trees = util.load_trees_from_str(trees_str, 0)
201 | return trees
202 |
203 |
204 |
205 |
206 | def build_dynamicRBT_tree(self, x, chunks):
207 |
208 | from trees import InternalTreebankNode, LeafTreebankNode, InternalUncompletedTreebankNode, InternalParseChunkNode, InternalTreebankChunkNode
209 | import parse
210 |
211 | cut_off_point = -1
212 | for i in reversed(range(len(chunks))):
213 | label_order = self.get_label_order(chunks[i][0])
214 | if label_order >= self.RBT_order_before_idx:
215 | cut_off_point = i
216 | break
217 |
218 |
219 | #Build RBT from [0, i]
220 |
221 | latentscope = (chunks[cut_off_point + 1][1] if cut_off_point + 1 < len(chunks) else len(x), len(x))
222 |
223 | chunks_in_scope = chunks[cut_off_point+1:]
224 | chunks_in_scope = [(label, s, e, x[s:e]) for label, s, e in chunks_in_scope]
225 | if len(chunks_in_scope) > 0:
226 | chunkleaves = []
227 | for label, s, e, text in chunks_in_scope:
228 | leaves = []
229 | for ch in text:
230 | leaf = LeafTreebankNode(parse.XX, ch)
231 | leaves.append(leaf)
232 |
233 | chunk_leaf = InternalTreebankNode(label, leaves) #InternalTreebankChunkNode
234 | chunkleaves.append(chunk_leaf)
235 |
236 | if self.non_terminal_label_mode == 0 or self.non_terminal_label_mode == 3:
237 | label = max([chunk[0] for chunk in chunks_in_scope], key=lambda l: self.get_label_order(l))
238 | label = self.non_terminal_label(label)
239 | elif self.non_terminal_label_mode == 1:
240 | label = parse.EMPTY
241 | else: #self.non_terminal_label_mode == 2:
242 | import random
243 | label_id = random.randint(1 + self.latent.label_size + 0, 1 + self.latent.label_size + self.latent.label_size - 1)
244 | label = self.label_vocab.value(label_id)
245 |
246 | latent_area = [InternalUncompletedTreebankNode(label, chunkleaves, chunks_in_scope, self)]
247 | else:
248 | latent_area = []
249 |
250 | RBT_chunks = list(chunks[:cut_off_point+1]) + latent_area #(parse.EMPTY, chunks[i+1][1], chunks[-1][2])
251 |
252 | if len(latent_area) == 0:
253 | label, s, e = chunks[-1]
254 | text = x[s:e]
255 | leaves = []
256 | for ch in text:
257 | leaf = LeafTreebankNode(parse.XX, ch)
258 | leaves.append(leaf)
259 |
260 | RBT_chunks[-1] = InternalTreebankNode(label, leaves)
261 |
262 | while len(RBT_chunks) > 1:
263 |
264 | second_last_chunk = RBT_chunks[-2]
265 |
266 | second_last_children = []
267 | for pos in range(second_last_chunk[1], second_last_chunk[2]):
268 | second_last_children.append(LeafTreebankNode(parse.XX, x[pos]))
269 |
270 | second_last_node = InternalTreebankNode(second_last_chunk[0], second_last_children) #InternalTreebankChunkNode
271 |
272 | last_node = RBT_chunks[-1]
273 |
274 | if self.non_terminal_label_mode == 0 or self.non_terminal_label_mode == 3:
275 | parent_label = self.get_parent_label(second_last_chunk[0], last_node.label)
276 | elif self.non_terminal_label_mode == 1:
277 | parent_label = parse.EMPTY
278 | else: #self.non_terminal_label_mode == 2:
279 | import random
280 | label_id = random.randint(1 + self.latent.label_size + 0, 1 + self.latent.label_size + self.latent.label_size - 1)
281 | parent_label = self.label_vocab.value(label_id)
282 |
283 | parent_node = InternalTreebankNode(parent_label, [second_last_node, last_node])
284 |
285 |
286 |
287 | RBT_chunks[-2] = parent_node
288 | RBT_chunks.remove(last_node)
289 |
290 | tree = RBT_chunks[0]
291 | return x, tree, chunks, latentscope
292 |
293 | def build_dynamicRBT_trees(self, insts):
294 | trees = []
295 | for x, chunks in insts:
296 | x, tree, chunks, latentscope = self.build_dynamicRBT_tree(x, chunks)
297 | trees.append((x, tree, chunks, latentscope))
298 | return trees
299 |
300 | def main_test():
301 | import util
302 | import vocabulary
303 | import parse
304 | label_list = util.load_label_list('data/labels.txt')
305 | label_vocab = vocabulary.Vocabulary()
306 |
307 | label_vocab.index(())
308 |
309 |
310 | for item in label_list:
311 | label_vocab.index((item,))
312 |
313 | for item in label_list:
314 | label_vocab.index((item + "'",))
315 |
316 | label_vocab.index((parse.EMPTY,))
317 |
318 | label_vocab.freeze()
319 |
320 | latent = latent_tree_builder(label_vocab, 'city')
321 |
322 | insts = util.read_chunks('data/trial.txt')
323 |
324 |
325 | # for k in range(3):
326 | # trees = latent.build_latent_trees(insts)
327 | # for tree in trees:
328 | # print(tree.linearize())
329 | # print()
330 |
331 |
332 | trees = latent.build_dynamicRBT_trees(insts)
333 | for x, tree, chunks, latentscope in trees:
334 | print(tree.linearize())
335 | tree = tree.convert()
336 | print()
337 | tree = tree.convert()
338 | print()
339 |
340 |
341 |
342 | #main_test()
--------------------------------------------------------------------------------
/src/latenttrees.py:
--------------------------------------------------------------------------------
1 | import collections.abc
2 | import parse
3 | import util
4 |
5 | class TreebankNode(object):
6 | pass
7 |
8 | class InternalTreebankNode(TreebankNode):
9 | def __init__(self, label, children):
10 | assert isinstance(label, str)
11 | self.label = label
12 |
13 | assert isinstance(children, collections.abc.Sequence)
14 | assert all(isinstance(child, TreebankNode) for child in children)
15 | #assert children
16 | self.children = tuple(children)
17 |
18 | def to_chunks(self):
19 | raw_chunks = []
20 | self.chunk_helper(raw_chunks)
21 |
22 | chunks = []
23 | p = 0
24 | for label, text_list in raw_chunks:
25 | if label.endswith("'"):
26 | label = label[:-1]
27 | chunks.append((label, p, p + len(text_list), text_list))
28 | p += len(text_list)
29 |
30 | return chunks
31 |
32 |
33 | def chunk_helper(self, chunks):
34 |
35 | children_status = [isinstance(child, LeafTreebankNode) for child in self.children]
36 |
37 | if all(children_status):
38 | chunks.append((self.label, [child.word for child in self.children]))
39 | elif any(children_status):
40 | char_list = []
41 |
42 | for child in self.children:
43 | if isinstance(child, InternalTreebankNode):
44 | char_list += child.get_word_list()
45 | else:
46 | char_list.append(child.word)
47 |
48 |
49 | chunk = (self.label, char_list)
50 | chunks.append(chunk)
51 | else:
52 | for child in self.children:
53 | if isinstance(child, InternalTreebankNode):
54 | child.chunk_helper(chunks)
55 |
56 |
57 | def get_word_list(self):
58 | word_list = []
59 | for child in self.children:
60 | if isinstance(child, InternalTreebankNode):
61 | word_list += child.get_word_list()
62 | else:
63 | word_list += [child.word]
64 | return word_list
65 |
66 | def linearize(self):
67 | return "({} {})".format(
68 | self.label, " ".join(child.linearize() for child in self.children))
69 |
70 | def leaves(self):
71 | for child in self.children:
72 | yield from child.leaves()
73 |
74 | def convert(self, index=0):
75 | tree = self
76 | sublabels = [self.label]
77 |
78 | while len(tree.children) == 1 and isinstance(tree.children[0], InternalTreebankNode):
79 | tree = tree.children[0]
80 | sublabels.append(tree.label)
81 |
82 | children = []
83 | for child in tree.children:
84 | children.append(child.convert(index=index))
85 | index = children[-1].right
86 |
87 | return InternalParseNode(tuple(sublabels), children)
88 |
89 |
90 | class LeafTreebankNode(TreebankNode):
91 | def __init__(self, tag, word):
92 | assert isinstance(tag, str)
93 | self.tag = tag
94 |
95 | assert isinstance(word, str)
96 | self.word = word
97 |
98 | def linearize(self):
99 | return "({} {})".format(self.tag, self.word)
100 |
101 | def leaves(self):
102 | yield self
103 |
104 | def convert(self, index=0):
105 | return LeafParseNode(index, self.tag, self.word)
106 |
107 |
108 |
109 | class InternalUncompletedTreebankNode(TreebankNode):
110 | def __init__(self, label, chunkleaves, chunks_in_scope, latent):
111 | assert isinstance(label, str)
112 | self.label = label
113 | self.latent = latent
114 | #assert isinstance(children, collections.abc.Sequence)
115 | #assert all(isinstance(child, TreebankNode) for child in children)
116 | #assert children
117 | self.children = () #tuple(children)
118 | self.chunkleaves = chunkleaves
119 | self.chunks_in_scope = chunks_in_scope
120 |
121 |
122 | def linearize(self):
123 |
124 | #chunks_str_list = [ '(' + label + ' ' + ''.join( '(XX ' + item + ')' for item in text) + ')' for label, s, e, text in self.chunks_in_scope]
125 | return "({} leaves: {})".format(self.label, " ".join(leaf.linearize() for leaf in self.chunkleaves))
126 |
127 | def leaves(self):
128 | return self.chunkleaves
129 |
130 | def convert(self, index=0):
131 | tree = self
132 | sublabels = [self.label]
133 |
134 | # while len(tree.children) == 1 and isinstance(tree.children[0], InternalTreebankNode):
135 | # tree = tree.children[0]
136 | # sublabels.append(tree.label)
137 |
138 | #children = []
139 | # for child in tree.children:
140 | # children.append(child.convert(index=index))
141 | # index = children[-1].right
142 | #index = self.chunks_in_scope[0][1]
143 | chunkleaves = []
144 | for chunkleaf in self.chunkleaves:
145 | chunkleaves.append(chunkleaf.convert(index=index))
146 | index = chunkleaves[-1].right
147 | # for i in range(len(self.chunkleaves)):
148 | # chunk = self.chunks_in_scope[i]
149 | # chunkleaf = self.chunkleaves[i]
150 | # chunkleaf.convert(index=chunk[1])
151 | # chunkleaves.append(chunkleaf)
152 |
153 | return InternalUncompletedParseNode(tuple(sublabels), chunkleaves, self.chunks_in_scope, self.latent)
154 |
155 |
156 |
157 | class ParseNode(object):
158 | pass
159 |
160 | class InternalParseNode(ParseNode):
161 | def __init__(self, label, left, right, children):
162 | assert isinstance(label, tuple)
163 | assert all(isinstance(sublabel, str) for sublabel in label)
164 | assert label
165 | self.label = label
166 |
167 | assert isinstance(children, collections.abc.Sequence)
168 | assert all(isinstance(child, ParseNode) for child in children)
169 | assert children
170 | assert len(children) > 1 or isinstance(children[0], LeafParseNode)
171 | assert all(
172 | left.right == right.left
173 | for left, right in zip(children, children[1:]))
174 | self.children = tuple(children)
175 |
176 | self.left = children[0].left
177 | self.right = children[-1].right
178 |
179 | def leaves(self):
180 | for child in self.children:
181 | yield from child.leaves()
182 |
183 | def convert(self):
184 | children = [child.convert() for child in self.children]
185 | tree = InternalTreebankNode(self.label[-1], children)
186 | for sublabel in reversed(self.label[:-1]):
187 | tree = InternalTreebankNode(sublabel, [tree])
188 | return tree
189 |
190 | def enclosing(self, left, right):
191 | assert self.left <= left < right <= self.right
192 | for child in self.children:
193 | if isinstance(child, LeafParseNode):
194 | continue
195 | if child.left <= left < right <= child.right:
196 | return child.enclosing(left, right)
197 | return self
198 |
199 | def oracle_label(self, left, right):
200 | enclosing = self.enclosing(left, right)
201 | if enclosing.left == left and enclosing.right == right:
202 | return enclosing.label
203 | return ()
204 |
205 | def oracle_splits(self, left, right):
206 | # return [
207 | # child.left
208 | # for child in self.enclosing(left, right).children
209 | # if left < child.left < right
210 | # ]
211 | enclosing = self.enclosing(left, right)
212 | return [
213 | child.left
214 | for child in enclosing.children
215 | if left < child.left < right
216 | ]
217 | # if isinstance(enclosing, InternalUncompletedParseNode):
218 | # return [
219 | # child.left
220 | # for child in enclosing.chunkleaves
221 | # if left < child.left < right
222 | # ]
223 | # else:
224 | # return [
225 | # child.left
226 | # for child in enclosing.children
227 | # if left < child.left < right
228 | # ]
229 |
230 |
231 |
232 | class InternalUncompletedParseNode(InternalParseNode):
233 | def __init__(self, label, chunkleaves, chunks_in_scope:[], latent):
234 | assert isinstance(label, tuple)
235 | assert all(isinstance(sublabel, str) for sublabel in label)
236 | assert label
237 | self.label = label
238 | self.latent = latent
239 |
240 | #assert isinstance(children, collections.abc.Sequence)
241 | #assert all(isinstance(child, ParseNode) for child in children)
242 | #assert children
243 | #assert len(children) > 1 or isinstance(children[0], LeafParseNode)
244 | # assert all(
245 | # left.right == right.left
246 | # for left, right in zip(children, children[1:]))
247 | self.children = [] #tuple(children)
248 | self.chunkleaves = chunkleaves
249 |
250 | #self.left = children[0].left
251 | #self.right = children[-1].right
252 | self.left = chunkleaves[0].left #chunks_in_scope[0][1]
253 | self.right = chunkleaves[-1].right #chunks_in_scope[-1][2]
254 | self.chunks_in_scope = chunks_in_scope
255 | self.splits = [chunk[1] for chunk in self.chunks_in_scope] + [self.chunks_in_scope[-1][2]]
256 |
257 | def leaves(self):
258 | return self.chunkleaves
259 |
260 | def convert(self):
261 | # children = [child.convert() for child in self.children]
262 | # tree = InternalUncompletedTreebankNode(self.label[-1], children)
263 | # for sublabel in reversed(self.label[:-1]):
264 | # tree = InternalUncompletedTreebankNode(sublabel, [tree])
265 |
266 | chunkleaves = [chunkleaf.convert() for chunkleaf in self.chunkleaves]
267 | tree = InternalUncompletedTreebankNode(self.label[-1], chunkleaves, (self.left, self.right), self.latent)
268 | return tree
269 |
270 | def enclosing(self, left, right):
271 | assert self.left <= left < right <= self.right
272 | for chunkleaf in self.chunkleaves:
273 | if isinstance(chunkleaf, LeafParseNode):
274 | continue
275 | if chunkleaf.left <= left < right <= chunkleaf.right:
276 | return chunkleaf.enclosing(left, right)
277 |
278 | # if left in self.splits and right in self.splits:
279 | # children = [chunkleaf for chunkleaf in self.chunkleaves if left <= chunkleaf.left and chunkleaf.right <= right]
280 | # # label = max([child.label for child in children],key=lambda l:self.latent.get_label_order(l[0]))
281 | # # label = (self.latent.non_terminal_label(label[0]),)
282 | # label = self.label
283 | # return InternalParseNode(label, children)
284 |
285 | children = [chunkleaf for chunkleaf in self.chunkleaves if left < chunkleaf.right]
286 | children = [chunkleaf for chunkleaf in children if right > chunkleaf.left]
287 |
288 | if self.latent.non_terminal_label_mode == 0:
289 | label = max([child.label for child in children], key=lambda l: self.latent.get_label_order(l[0]))
290 | label = (self.latent.non_terminal_label(label[0]),)
291 | elif self.latent.non_terminal_label_mode == 1:
292 | label = self.label
293 | else: #self.latent.non_terminal_label_mode == 2:
294 | import random
295 | label_id = random.randint(1 + self.latent.label_size + 0, 1 + self.latent.label_size + self.latent.label_size - 1)
296 | label = self.latent.label_vocab.value(label_id)
297 | return InternalParseNode(label, children)
298 |
299 | # self.children = self.chunkleaves
300 | # return self
301 |
302 | def oracle_label(self, left, right):
303 | # enclosing = self.enclosing(left, right)
304 | # if enclosing.left == left and enclosing.right == right:
305 | # return enclosing.label
306 |
307 | for chunk in self.chunks_in_scope:
308 | if chunk[1] == left and chunk[2] == right:
309 | return (chunk[0],)
310 |
311 |
312 | if left in self.splits and right in self.splits:
313 | return self.label
314 |
315 | return ()
316 |
317 | def oracle_splits(self, left, right):
318 |
319 |
320 | ret = [p for p in self.splits if left < p and p < right]
321 | if len(ret) == 0:
322 | return [
323 | child.left
324 | for child in self.enclosing(left, right).children
325 | if left < child.left < right
326 | ]
327 |
328 |
329 | return ret
330 |
331 |
332 |
333 | class LeafParseNode(ParseNode):
334 | def __init__(self, index, tag, word):
335 | assert isinstance(index, int)
336 | assert index >= 0
337 | self.left = index
338 | self.right = index + 1
339 |
340 | assert isinstance(tag, str)
341 | self.tag = tag
342 |
343 | assert isinstance(word, str)
344 | self.word = word
345 |
346 | def leaves(self):
347 | yield self
348 |
349 | def convert(self):
350 | return LeafTreebankNode(self.tag, self.word)
351 |
352 |
353 | def load_trees(path, normal, strip_top=True):
354 | with open(path, 'r', encoding='utf-8') as infile:
355 | tokens = infile.read().replace("(", " ( ").replace(")", " ) ").split()
356 |
357 | def helper(index):
358 | trees = []
359 |
360 | while index < len(tokens) and tokens[index] == "(":
361 | paren_count = 0
362 | while tokens[index] == "(":
363 | index += 1
364 | paren_count += 1
365 |
366 | label = tokens[index]
367 | index += 1
368 |
369 | if tokens[index] == "(":
370 | children, index = helper(index)
371 | trees.append(InternalTreebankNode(label, children))
372 | else:
373 | word = tokens[index]
374 | if normal == 1:
375 | newword = ''
376 | for c in word:
377 | if util.is_digit(c):
378 | newword += '0'
379 | else:
380 | newword += c
381 | else:
382 | newword = word
383 | index += 1
384 | trees.append(LeafTreebankNode(label, newword))
385 |
386 | while paren_count > 0:
387 | assert tokens[index] == ")"
388 | index += 1
389 | paren_count -= 1
390 |
391 | return trees, index
392 |
393 | trees, index = helper(0)
394 | assert index == len(tokens)
395 |
396 | if strip_top:
397 | for i, tree in enumerate(trees):
398 | if tree.label == "TOP":
399 | assert len(tree.children) == 1
400 | trees[i] = tree.children[0]
401 |
402 | return trees
403 |
--------------------------------------------------------------------------------
/src/main_dyRBT.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import os.path
4 | import time
5 |
6 | seed = 3986067777
7 | import dynet_config
8 | dynet_config.set(random_seed = seed)
9 | import dynet as dy
10 |
11 | import numpy as np
12 |
13 | import evaluate
14 | import parse
15 | import trees
16 | import vocabulary
17 |
18 | import latent
19 | import util
20 |
21 |
22 | def format_elapsed(start_time):
23 | elapsed_time = int(time.time() - start_time)
24 | minutes, seconds = divmod(elapsed_time, 60)
25 | hours, minutes = divmod(minutes, 60)
26 | days, hours = divmod(hours, 24)
27 | elapsed_string = "{}h{:02}m{:02}s".format(hours, minutes, seconds)
28 | if days > 0:
29 | elapsed_string = "{}d{}".format(days, elapsed_string)
30 | return elapsed_string
31 |
32 |
33 |
34 |
35 | def run_train(args):
36 |
37 | args.numpy_seed = seed
38 | if args.numpy_seed is not None:
39 | print("Setting numpy random seed to {}...".format(args.numpy_seed))
40 | np.random.seed(args.numpy_seed)
41 |
42 |
43 | if args.trial == 1:
44 | args.train_path = 'data/trial.txt'
45 | args.dev_path = 'data/trial.txt'
46 | args.test_path = 'data/trial.txt'
47 |
48 | # args.train_path = args.train_path.replace('[*]', args.treetype)
49 | # args.dev_path = args.dev_path.replace('[*]', args.treetype)
50 | # args.test_path = args.test_path.replace('[*]', args.treetype)
51 |
52 | print("Loading training trees from {}...".format(args.train_path))
53 | train_chunk_insts = util.read_chunks(args.train_path, args.normal)
54 | print("Loaded {:,} training examples.".format(len(train_chunk_insts)))
55 |
56 | print("Loading development trees from {}...".format(args.dev_path))
57 | dev_chunk_insts = util.read_chunks(args.dev_path, args.normal)
58 | print("Loaded {:,} development examples.".format(len(dev_chunk_insts)))
59 |
60 | print("Loading test trees from {}...".format(args.test_path))
61 | test_chunk_insts = util.read_chunks(args.test_path, args.normal)
62 | print("Loaded {:,} test examples.".format(len(test_chunk_insts)))
63 |
64 | # print("Processing trees for training...")
65 | # train_parse = [tree.convert() for tree in train_treebank]
66 |
67 | print("Constructing vocabularies...")
68 |
69 | tag_vocab = vocabulary.Vocabulary()
70 | tag_vocab.index(parse.START)
71 | tag_vocab.index(parse.STOP)
72 | tag_vocab.index(parse.XX)
73 |
74 |
75 | word_vocab = vocabulary.Vocabulary()
76 | word_vocab.index(parse.START)
77 | word_vocab.index(parse.STOP)
78 | word_vocab.index(parse.UNK)
79 | word_vocab.index(parse.NUM)
80 |
81 | for x, chunks in train_chunk_insts + dev_chunk_insts + test_chunk_insts:
82 | for ch in x:
83 | word_vocab.index(ch)
84 |
85 | label_vocab = vocabulary.Vocabulary()
86 | label_vocab.index(())
87 |
88 |
89 | label_list = util.load_label_list(args.labellist_path) #'data/labels.txt')
90 | for item in label_list:
91 | label_vocab.index((item, ))
92 |
93 | if args.nontlabelstyle != 1:
94 | for item in label_list:
95 | label_vocab.index((item + "'",))
96 |
97 | if args.nontlabelstyle == 1:
98 | label_vocab.index((parse.EMPTY,))
99 |
100 | tag_vocab.freeze()
101 | word_vocab.freeze()
102 | label_vocab.freeze()
103 |
104 | latent_tree = latent.latent_tree_builder(label_vocab, args.RBTlabel, args.nontlabelstyle)
105 |
106 | def print_vocabulary(name, vocab):
107 | special = {parse.START, parse.STOP, parse.UNK}
108 | print("{} ({:,}): {}".format(
109 | name, vocab.size,
110 | sorted(value for value in vocab.values if value in special) +
111 | sorted(value for value in vocab.values if value not in special)))
112 |
113 | if args.print_vocabs:
114 | print_vocabulary("Tag", tag_vocab)
115 | print_vocabulary("Word", word_vocab)
116 | print_vocabulary("Label", label_vocab)
117 |
118 | print("Initializing model...")
119 |
120 | pretrain = {'giga':'data/giga.vec100', 'none':'none'}
121 | pretrainemb = util.load_pretrain(pretrain[args.pretrainemb], args.word_embedding_dim, word_vocab)
122 |
123 | model = dy.ParameterCollection()
124 | if args.parser_type == "chartdyRBTC":
125 | parser = parse.ChartDynamicRBTConstraintParser(
126 | model,
127 | tag_vocab,
128 | word_vocab,
129 | label_vocab,
130 | args.tag_embedding_dim,
131 | args.word_embedding_dim,
132 | args.lstm_layers,
133 | args.lstm_dim,
134 | args.label_hidden_dim,
135 | args.dropout,
136 | (args.pretrainemb, pretrainemb),
137 | args.chunkencoding,
138 | args.trainc == 1,
139 | True,
140 | (args.zerocostchunk == 1),
141 | )
142 |
143 |
144 | else:
145 | print('Model is not valid!')
146 | exit()
147 |
148 | if args.loadmodel != 'none':
149 | tmp = dy.load(args.loadmodel, model)
150 | parser = tmp[0]
151 | print('Model is loaded from ', args.loadmodel)
152 |
153 | trainer = dy.AdamTrainer(model)
154 |
155 | total_processed = 0
156 | current_processed = 0
157 | check_every = len(train_chunk_insts) / args.checks_per_epoch
158 | best_dev_fscore = -np.inf
159 | best_dev_model_path = None
160 |
161 | start_time = time.time()
162 |
163 | def check_dev():
164 | nonlocal best_dev_fscore
165 | nonlocal best_dev_model_path
166 |
167 | dev_start_time = time.time()
168 |
169 | dev_predicted = []
170 | #dev_gold = []
171 |
172 | #dev_gold = latent_tree.build_latent_trees(dev_chunk_insts)
173 | dev_gold = []
174 | for inst in dev_chunk_insts:
175 | chunks = util.inst2chunks(inst)
176 | dev_gold.append(chunks)
177 |
178 | for x, chunks in dev_chunk_insts:
179 | dy.renew_cg()
180 | #sentence = [(leaf.tag, leaf.word) for leaf in tree.leaves()]
181 | sentence = [(parse.XX, ch) for ch in x]
182 | predicted, _ = parser.parse(sentence)
183 | dev_predicted.append(predicted.convert().to_chunks())
184 |
185 |
186 | #dev_fscore = evaluate.evalb(args.evalb_dir, dev_gold, dev_predicted, args.expname + '.dev.') #evalb
187 | dev_fscore = evaluate.eval_chunks2(args.evalb_dir, dev_gold, dev_predicted, output_filename=args.expname + '.dev.txt') # evalb
188 |
189 |
190 | print(
191 | "dev-fscore {} "
192 | "dev-elapsed {} "
193 | "total-elapsed {}".format(
194 | dev_fscore,
195 | format_elapsed(dev_start_time),
196 | format_elapsed(start_time),
197 | )
198 | )
199 |
200 |
201 | if dev_fscore.fscore > best_dev_fscore:
202 | if best_dev_model_path is not None:
203 | for ext in [".data", ".meta"]:
204 | path = best_dev_model_path + ext
205 | if os.path.exists(path):
206 | print("Removing previous model file {}...".format(path))
207 | os.remove(path)
208 |
209 | best_dev_fscore = dev_fscore.fscore
210 | best_dev_model_path = "{}_dev={:.2f}".format(args.model_path_base + "_" + args.expname, dev_fscore.fscore)
211 | print("Saving new best model to {}...".format(best_dev_model_path))
212 | dy.save(best_dev_model_path, [parser])
213 |
214 | test_start_time = time.time()
215 | test_predicted = []
216 | #test_gold = latent_tree.build_latent_trees(test_chunk_insts)
217 | test_gold = []
218 | for inst in test_chunk_insts:
219 | chunks = util.inst2chunks(inst)
220 | test_gold.append(chunks)
221 |
222 | ftreelog = open(args.expname + '.test.predtree.txt', 'w', encoding='utf-8')
223 |
224 | for x, chunks in test_chunk_insts:
225 | dy.renew_cg()
226 | #sentence = [(leaf.tag, leaf.word) for leaf in tree.leaves()]
227 | sentence = [(parse.XX, ch) for ch in x]
228 | predicted, _ = parser.parse(sentence)
229 | pred_tree = predicted.convert()
230 | ftreelog.write(pred_tree.linearize() + '\n')
231 | test_predicted.append(pred_tree.to_chunks())
232 |
233 |
234 |
235 | ftreelog.close()
236 |
237 | #test_fscore = evaluate.evalb(args.evalb_dir, test_chunk_insts, test_predicted, args.expname + '.test.')
238 | test_fscore = evaluate.eval_chunks2(args.evalb_dir, test_gold, test_predicted, output_filename=args.expname + '.test.txt') # evalb
239 |
240 | print(
241 | "epoch {:,} "
242 | "test-fscore {} "
243 | "test-elapsed {} "
244 | "total-elapsed {}".format(
245 | epoch,
246 | test_fscore,
247 | format_elapsed(test_start_time),
248 | format_elapsed(start_time),
249 | )
250 | )
251 |
252 |
253 | train_trees = latent_tree.build_dynamicRBT_trees(train_chunk_insts)
254 | train_trees = [(x, tree.convert(), chunks, latentscope) for x, tree, chunks, latentscope in train_trees]
255 |
256 | for epoch in itertools.count(start=1):
257 | if args.epochs is not None and epoch > args.epochs:
258 | break
259 |
260 | np.random.shuffle(train_chunk_insts)
261 | epoch_start_time = time.time()
262 |
263 | for start_index in range(0, len(train_chunk_insts), args.batch_size):
264 | dy.renew_cg()
265 | batch_losses = []
266 |
267 |
268 | for x, tree, chunks, latentscope in train_trees[start_index:start_index + args.batch_size]:
269 |
270 | discard = False
271 | for chunk in chunks:
272 | length = chunk[2] - chunk[1]
273 | if length > args.maxllimit:
274 | discard = True
275 | break
276 |
277 | if discard:
278 | continue
279 | print('discard')
280 |
281 |
282 | sentence = [(parse.XX, ch) for ch in x]
283 | if args.parser_type == "top-down":
284 | _, loss = parser.parse(sentence, tree, args.explore)
285 | else:
286 | _, loss = parser.parse(sentence, tree, chunks, latentscope)
287 | batch_losses.append(loss)
288 | total_processed += 1
289 | current_processed += 1
290 |
291 |
292 | batch_loss = dy.average(batch_losses)
293 | batch_loss_value = batch_loss.scalar_value()
294 | batch_loss.backward()
295 | trainer.update()
296 |
297 | print(
298 | "Epoch {:,} "
299 | "batch {:,}/{:,} "
300 | "processed {:,} "
301 | "batch-loss {:.4f} "
302 | "epoch-elapsed {} "
303 | "total-elapsed {}".format(
304 | epoch,
305 | start_index // args.batch_size + 1,
306 | int(np.ceil(len(train_chunk_insts) / args.batch_size)),
307 | total_processed,
308 | batch_loss_value,
309 | format_elapsed(epoch_start_time),
310 | format_elapsed(start_time),
311 | ), flush=True
312 | )
313 |
314 | if current_processed >= check_every:
315 | current_processed -= check_every
316 | if epoch > 7:
317 | check_dev()
318 |
319 |
320 | def run_test(args):
321 | #args.test_path = args.test_path.replace('[*]', args.treetype)
322 | print("Loading test trees from {}...".format(args.test_path))
323 | test_treebank = trees.load_trees(args.test_path, args.normal)
324 | print("Loaded {:,} test examples.".format(len(test_treebank)))
325 |
326 | print("Loading model from {}...".format(args.model_path_base))
327 | model = dy.ParameterCollection()
328 | [parser] = dy.load(args.model_path_base, model)
329 |
330 | label_vocab = vocabulary.Vocabulary()
331 |
332 | label_list = util.load_label_list('../data/labels.txt')
333 | for item in label_list:
334 | label_vocab.index((item, ))
335 | label_vocab.index((parse.EMPTY,))
336 | for item in label_list:
337 | label_vocab.index((item + "'",))
338 |
339 | label_vocab.freeze()
340 | latent_tree = latent.latent_tree_builder(label_vocab, args.RBTlabel)
341 |
342 |
343 | print("Parsing test sentences...")
344 |
345 | start_time = time.time()
346 |
347 | test_predicted = []
348 | test_gold = latent_tree.build_latent_trees(test_treebank)
349 | for x, chunks in test_treebank:
350 | dy.renew_cg()
351 | #sentence = [(leaf.tag, leaf.word) for leaf in tree.leaves()]
352 | sentence = [(parse.XX, ch) for ch in x]
353 | predicted, _ = parser.parse(sentence)
354 | test_predicted.append(predicted.convert())
355 |
356 | #test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, test_predicted, args.expname + '.test.')
357 | test_fscore = evaluate.eval_chunks(args.evalb_dir, test_gold, test_predicted, output_filename=args.expname + '.finaltest.txt') # evalb
358 | print(
359 | "test-fscore {} "
360 | "test-elapsed {}".format(
361 | test_fscore,
362 | format_elapsed(start_time),
363 | )
364 | )
365 |
366 | def main():
367 | dynet_args = [
368 | "--dynet-mem",
369 | "--dynet-weight-decay",
370 | "--dynet-autobatch",
371 | "--dynet-gpus",
372 | "--dynet-gpu",
373 | "--dynet-devices",
374 | "--dynet-seed",
375 | ]
376 |
377 | parser = argparse.ArgumentParser()
378 | subparsers = parser.add_subparsers()
379 |
380 | subparser = subparsers.add_parser("train")
381 | subparser.set_defaults(callback=run_train)
382 | for arg in dynet_args:
383 | subparser.add_argument(arg)
384 | subparser.add_argument("--numpy-seed", type=int)
385 | subparser.add_argument("--parser-type", choices=["top-down", "chartdyRBT", "chartdyRBTchunk", "chartdyRBTC", "chartdyRBTCseg"], required=True)
386 | subparser.add_argument("--tag-embedding-dim", type=int, default=50)
387 | subparser.add_argument("--word-embedding-dim", type=int, default=100)
388 | subparser.add_argument("--lstm-layers", type=int, default=2)
389 | subparser.add_argument("--lstm-dim", type=int, default=250)
390 | subparser.add_argument("--label-hidden-dim", type=int, default=250)
391 | subparser.add_argument("--split-hidden-dim", type=int, default=250)
392 | subparser.add_argument("--dropout", type=float, default=0.4)
393 | subparser.add_argument("--explore", action="store_true")
394 | subparser.add_argument("--model-path-base", required=True)
395 | subparser.add_argument("--evalb-dir", default="EVALB/")
396 | subparser.add_argument("--train-path", default="data/train.txt")
397 | subparser.add_argument("--dev-path", default="data/dev.txt")
398 | subparser.add_argument("--labellist-path", default="data/labels.txt")
399 | subparser.add_argument("--batch-size", type=int, default=10)
400 | subparser.add_argument("--epochs", type=int)
401 | subparser.add_argument("--checks-per-epoch", type=int, default=4)
402 | subparser.add_argument("--print-vocabs", action="store_true")
403 | subparser.add_argument("--test-path", default="data/test.txt")
404 | subparser.add_argument("--pretrainemb", default="giga")
405 | subparser.add_argument("--treetype", default="NRBT")
406 | subparser.add_argument("--expname", default="default")
407 | subparser.add_argument("--chunkencoding", type=int, default=1)
408 | subparser.add_argument("--trial", type=int, default=0)
409 | subparser.add_argument("--normal", type=int, default=1)
410 | subparser.add_argument("--RBTlabel", type=str, default="city")
411 | subparser.add_argument("--nontlabelstyle", type=int, default=0)
412 | subparser.add_argument("--zerocostchunk", type=int, default=0)
413 | subparser.add_argument("--loadmodel", type=str, default="none")
414 | subparser.add_argument("--trainc", type=int, default=1)
415 | subparser.add_argument("--maxllimit", type=int, default=38)
416 |
417 |
418 | subparser = subparsers.add_parser("test")
419 | subparser.set_defaults(callback=run_test)
420 | for arg in dynet_args:
421 | subparser.add_argument(arg)
422 | subparser.add_argument("--model-path-base", required=True)
423 | subparser.add_argument("--evalb-dir", default="EVALB/")
424 | subparser.add_argument("--treetype", default="NRBT")
425 | subparser.add_argument("--test-path", default="data/test.txt")
426 | subparser.add_argument("--expname", default="default")
427 | subparser.add_argument("--normal", type=int, default=1)
428 |
429 |
430 | args = parser.parse_args()
431 | args.callback(args)
432 |
433 | if __name__ == "__main__":
434 | main()
435 |
--------------------------------------------------------------------------------
/src/parse.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import dynet as dy
4 | import numpy as np
5 |
6 | import trees
7 | import util
8 |
9 |
10 | START = ""
11 | STOP = ""
12 | UNK = ""
13 | NUM = ""
14 | XX = "XX"
15 | EMPTY = ""
16 |
17 |
18 | def augment(scores, oracle_index):
19 | assert isinstance(scores, dy.Expression)
20 | shape = scores.dim()[0]
21 | assert len(shape) == 1
22 | increment = np.ones(shape)
23 | increment[oracle_index] = 0
24 | return scores + dy.inputVector(increment)
25 |
26 |
27 |
28 | class Feedforward(object):
29 | def __init__(self, model, input_dim, hidden_dims, output_dim):
30 | self.spec = locals()
31 | self.spec.pop("self")
32 | self.spec.pop("model")
33 |
34 | self.model = model.add_subcollection("Feedforward")
35 |
36 | self.weights = []
37 | self.biases = []
38 | dims = [input_dim] + hidden_dims + [output_dim]
39 | for prev_dim, next_dim in zip(dims, dims[1:]):
40 | self.weights.append(self.model.add_parameters((next_dim, prev_dim)))
41 | self.biases.append(self.model.add_parameters(next_dim))
42 |
43 | def param_collection(self):
44 | return self.model
45 |
46 | @classmethod
47 | def from_spec(cls, spec, model):
48 | return cls(model, **spec)
49 |
50 | def __call__(self, x):
51 | for i, (weight, bias) in enumerate(zip(self.weights, self.biases)):
52 | weight = dy.parameter(weight)
53 | bias = dy.parameter(bias)
54 | x = dy.affine_transform([bias, weight, x])
55 | if i < len(self.weights) - 1:
56 | x = dy.rectify(x)
57 | return x
58 |
59 |
60 | class ChartDynamicRBTConstraintParser(object):
61 | def __init__(
62 | self,
63 | model,
64 | tag_vocab,
65 | word_vocab,
66 | label_vocab,
67 | tag_embedding_dim,
68 | word_embedding_dim,
69 | lstm_layers,
70 | lstm_dim,
71 | label_hidden_dim,
72 | dropout,
73 | pretrainemb,
74 | chunk_encoding = 1,
75 | train_constraint = True,
76 | decode_constraint = True,
77 | zerocostchunk = 0,
78 | nontlabelstyle = 0,
79 | ):
80 | self.spec = locals()
81 | self.spec.pop("self")
82 | self.spec.pop("model")
83 |
84 | self.model = model.add_subcollection("Parser")
85 | self.tag_vocab = tag_vocab
86 | self.word_vocab = word_vocab
87 | self.label_vocab = label_vocab
88 | self.lstm_dim = lstm_dim
89 |
90 | self.tag_embeddings = self.model.add_lookup_parameters(
91 | (tag_vocab.size, tag_embedding_dim))
92 | self.word_embeddings = self.model.add_lookup_parameters(
93 | (word_vocab.size, word_embedding_dim))
94 |
95 | if pretrainemb[0] != 'none':
96 | print('Init Lookup table with pretrain emb')
97 | self.word_embeddings.init_from_array(pretrainemb[1])
98 |
99 | self.lstm = dy.BiRNNBuilder(
100 | lstm_layers,
101 | tag_embedding_dim + word_embedding_dim,
102 | 2 * lstm_dim,
103 | self.model,
104 | dy.VanillaLSTMBuilder)
105 |
106 | self.CompFwdRNN = dy.LSTMBuilder(lstm_layers, 2 * lstm_dim, 2 * lstm_dim, model)
107 | self.CompBwdRNN = dy.LSTMBuilder(lstm_layers, 2 * lstm_dim, 2 * lstm_dim, model)
108 |
109 | self.pW_comp = model.add_parameters((lstm_dim * 2, lstm_dim * 4))
110 | self.pb_comp = model.add_parameters((lstm_dim * 2,))
111 |
112 |
113 | self.f_label = Feedforward(
114 | self.model, 2 * lstm_dim, [label_hidden_dim], label_vocab.size - 1)
115 |
116 | self.dropout = dropout
117 |
118 | self.chunk_encoding = chunk_encoding
119 |
120 | self.train_constraint = train_constraint
121 | self.decode_constraint = decode_constraint
122 | self.zerocostchunk = zerocostchunk
123 | self.nontlabelstyle = nontlabelstyle
124 |
125 | def param_collection(self):
126 | return self.model
127 |
128 | @classmethod
129 | def from_spec(cls, spec, model):
130 | return cls(model, **spec)
131 |
132 |
133 |
134 |
135 | def parse(self, sentence, gold=None, gold_chunks=None, latentscope=None):
136 | is_train = gold is not None
137 |
138 | if is_train:
139 | self.lstm.set_dropout(self.dropout)
140 | else:
141 | self.lstm.disable_dropout()
142 |
143 | embeddings = []
144 | for tag, word in [(START, START)] + sentence + [(STOP, STOP)]:
145 | tag_embedding = self.tag_embeddings[self.tag_vocab.index(tag)]
146 | if word not in (START, STOP):
147 | count = self.word_vocab.count(word)
148 | if not count or (is_train and np.random.rand() < 1 / (1 + count)):
149 | word = UNK
150 | word_embedding = self.word_embeddings[self.word_vocab.index(word)]
151 | embeddings.append(dy.concatenate([tag_embedding, word_embedding]))
152 |
153 | lstm_outputs = self.lstm.transduce(embeddings)
154 |
155 | W_comp = dy.parameter(self.pW_comp)
156 | b_comp = dy.parameter(self.pb_comp)
157 |
158 | @functools.lru_cache(maxsize=None)
159 | def get_span_encoding(left, right):
160 | if self.chunk_encoding == 2:
161 | return get_span_encoding_chunk(left, right)
162 | forward = (
163 | lstm_outputs[right][:self.lstm_dim] -
164 | lstm_outputs[left][:self.lstm_dim])
165 | backward = (
166 | lstm_outputs[left + 1][self.lstm_dim:] -
167 | lstm_outputs[right + 1][self.lstm_dim:])
168 |
169 |
170 | bi = dy.concatenate([forward, backward])
171 |
172 | return bi
173 |
174 | @functools.lru_cache(maxsize=None)
175 | def get_span_encoding_chunk(left, right):
176 |
177 | fw_init = self.CompFwdRNN.initial_state()
178 | bw_init = self.CompBwdRNN.initial_state()
179 |
180 | fwd_exp = fw_init.transduce(lstm_outputs[left:right])
181 | bwd_exp = bw_init.transduce(reversed(lstm_outputs[left:right]))
182 |
183 | bi = dy.concatenate([fwd_exp[-1], bwd_exp[-1]])
184 | chunk_rep = dy.rectify(dy.affine_transform([b_comp, W_comp, bi]))
185 |
186 | return chunk_rep
187 |
188 | @functools.lru_cache(maxsize=None)
189 | def get_label_scores(left, right):
190 | non_empty_label_scores = self.f_label(get_span_encoding(left, right))
191 | return dy.concatenate([dy.zeros(1), non_empty_label_scores])
192 |
193 |
194 |
195 | def helper(force_gold):
196 | if force_gold:
197 | assert is_train
198 |
199 | chart = {}
200 | label_scores_span_max = {}
201 |
202 | for length in range(1, len(sentence) + 1):
203 | for left in range(0, len(sentence) + 1 - length):
204 | right = left + length
205 |
206 | label_scores = get_label_scores(left, right)
207 |
208 | if is_train:
209 | oracle_label = gold.oracle_label(left, right)
210 | oracle_label_index = self.label_vocab.index(oracle_label)
211 |
212 |
213 |
214 | if force_gold:
215 | label = oracle_label
216 | label_index = oracle_label_index
217 | label_score = label_scores[label_index]
218 |
219 | if self.nontlabelstyle == 3:
220 | label_scores_np = label_scores.npvalue()
221 | argmax_label_index = int(
222 | label_scores_np.argmax() if length < len(sentence) else
223 | label_scores_np[1:].argmax() + 1)
224 | argmax_label = self.label_vocab.value(argmax_label_index)
225 | label = argmax_label
226 | label_score = label_scores[argmax_label_index]
227 |
228 | else:
229 | if is_train:
230 | label_scores = augment(label_scores, oracle_label_index)
231 | label_scores_np = label_scores.npvalue()
232 | #argmax_score = dy.argmax(label_scores, gradient_mode="straight_through_gradient")
233 | #dy.dot_product()
234 | argmax_label_index = int(
235 | label_scores_np.argmax() if length < len(sentence) else
236 | label_scores_np[1:].argmax() + 1)
237 | argmax_label = self.label_vocab.value(argmax_label_index)
238 | label = argmax_label
239 | label_score = label_scores[argmax_label_index]
240 |
241 | if length == 1:
242 | tag, word = sentence[left]
243 | tree = trees.LeafParseNode(left, tag, word)
244 | if label:
245 | tree = trees.InternalParseNode(label, [tree])
246 | chart[left, right] = [tree], label_score
247 | label_scores_span_max[left, right] = [tree], label_score
248 | continue
249 |
250 | if force_gold:
251 | oracle_splits = gold.oracle_splits(left, right)
252 |
253 | if (len(label) > 0 and (label[0].endswith("'") or label[0] == EMPTY)) and latentscope[0] <= left <= right <= latentscope[1]: # if label == (EMPTY,) and latentscope[0] <= left <= right <= latentscope[1]: # and label != ():
254 | # Latent during Training
255 | #if self.train_constraint:
256 |
257 | # if self.train_constraint:
258 | # oracle_splits = [(oracle_splits[0], 0), (oracle_splits[-1], 1)]
259 | # else:
260 | # #oracle_splits = [(p, 0) for p in oracle_splits] + [(p, 1) for p in oracle_splits] # it is not correct
261 | # pass
262 |
263 | oracle_splits = [(oracle_splits[0], 0), (oracle_splits[-1], 1)]
264 | best_split = max(oracle_splits,
265 | key=lambda sb : #(split, branching) #branching == 0: right branching; 1: left branching
266 | label_scores_span_max[left, sb[0]][1].value() + chart[sb[0], right][1].value() if sb[1] == 0 else
267 | chart[left, sb[0]][1].value() + label_scores_span_max[sb[0], right][1].value()
268 | )
269 | else:
270 |
271 | best_split = (min(oracle_splits), 0) #by default right braching
272 |
273 |
274 | else:
275 | pred_range = range(left + 1, right)
276 | pred_splits = [(p, 0) for p in pred_range] + [(p, 1) for p in pred_range]
277 | best_split = max(pred_splits,
278 | key=lambda sb: # (split, branching) #branching == 0: right branching; 1: left branching
279 | label_scores_span_max[left, sb[0]][1].value() + chart[sb[0], right][1].value() if sb[1] == 0 else
280 | chart[left, sb[0]][1].value() + label_scores_span_max[sb[0], right][1].value()
281 | )
282 |
283 | children_leaf = [trees.LeafParseNode(pos, sentence[pos][0], sentence[pos][1]) for pos in range(left, right)]
284 | label_scores_span_max[left, right] = children_leaf, label_score
285 |
286 |
287 |
288 | if best_split[1] == 0:#Right Branching
289 | left_trees, left_score = label_scores_span_max[left, best_split[0]]
290 | right_trees, right_score = chart[best_split[0], right]
291 | else:#Left Branching
292 | left_trees, left_score = chart[left, best_split[0]]
293 | right_trees, right_score = label_scores_span_max[best_split[0], right]
294 |
295 |
296 | children = left_trees + right_trees
297 |
298 | if label:
299 | children = [trees.InternalParseNode(label, children)]
300 | if not label[0].endswith("'"):
301 | children_leaf = [trees.InternalParseNode(label, children_leaf)]
302 | label_scores_span_max[left, right] = children_leaf, label_score
303 |
304 | chart[left, right] = (children, label_score + left_score + right_score)
305 |
306 |
307 |
308 | children, score = chart[0, len(sentence)]
309 | assert len(children) == 1
310 | return children[0], score
311 |
312 | tree, score = helper(False)
313 | if is_train:
314 | oracle_tree, oracle_score = helper(True)
315 | #assert oracle_tree.convert().linearize() == gold.convert().linearize()
316 | #correct = tree.convert().linearize() == gold.convert().linearize()
317 |
318 | if self.zerocostchunk:
319 | pred_chunks = tree.convert().to_chunks()
320 | correct = (gold_chunks == pred_chunks)
321 | else:
322 | correct = False #(gold_chunks == pred_chunks)
323 | loss = dy.zeros(1) if correct else score - oracle_score
324 | #loss = score - oracle_score
325 | return tree, loss
326 | else:
327 | return tree, score
328 |
329 |
--------------------------------------------------------------------------------
/src/trees.py:
--------------------------------------------------------------------------------
1 | import collections.abc
2 | import parse
3 | import util
4 |
5 | class TreebankNode(object):
6 | pass
7 |
8 | class InternalTreebankNode(TreebankNode):
9 | def __init__(self, label, children):
10 | assert isinstance(label, str)
11 | self.label = label
12 |
13 | assert isinstance(children, collections.abc.Sequence)
14 | assert all(isinstance(child, TreebankNode) for child in children)
15 | #assert children
16 | self.children = tuple(children)
17 |
18 | def to_chunks(self):
19 | raw_chunks = []
20 | self.chunk_helper(raw_chunks)
21 |
22 | chunks = []
23 | p = 0
24 | for label, text_list in raw_chunks:
25 | if label.endswith("'"):
26 | label = label[:-1]
27 | chunks.append((label, p, p + len(text_list), text_list))
28 | p += len(text_list)
29 |
30 | return chunks
31 |
32 |
33 | def chunk_helper(self, chunks):
34 |
35 | children_status = [isinstance(child, LeafTreebankNode) for child in self.children]
36 |
37 | if all(children_status):
38 | chunks.append((self.label, [child.word for child in self.children]))
39 | elif any(children_status):
40 | char_list = []
41 |
42 | for child in self.children:
43 | if isinstance(child, InternalTreebankNode):
44 | char_list += child.get_word_list()
45 | else:
46 | char_list.append(child.word)
47 |
48 |
49 | chunk = (self.label, char_list)
50 | chunks.append(chunk)
51 | else:
52 | for child in self.children:
53 | if isinstance(child, InternalTreebankNode):
54 | child.chunk_helper(chunks)
55 |
56 |
57 | def get_word_list(self):
58 | word_list = []
59 | for child in self.children:
60 | if isinstance(child, InternalTreebankNode):
61 | word_list += child.get_word_list()
62 | else:
63 | word_list += [child.word]
64 | return word_list
65 |
66 | def linearize(self):
67 | return "({} {})".format(
68 | self.label, " ".join(child.linearize() for child in self.children))
69 |
70 | def leaves(self):
71 | for child in self.children:
72 | yield from child.leaves()
73 |
74 | def convert(self, index=0):
75 | tree = self
76 | sublabels = [self.label]
77 |
78 | while len(tree.children) == 1 and isinstance(tree.children[0], InternalTreebankNode):
79 | tree = tree.children[0]
80 | sublabels.append(tree.label)
81 |
82 | children = []
83 | for child in tree.children:
84 | children.append(child.convert(index=index))
85 | index = children[-1].right
86 |
87 | return InternalParseNode(tuple(sublabels), children)
88 |
89 |
90 |
91 | class InternalTreebankChunkNode(InternalTreebankNode):
92 | def __init__(self, label, children):
93 | super(InternalTreebankChunkNode, self).__init__(label, children)
94 | self.is_chunk_node = True
95 |
96 | def convert(self, index=0):
97 | tree = self
98 | sublabels = [self.label]
99 |
100 | while len(tree.children) == 1 and isinstance(tree.children[0], InternalTreebankNode):
101 | tree = tree.children[0]
102 | sublabels.append(tree.label)
103 |
104 | children = []
105 | for child in tree.children:
106 | children.append(child.convert(index=index))
107 | index = children[-1].right
108 |
109 | return InternalParseChunkNode(tuple(sublabels), children)
110 |
111 |
112 | class LeafTreebankNode(TreebankNode):
113 | def __init__(self, tag, word):
114 | assert isinstance(tag, str)
115 | self.tag = tag
116 |
117 | assert isinstance(word, str)
118 | self.word = word
119 |
120 | def linearize(self):
121 | return "({} {})".format(self.tag, self.word)
122 |
123 | def leaves(self):
124 | yield self
125 |
126 | def convert(self, index=0):
127 | return LeafParseNode(index, self.tag, self.word)
128 |
129 |
130 |
131 | class InternalUncompletedTreebankNode(TreebankNode):
132 | def __init__(self, label, chunkleaves, chunks_in_scope, latent):
133 | assert isinstance(label, str)
134 | self.label = label
135 | self.latent = latent
136 | #assert isinstance(children, collections.abc.Sequence)
137 | #assert all(isinstance(child, TreebankNode) for child in children)
138 | #assert children
139 | self.children = () #tuple(children)
140 | self.chunkleaves = chunkleaves
141 | self.chunks_in_scope = chunks_in_scope
142 |
143 |
144 | def linearize(self):
145 |
146 | #chunks_str_list = [ '(' + label + ' ' + ''.join( '(XX ' + item + ')' for item in text) + ')' for label, s, e, text in self.chunks_in_scope]
147 | return "({} leaves: {})".format(self.label, " ".join(leaf.linearize() for leaf in self.chunkleaves))
148 |
149 | def leaves(self):
150 | return self.chunkleaves
151 |
152 | def convert(self, index=0):
153 | tree = self
154 | sublabels = [self.label]
155 |
156 | # while len(tree.children) == 1 and isinstance(tree.children[0], InternalTreebankNode):
157 | # tree = tree.children[0]
158 | # sublabels.append(tree.label)
159 |
160 | #children = []
161 | # for child in tree.children:
162 | # children.append(child.convert(index=index))
163 | # index = children[-1].right
164 | #index = self.chunks_in_scope[0][1]
165 | chunkleaves = []
166 | for chunkleaf in self.chunkleaves:
167 | chunkleaves.append(chunkleaf.convert(index=index))
168 | index = chunkleaves[-1].right
169 | # for i in range(len(self.chunkleaves)):
170 | # chunk = self.chunks_in_scope[i]
171 | # chunkleaf = self.chunkleaves[i]
172 | # chunkleaf.convert(index=chunk[1])
173 | # chunkleaves.append(chunkleaf)
174 |
175 | return InternalUncompletedParseNode(tuple(sublabels), chunkleaves, self.chunks_in_scope, self.latent)
176 |
177 |
178 |
179 | class ParseNode(object):
180 | pass
181 |
182 | class InternalParseNode(ParseNode):
183 | def __init__(self, label, children):
184 | assert isinstance(label, tuple)
185 | assert all(isinstance(sublabel, str) for sublabel in label)
186 | assert label
187 | self.label = label
188 |
189 | assert isinstance(children, collections.abc.Sequence)
190 | assert all(isinstance(child, ParseNode) for child in children)
191 | assert children
192 | assert len(children) > 1 or isinstance(children[0], LeafParseNode)
193 | assert all(
194 | left.right == right.left
195 | for left, right in zip(children, children[1:]))
196 | self.children = tuple(children)
197 |
198 | self.left = children[0].left
199 | self.right = children[-1].right
200 |
201 | def leaves(self):
202 | for child in self.children:
203 | yield from child.leaves()
204 |
205 | def convert(self):
206 | children = [child.convert() for child in self.children]
207 | tree = InternalTreebankNode(self.label[-1], children)
208 | for sublabel in reversed(self.label[:-1]):
209 | tree = InternalTreebankNode(sublabel, [tree])
210 | return tree
211 |
212 | def enclosing(self, left, right):
213 | assert self.left <= left < right <= self.right
214 | for child in self.children:
215 | if isinstance(child, LeafParseNode):
216 | continue
217 | if child.left <= left < right <= child.right:
218 | return child.enclosing(left, right)
219 | return self
220 |
221 | def oracle_label(self, left, right):
222 | enclosing = self.enclosing(left, right)
223 | if enclosing.left == left and enclosing.right == right:
224 | return enclosing.label
225 | return ()
226 |
227 | def oracle_splits(self, left, right):
228 | # return [
229 | # child.left
230 | # for child in self.enclosing(left, right).children
231 | # if left < child.left < right
232 | # ]
233 | enclosing = self.enclosing(left, right)
234 | return [
235 | child.left
236 | for child in enclosing.children
237 | if left < child.left < right
238 | ]
239 | # if isinstance(enclosing, InternalUncompletedParseNode):
240 | # return [
241 | # child.left
242 | # for child in enclosing.chunkleaves
243 | # if left < child.left < right
244 | # ]
245 | # else:
246 | # return [
247 | # child.left
248 | # for child in enclosing.children
249 | # if left < child.left < right
250 | # ]
251 |
252 |
253 | def oracle_splits2(self, left, right):
254 | # return [
255 | # child.left
256 | # for child in self.enclosing(left, right).children
257 | # if left < child.left < right
258 | # ]
259 |
260 | enclosing = self.enclosing(left, right)
261 | if enclosing.left == left and enclosing.right == right:
262 | splits = [child.left for child in enclosing.children if not isinstance(child, LeafTreebankNode)]
263 | splits = splits[1:]
264 | return splits
265 | else:
266 | return []
267 |
268 |
269 | # if isinstance(self, InternalParseChunkNode):
270 | # return []
271 | # elif isinstance(self, InternalUncompletedParseNode):
272 | # return self.oracle_splits(left, right)
273 | # else:
274 | # #enclosing = self.enclosing(left, right)
275 | # if self.left == left and self.right == right:
276 | # splits = [child.left for child in self.children if not isinstance(child, LeafTreebankNode)]
277 | # splits = splits[1:]
278 | # return splits
279 | #
280 | #
281 | # for child in self.children:
282 | # if not isinstance(child, LeafParseNode):
283 | # if isinstance(child, InternalUncompletedParseNode):
284 | # if child.left <= left <= right <= child.right:
285 | # return child.oracle_splits(left, right)
286 | # else:
287 | # if child.left == left and child.right == right:
288 | # return child.oracle_splits2(left, right)
289 | #
290 | # return []
291 |
292 |
293 |
294 | class InternalParseChunkNode(InternalParseNode):
295 | def __init__(self, label, children):
296 | super(InternalParseChunkNode, self).__init__(label, children)
297 | self.is_chunk_node = True
298 | self._chunknode = None
299 |
300 | def convert(self):
301 | children = [child.convert() for child in self.children]
302 | tree = InternalTreebankChunkNode(self.label[-1], children)
303 | for sublabel in reversed(self.label[:-1]):
304 | tree = InternalTreebankChunkNode(sublabel, [tree])
305 | return tree
306 |
307 | def enclosing(self, left, right):
308 | assert self.left <= left < right <= self.right
309 | for child in self.children:
310 | if isinstance(child, LeafParseNode):
311 | continue
312 | if child.left <= left < right <= child.right:
313 | return child.enclosing(left, right)
314 |
315 | if self._chunknode is None:
316 | self._chunknode = InternalParseChunkNode(self.label, self.children)
317 | self._chunknode.children = []
318 | return self._chunknode
319 |
320 | def oracle_splits(self, left, right):
321 | # return [
322 | # child.left
323 | # for child in self.enclosing(left, right).children
324 | # if left < child.left < right
325 | # ]
326 | enclosing = self.enclosing(left, right)
327 | return [
328 | child.left
329 | for child in enclosing.children
330 | if left < child.left < right
331 | ]
332 |
333 |
334 | class InternalUncompletedParseNode(InternalParseNode):
335 | def __init__(self, label, chunkleaves, chunks_in_scope:[], latent):
336 | assert isinstance(label, tuple)
337 | assert all(isinstance(sublabel, str) for sublabel in label)
338 | assert label
339 | self.label = label
340 | self.latent = latent
341 |
342 | #assert isinstance(children, collections.abc.Sequence)
343 | #assert all(isinstance(child, ParseNode) for child in children)
344 | #assert children
345 | #assert len(children) > 1 or isinstance(children[0], LeafParseNode)
346 | # assert all(
347 | # left.right == right.left
348 | # for left, right in zip(children, children[1:]))
349 | self.children = [] #tuple(children)
350 | self.chunkleaves = chunkleaves
351 |
352 | #self.left = children[0].left
353 | #self.right = children[-1].right
354 | self.left = chunkleaves[0].left #chunks_in_scope[0][1]
355 | self.right = chunkleaves[-1].right #chunks_in_scope[-1][2]
356 | self.chunks_in_scope = chunks_in_scope
357 | self.splits = [chunk[1] for chunk in self.chunks_in_scope] + [self.chunks_in_scope[-1][2]]
358 |
359 | def leaves(self):
360 | return self.chunkleaves
361 |
362 | def convert(self):
363 | # children = [child.convert() for child in self.children]
364 | # tree = InternalUncompletedTreebankNode(self.label[-1], children)
365 | # for sublabel in reversed(self.label[:-1]):
366 | # tree = InternalUncompletedTreebankNode(sublabel, [tree])
367 |
368 | chunkleaves = [chunkleaf.convert() for chunkleaf in self.chunkleaves]
369 | tree = InternalUncompletedTreebankNode(self.label[-1], chunkleaves, (self.left, self.right), self.latent)
370 | return tree
371 |
372 | def enclosing(self, left, right):
373 | assert self.left <= left < right <= self.right
374 | for chunkleaf in self.chunkleaves:
375 | if isinstance(chunkleaf, LeafParseNode):
376 | continue
377 | if chunkleaf.left <= left < right <= chunkleaf.right:
378 | return chunkleaf.enclosing(left, right)
379 |
380 | # if left in self.splits and right in self.splits:
381 | # children = [chunkleaf for chunkleaf in self.chunkleaves if left <= chunkleaf.left and chunkleaf.right <= right]
382 | # # label = max([child.label for child in children],key=lambda l:self.latent.get_label_order(l[0]))
383 | # # label = (self.latent.non_terminal_label(label[0]),)
384 | # label = self.label
385 | # return InternalParseNode(label, children)
386 |
387 | children = [chunkleaf for chunkleaf in self.chunkleaves if left < chunkleaf.right]
388 | children = [chunkleaf for chunkleaf in children if right > chunkleaf.left]
389 |
390 | if self.latent.non_terminal_label_mode == 0 or self.latent.non_terminal_label_mode == 3:
391 | label = max([child.label for child in children], key=lambda l: self.latent.get_label_order(l[0]))
392 | label = (self.latent.non_terminal_label(label[0]),)
393 | elif self.latent.non_terminal_label_mode == 1:
394 | label = self.label
395 | else: #self.latent.non_terminal_label_mode == 2:
396 | import random
397 | label_id = random.randint(1 + self.latent.label_size + 0, 1 + self.latent.label_size + self.latent.label_size - 1)
398 | label = self.latent.label_vocab.value(label_id)
399 | return InternalParseNode(label, children)
400 |
401 | # self.children = self.chunkleaves
402 | # return self
403 |
404 | def oracle_label(self, left, right):
405 | # enclosing = self.enclosing(left, right)
406 | # if enclosing.left == left and enclosing.right == right:
407 | # return enclosing.label
408 |
409 | for chunk in self.chunks_in_scope:
410 | if chunk[1] == left and chunk[2] == right:
411 | return (chunk[0],)
412 |
413 |
414 | if left in self.splits and right in self.splits:
415 | return self.label
416 |
417 | return ()
418 |
419 | def oracle_splits(self, left, right):
420 |
421 |
422 | ret = [p for p in self.splits if left < p and p < right]
423 | if len(ret) == 0:
424 | return [
425 | child.left
426 | for child in self.enclosing(left, right).children
427 | if left < child.left < right
428 | ]
429 |
430 |
431 | return ret
432 |
433 |
434 |
435 | class LeafParseNode(ParseNode):
436 | def __init__(self, index, tag, word):
437 | assert isinstance(index, int)
438 | assert index >= 0
439 | self.left = index
440 | self.right = index + 1
441 |
442 | assert isinstance(tag, str)
443 | self.tag = tag
444 |
445 | assert isinstance(word, str)
446 | self.word = word
447 |
448 | def leaves(self):
449 | yield self
450 |
451 | def convert(self):
452 | return LeafTreebankNode(self.tag, self.word)
453 |
454 |
455 | def load_trees(path, normal, strip_top=True):
456 | with open(path, 'r', encoding='utf-8') as infile:
457 | tokens = infile.read().replace("(", " ( ").replace(")", " ) ").split()
458 |
459 | def helper(index):
460 | trees = []
461 |
462 | while index < len(tokens) and tokens[index] == "(":
463 | paren_count = 0
464 | while tokens[index] == "(":
465 | index += 1
466 | paren_count += 1
467 |
468 | label = tokens[index]
469 | index += 1
470 |
471 | if tokens[index] == "(":
472 | children, index = helper(index)
473 | trees.append(InternalTreebankNode(label, children))
474 | else:
475 | word = tokens[index]
476 | if normal == 1:
477 | newword = ''
478 | for c in word:
479 | if util.is_digit(c):
480 | newword += '0'
481 | else:
482 | newword += c
483 | else:
484 | newword = word
485 | index += 1
486 | trees.append(LeafTreebankNode(label, newword))
487 |
488 | while paren_count > 0:
489 | assert tokens[index] == ")"
490 | index += 1
491 | paren_count -= 1
492 |
493 | return trees, index
494 |
495 | trees, index = helper(0)
496 | assert index == len(tokens)
497 |
498 | if strip_top:
499 | for i, tree in enumerate(trees):
500 | if tree.label == "TOP":
501 | assert len(tree.children) == 1
502 | trees[i] = tree.children[0]
503 |
504 | return trees
505 |
--------------------------------------------------------------------------------
/src/util.py:
--------------------------------------------------------------------------------
1 | import vocabulary
2 |
3 |
4 |
5 | chinese_digit = '一二三四五六七八九十百千万'
6 |
7 | def is_digit(tok:str):
8 |
9 | tok = tok.strip()
10 |
11 | if tok.isdigit():
12 | return True
13 |
14 | if tok in chinese_digit:
15 | return True
16 |
17 | return False
18 |
19 |
20 | def seq2chunk(seq):
21 | chunks = []
22 | label = None
23 | last_label = None
24 | start_idx = 0
25 | for i in range(len(seq)):
26 | tok = seq[i]
27 | label = tok[2:] if tok.startswith('B') or tok.startswith('I') else tok
28 | if tok.startswith('B') or last_label != label:
29 | if last_label == None:
30 | start_idx = i
31 | else:
32 | chunks.append((last_label, start_idx, i))
33 | start_idx = i
34 |
35 | last_label = label
36 |
37 | chunks.append((label, start_idx, len(seq)))
38 |
39 | return chunks
40 |
41 | def read_chunks(filename, normal = 1):
42 | f = open(filename, 'r', encoding='utf-8')
43 | insts = []
44 | inst = list()
45 | num_inst = 0
46 | max_chunk_length_limit = 36
47 | max_chunk_length = 0
48 | max_char_length = 0
49 | for line in f:
50 | line = line.strip()
51 | if line == "":
52 | if inst != None:
53 |
54 | inst = [tuple(x) for x in inst]
55 |
56 | tmp = list(zip(*inst))
57 |
58 | x = tmp[0]
59 | new_x = []
60 | for word in x:
61 | if normal == 1:
62 | newword = ''
63 | for c in word:
64 | if is_digit(c):
65 | newword += '0'
66 | else:
67 | newword += c
68 | else:
69 | newword = word
70 |
71 | new_x.append(newword)
72 |
73 |
74 | y = [x[:2] + x[2:].lower() for x in tmp[1]]
75 |
76 | chunks = seq2chunk(y)
77 |
78 | for chunk in chunks:
79 | if max_chunk_length < chunk[2] - chunk[1]:
80 | max_chunk_length = chunk[2] - chunk[1]
81 |
82 | if max_char_length < len(x):
83 | max_char_length = len(x)
84 |
85 | insts.append((new_x, chunks))
86 |
87 | inst = list()
88 | else:
89 | inst.append(line.split())
90 | f.close()
91 | print(filename + 'is loaded.','\t', 'max_chunk_length=',max_chunk_length, '\tmax_char_length:', max_char_length)
92 | return insts
93 |
94 |
95 | def load_trees_from_str(tokens, normal = 1, strip_top=True):
96 | from trees import InternalTreebankNode, LeafTreebankNode
97 |
98 | tokens = tokens.replace("(", " ( ").replace(")", " ) ").split()
99 |
100 | def helper(index):
101 | trees = []
102 |
103 | while index < len(tokens) and tokens[index] == "(":
104 | paren_count = 0
105 | while tokens[index] == "(":
106 | index += 1
107 | paren_count += 1
108 |
109 | label = tokens[index]
110 | index += 1
111 |
112 | if tokens[index] == "(":
113 | children, index = helper(index)
114 | trees.append(InternalTreebankNode(label, children))
115 | else:
116 | word = tokens[index]
117 | if normal == 1:
118 | newword = ''
119 | for c in word:
120 | if is_digit(c):
121 | newword += '0'
122 | else:
123 | newword += c
124 | else:
125 | newword = word
126 | index += 1
127 | trees.append(LeafTreebankNode(label, newword))
128 |
129 | while paren_count > 0:
130 | assert tokens[index] == ")"
131 | index += 1
132 | paren_count -= 1
133 |
134 | return trees, index
135 |
136 | trees, index = helper(0)
137 | assert index == len(tokens)
138 |
139 | if strip_top:
140 | for i, tree in enumerate(trees):
141 | if tree.label == "TOP":
142 | assert len(tree.children) == 1
143 | trees[i] = tree.children[0]
144 |
145 | return trees
146 |
147 |
148 |
149 | def load_label_list(path):
150 | f = open(path, 'r', encoding='utf-8')
151 | label_list = [line.strip() for line in f]
152 | f.close()
153 | return label_list
154 |
155 |
156 | def load_pretrain(filename : str, WORD_DIM, word_vocab : vocabulary.Vocabulary, saveemb=True):
157 | import numpy as np
158 | import parse
159 | import pickle
160 |
161 | if filename == 'none':
162 | print("Do not use pretrain embedding...")
163 | return None
164 |
165 | print('Loading Pretrained Embedding from ', filename,' ...')
166 | vocab_dic = {}
167 | with open(filename, encoding='utf-8') as f:
168 | for line in f:
169 | s_s = line.split()
170 | if s_s[0] in word_vocab.counts:
171 | vocab_dic[s_s[0]] = np.array([float(x) for x in s_s[1:]])
172 | # vocab_dic[s_s[0]] = [float(x) for x in s_s[1:]]
173 |
174 | unknowns = np.random.uniform(-0.01, 0.01, WORD_DIM).astype("float32")
175 | numbers = np.random.uniform(-0.01, 0.01, WORD_DIM).astype("float32")
176 |
177 | vocab_dic[parse.UNK] = unknowns
178 | vocab_dic[parse.NUM] = numbers
179 |
180 |
181 |
182 | ret_mat = np.zeros((word_vocab.size, WORD_DIM))
183 | unk_counter = 0
184 | for token_id in range(word_vocab.size):
185 | token = word_vocab.value(token_id)
186 | if token in vocab_dic:
187 | # ret_mat.append(vocab_dic[token])
188 | ret_mat[token_id] = vocab_dic[token]
189 | # elif parse.is_digit(token) or token == '':
190 | # ret_mat[token_id] = numbers
191 | else:
192 | # ret_mat.append(unknowns)
193 | ret_mat[token_id] = unknowns
194 | # print "Unknown token:", token
195 | unk_counter += 1
196 | #print('unk:', token)
197 | ret_mat = np.array(ret_mat)
198 |
199 | print('ret_mat shape:', ret_mat.shape)
200 |
201 | if saveemb:
202 | with open('giga.emb', "wb") as f:
203 | pickle.dump(ret_mat, f)
204 |
205 | print("{0} unk out of {1} vocab".format(unk_counter, word_vocab.size))
206 | print('Glove Embedding is loaded.')
207 | return ret_mat
208 |
209 |
210 | def inst2chunks(inst):
211 | chunks = []
212 | x, x_chunks = inst
213 | for x_chunk in x_chunks:
214 | chunk = (x_chunk[0], x_chunk[1], x_chunk[2], x[x_chunk[1]:x_chunk[2]])
215 | chunks.append(chunk)
216 |
217 | return chunks
--------------------------------------------------------------------------------
/src/vocabulary.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | class Vocabulary(object):
4 | def __init__(self):
5 | self.frozen = False
6 | self.values = []
7 | self.indices = {}
8 | self.counts = collections.defaultdict(int)
9 |
10 | @property
11 | def size(self):
12 | return len(self.values)
13 |
14 | def value(self, index):
15 | assert 0 <= index < len(self.values)
16 | return self.values[index]
17 |
18 | def index(self, value):
19 | if not self.frozen:
20 | self.counts[value] += 1
21 |
22 | if value in self.indices:
23 | return self.indices[value]
24 |
25 | elif not self.frozen:
26 | self.values.append(value)
27 | self.indices[value] = len(self.values) - 1
28 | return self.indices[value]
29 |
30 | else:
31 | raise ValueError("Unknown value: {}".format(value))
32 |
33 | def count(self, value):
34 | return self.counts[value]
35 |
36 | def freeze(self):
37 | self.frozen = True
38 |
--------------------------------------------------------------------------------