├── .idea ├── encodings.xml ├── misc.xml ├── modules.xml ├── neural-chinese-address-parsing.iml └── vcs.xml ├── README.md ├── conlleval.pl ├── data ├── anno │ ├── anno-cn.md │ └── anno-en.md ├── dev.txt ├── giga.vec100 ├── labels.txt ├── test.txt └── train.txt ├── exp_dytree.sh ├── giga.emb ├── log.jpg └── src ├── evaluate.py ├── latent.py ├── latenttrees.py ├── main_dyRBT.py ├── parse.py ├── trees.py ├── util.py └── vocabulary.py /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/neural-chinese-address-parsing.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Chinese Address Parsing 2 | This page contains the code used in the work ["Neural Chinese Address Parsing"](http://statnlp.org/research/sp) published at [NAACL 2019](https://naacl2019.org/program/accepted/). 3 | 4 | 5 | ## Contents 6 | 1. [Usage](#usage) 7 | 2. [SourceCode](#sourcecode) 8 | 3. [Data](#data) 9 | 4. [Citation](#citation) 10 | 5. [Credits](#credits) 11 | 12 | 13 | ## Usage 14 | 15 | Prerequisite: Python (3.5 or later), Dynet (2.0 or later) 16 | 17 | Run the following command to try out the APLT(sp=7) model in the paper. 18 | ```sh 19 | ./exp_dytree.sh 20 | ``` 21 | After the training is complete, type the following command to display the result on test data. The performance outputed by conlleval.pl is shown as below. 22 | ```sh 23 | perl conlleval.pl < addr_dytree_giga_0.4_200_1_chardyRBTC_dytree_1_houseno_0_0.test.txt 24 | ``` 25 | 26 | ![alt text](log.jpg) 27 | 28 | ## SourceCode 29 | 30 | The source code is written in Dynet, which can be found under the "src" folder. 31 | 32 | 33 | ## Data 34 | 35 | The **data** is stored in "data" folder containing "train.txt", "dev.txt" and "test.txt". The embedding file "giga.vec100" is also located in the folder "data". 36 | 37 | **The annotation guidelines** are in the folder ["data/anno"](https://github.com/leodotnet/neural-chinese-address-parsing/blob/master/data/anno). Both [Chinese](https://github.com/leodotnet/neural-chinese-address-parsing/blob/master/data/anno/anno-cn.md) and [English](https://github.com/leodotnet/neural-chinese-address-parsing/blob/master/data/anno/anno-en.md) versions are available. 38 | 39 | ## Citation 40 | If you use this software for research, please cite our paper as follows: 41 | 42 | ``` 43 | @InProceedings{chineseaddressparsing19li, 44 | author = "Li, Hao and Lu, Wei and Xie, Pengjun and Li, Linlin", 45 | title = "Neural Chinese Address Parsing", 46 | booktitle = "Proc. of NAACL", 47 | year = "2019", 48 | } 49 | ``` 50 | 51 | 52 | ## Credits 53 | The code in this repository are based on https://github.com/mitchellstern/minimal-span-parser 54 | 55 | Email to [hao_li@mymail.sutd.edu.sg](hao_li@mymail.sutd.edu.sg) if any inquery. 56 | -------------------------------------------------------------------------------- /conlleval.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | # conlleval: evaluate result of processing CoNLL-2000 shared task 3 | # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file 4 | # README: http://cnts.uia.ac.be/conll2000/chunking/output.html 5 | # options: l: generate LaTeX output for tables like in 6 | # http://cnts.uia.ac.be/conll2003/ner/example.tex 7 | # r: accept raw result tags (without B- and I- prefix; 8 | # assumes one word per chunk) 9 | # d: alternative delimiter tag (default is single space) 10 | # o: alternative outside tag (default is O) 11 | # note: the file should contain lines with items separated 12 | # by $delimiter characters (default space). The final 13 | # two items should contain the correct tag and the 14 | # guessed tag in that order. Sentences should be 15 | # separated from each other by empty lines or lines 16 | # with $boundary fields (default -X-). 17 | # url: http://lcg-www.uia.ac.be/conll2000/chunking/ 18 | # started: 1998-09-25 19 | # version: 2004-01-26 20 | # author: Erik Tjong Kim Sang 21 | 22 | use strict; 23 | 24 | my $false = 0; 25 | my $true = 42; 26 | 27 | my $boundary = "-X-"; # sentence boundary 28 | my $correct; # current corpus chunk tag (I,O,B) 29 | my $correctChunk = 0; # number of correctly identified chunks 30 | my $correctTags = 0; # number of correct chunk tags 31 | my $correctType; # type of current corpus chunk tag (NP,VP,etc.) 32 | my $delimiter = "\t"; # field delimiter 33 | my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979) 34 | my $firstItem; # first feature (for sentence boundary checks) 35 | my $foundCorrect = 0; # number of chunks in corpus 36 | my $foundGuessed = 0; # number of identified chunks 37 | my $guessed; # current guessed chunk tag 38 | my $guessedType; # type of current guessed chunk tag 39 | my $i; # miscellaneous counter 40 | my $inCorrect = $false; # currently processed chunk is correct until now 41 | my $lastCorrect = "O"; # previous chunk tag in corpus 42 | my $latex = 0; # generate LaTeX formatted output 43 | my $lastCorrectType = ""; # type of previously identified chunk tag 44 | my $lastGuessed = "O"; # previously identified chunk tag 45 | my $lastGuessedType = ""; # type of previous chunk tag in corpus 46 | my $lastType; # temporary storage for detecting duplicates 47 | my $line; # line 48 | my $nbrOfFeatures = -1; # number of features per line 49 | my $precision = 0.0; # precision score 50 | my $oTag = "O"; # outside tag, default O 51 | my $raw = 0; # raw input: add B to every token 52 | my $recall = 0.0; # recall score 53 | my $tokenCounter = 0; # token counter (ignores sentence breaks) 54 | 55 | my %correctChunk = (); # number of correctly identified chunks per type 56 | my %foundCorrect = (); # number of chunks in corpus per type 57 | my %foundGuessed = (); # number of identified chunks per type 58 | 59 | my @features; # features on line 60 | my @sortedTypes; # sorted list of chunk type names 61 | 62 | # sanity check 63 | while (@ARGV and $ARGV[0] =~ /^-/) { 64 | if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); } 65 | elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); } 66 | elsif ($ARGV[0] eq "-d") { 67 | shift(@ARGV); 68 | if (not defined $ARGV[0]) { 69 | die "conlleval: -d requires delimiter character"; 70 | } 71 | $delimiter = shift(@ARGV); 72 | } elsif ($ARGV[0] eq "-o") { 73 | shift(@ARGV); 74 | if (not defined $ARGV[0]) { 75 | die "conlleval: -o requires delimiter character"; 76 | } 77 | $oTag = shift(@ARGV); 78 | } else { die "conlleval: unknown argument $ARGV[0]\n"; } 79 | } 80 | if (@ARGV) { die "conlleval: unexpected command line argument\n"; } 81 | # process input 82 | while () { 83 | chomp($line = $_); 84 | @features = split(/$delimiter/,$line); 85 | if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; } 86 | elsif ($nbrOfFeatures != $#features and @features != 0) { 87 | printf STDERR "unexpected number of features: %d (%d)\n", 88 | $#features+1,$nbrOfFeatures+1; 89 | exit(1); 90 | } 91 | if (@features == 0 or 92 | $features[0] eq $boundary) { @features = ($boundary,"O","O"); } 93 | if (@features < 2) { 94 | die "conlleval: unexpected number of features in line $line\n"; 95 | } 96 | if ($raw) { 97 | if ($features[$#features] eq $oTag) { $features[$#features] = "O"; } 98 | if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; } 99 | if ($features[$#features] ne "O") { 100 | $features[$#features] = "B-$features[$#features]"; 101 | } 102 | if ($features[$#features-1] ne "O") { 103 | $features[$#features-1] = "B-$features[$#features-1]"; 104 | } 105 | } 106 | # 20040126 ET code which allows hyphens in the types 107 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 108 | $guessed = $1; 109 | $guessedType = $2; 110 | } else { 111 | $guessed = $features[$#features]; 112 | $guessedType = ""; 113 | } 114 | pop(@features); 115 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 116 | $correct = $1; 117 | $correctType = $2; 118 | } else { 119 | $correct = $features[$#features]; 120 | $correctType = ""; 121 | } 122 | pop(@features); 123 | # ($guessed,$guessedType) = split(/-/,pop(@features)); 124 | # ($correct,$correctType) = split(/-/,pop(@features)); 125 | $guessedType = $guessedType ? $guessedType : ""; 126 | $correctType = $correctType ? $correctType : ""; 127 | $firstItem = shift(@features); 128 | 129 | # 1999-06-26 sentence breaks should always be counted as out of chunk 130 | if ( $firstItem eq $boundary ) { $guessed = "O"; } 131 | 132 | if ($inCorrect) { 133 | if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 134 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 135 | $lastGuessedType eq $lastCorrectType) { 136 | $inCorrect=$false; 137 | $correctChunk++; 138 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 139 | $correctChunk{$lastCorrectType}+1 : 1; 140 | } elsif ( 141 | &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) != 142 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or 143 | $guessedType ne $correctType ) { 144 | $inCorrect=$false; 145 | } 146 | } 147 | 148 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 149 | &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 150 | $guessedType eq $correctType) { $inCorrect = $true; } 151 | 152 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) { 153 | $foundCorrect++; 154 | $foundCorrect{$correctType} = $foundCorrect{$correctType} ? 155 | $foundCorrect{$correctType}+1 : 1; 156 | } 157 | if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) { 158 | $foundGuessed++; 159 | $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ? 160 | $foundGuessed{$guessedType}+1 : 1; 161 | } 162 | if ( $firstItem ne $boundary ) { 163 | if ( $correct eq $guessed and $guessedType eq $correctType ) { 164 | $correctTags++; 165 | } 166 | $tokenCounter++; 167 | } 168 | 169 | $lastGuessed = $guessed; 170 | $lastCorrect = $correct; 171 | $lastGuessedType = $guessedType; 172 | $lastCorrectType = $correctType; 173 | } 174 | if ($inCorrect) { 175 | $correctChunk++; 176 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 177 | $correctChunk{$lastCorrectType}+1 : 1; 178 | } 179 | 180 | if (not $latex) { 181 | # compute overall precision, recall and FB1 (default values are 0.0) 182 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 183 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 184 | $FB1 = 2*$precision*$recall/($precision+$recall) 185 | if ($precision+$recall > 0); 186 | 187 | # print overall performance 188 | printf "processed $tokenCounter tokens with $foundCorrect phrases; "; 189 | printf "found: $foundGuessed phrases; correct: $correctChunk.\n"; 190 | if ($tokenCounter>0) { 191 | printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter; 192 | printf "precision: %6.2f%%; ",$precision; 193 | printf "recall: %6.2f%%; ",$recall; 194 | printf "FB1: %6.2f\n",$FB1; 195 | } 196 | } 197 | 198 | # sort chunk type names 199 | undef($lastType); 200 | @sortedTypes = (); 201 | foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) { 202 | if (not($lastType) or $lastType ne $i) { 203 | push(@sortedTypes,($i)); 204 | } 205 | $lastType = $i; 206 | } 207 | # print performance per chunk type 208 | if (not $latex) { 209 | for $i (@sortedTypes) { 210 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 211 | if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; } 212 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 213 | if (not($foundCorrect{$i})) { $recall = 0.0; } 214 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 215 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 216 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 217 | printf "%17s: ",$i; 218 | printf "precision: %6.2f%%; ",$precision; 219 | printf "recall: %6.2f%%; ",$recall; 220 | printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i}; 221 | } 222 | } else { 223 | print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline"; 224 | for $i (@sortedTypes) { 225 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 226 | if (not($foundGuessed{$i})) { $precision = 0.0; } 227 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 228 | if (not($foundCorrect{$i})) { $recall = 0.0; } 229 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 230 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 231 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 232 | printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\", 233 | $i,$precision,$recall,$FB1; 234 | } 235 | print "\\hline\n"; 236 | $precision = 0.0; 237 | $recall = 0; 238 | $FB1 = 0.0; 239 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 240 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 241 | $FB1 = 2*$precision*$recall/($precision+$recall) 242 | if ($precision+$recall > 0); 243 | printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n", 244 | $precision,$recall,$FB1; 245 | } 246 | 247 | exit 0; 248 | 249 | # endOfChunk: checks if a chunk ended between the previous and current word 250 | # arguments: previous and current chunk tags, previous and current types 251 | # note: this code is capable of handling other chunk representations 252 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 253 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 254 | 255 | sub endOfChunk { 256 | my $prevTag = shift(@_); 257 | my $tag = shift(@_); 258 | my $prevType = shift(@_); 259 | my $type = shift(@_); 260 | my $chunkEnd = $false; 261 | 262 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; } 263 | if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; } 264 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; } 265 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 266 | 267 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; } 268 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; } 269 | if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; } 270 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 271 | 272 | if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) { 273 | $chunkEnd = $true; 274 | } 275 | 276 | # corrected 1998-12-22: these chunks are assumed to have length 1 277 | if ( $prevTag eq "]" ) { $chunkEnd = $true; } 278 | if ( $prevTag eq "[" ) { $chunkEnd = $true; } 279 | 280 | return($chunkEnd); 281 | } 282 | 283 | # startOfChunk: checks if a chunk started between the previous and current word 284 | # arguments: previous and current chunk tags, previous and current types 285 | # note: this code is capable of handling other chunk representations 286 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 287 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 288 | 289 | sub startOfChunk { 290 | my $prevTag = shift(@_); 291 | my $tag = shift(@_); 292 | my $prevType = shift(@_); 293 | my $type = shift(@_); 294 | my $chunkStart = $false; 295 | 296 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; } 297 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; } 298 | if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; } 299 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 300 | 301 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; } 302 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; } 303 | if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; } 304 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 305 | 306 | if ($tag ne "O" and $tag ne "." and $prevType ne $type) { 307 | $chunkStart = $true; 308 | } 309 | 310 | # corrected 1998-12-22: these chunks are assumed to have length 1 311 | if ( $tag eq "[" ) { $chunkStart = $true; } 312 | if ( $tag eq "]" ) { $chunkStart = $true; } 313 | 314 | return($chunkStart); 315 | } 316 | -------------------------------------------------------------------------------- /data/anno/anno-cn.md: -------------------------------------------------------------------------------- 1 | # 地址标准化标注规范 2 | ## 标注目标 3 | 地址结构化:识别地址中的行政区划、其它地址要素、辅助定位词,并打上英文标注。标注包括每个要素的范围和标签。 4 | 5 | 6 | ## 标注示例 7 | ####【prov】 8 | 1. 解释:省级行政区划,省、自治区; 9 | 2. 示例1: 10 | 1. Query:内蒙古赤峰市锦山镇 11 | 2. 结构化:prov=内蒙古 city=赤峰市 town=锦山镇 12 | 13 | ####【city】 14 | 1. 解释:地级行政区划,地级市、直辖市、地区、自治州等; 15 | 2. 示例1: 16 | 1. Query:杭州市富阳区戴家墩路91号东阳诚心木线(富阳店) 17 | 2. 结构化:city=杭州市 district=富阳区 road=戴家墩路 roadno=91号 poi=东阳诚心木线(富阳店) 18 | 3. 示例2: 19 | 1. Query:西藏自治区日喀则地区定日县柑碑村712号 20 | 2. 结构化:prov=西藏自治区 city=日喀则地区 district=定日县 community=柑碑村 roadno=712号 21 | 22 | ####【district】 23 | 1. 解释:县级行政区划,市辖区、县级市、县等; 24 | 2. 示例1: 25 | 1. Query:东城区福新东路245号 26 | 2. 结构化:district=东城区 road=福新东路 roadno=245号 27 | 28 | ####【devzone】 29 | 1. 解释:非正式行政区划,级别在district和town之间,或者town之后,特指经济技术开发区; 30 | 2. 示例1: 31 | 1. Query:内蒙古自治区呼和浩特市土默特左旗金川开发区公元仰山9号楼2单元202 32 | 2. 结构化:prov=内蒙古自治区 city=呼和浩特市 district=土默特左旗 town=察素齐镇 devzone=金川开发区 poi=公元仰山 houseno=9号楼 cellno=2单元 roomno=202室 33 | 3. 示例2: 34 | 1. Query:南宁市青秀区仙葫经济开发区开泰路148号广西警察学院仙葫校区 35 | 2. 结构化:city=南宁市 district=青秀区 devzone=仙葫经济开发区 road=开泰路 roadno=148号 poi=广西警察学院仙葫校区 36 | 37 | ####【town】 38 | 1. 解释:乡级行政区划,镇、街道、乡等。 39 | 2. 示例1: 40 | 1. Query:上海市 静安区 共和新路街道 柳营路669弄14号1102 41 | 2. 结构化:city=上海市 district=静安区 town=共和新路街道 road=柳营路 roadno=669弄 houseno=14号 roomno=1102室 42 | 3. 示例2: 43 | 1. Query:五常街道顾家桥社区河西北9号衣服鞋子店 44 | 2. 结构化:town=五常街道 community=顾家桥社区 road=河西北 roadno=9号 poi=衣服鞋子店 45 | 46 | ####【community】 47 | 1. 解释:社区、自然村; 48 | 2. 示例1: 49 | 1. Query:张庆乡北胡乔村 50 | 2. 结构化:town=张庆乡 community=北胡乔村 51 | 3. 示例2: 52 | 1. Query:五常街道顾家桥社区河西北9号衣服鞋子店 53 | 2. 结构化:town=五常街道 community=顾家桥社区 road=河西北 roadno=9号 poi=衣服鞋子店 54 | 55 | ####【road】 56 | 1. 解释:道路、组; 57 | 2. 示例1: 58 | 1. Query:静安区江场三路238号1613室 59 | 2. 结构化:district=静安区 road=江场三路 roadno=238号 roomno=1613室 60 | 3. 示例2: 61 | 1. Query:沿山村5组 62 | 2. 结构化:community=沿山村 road=5组 63 | 4. 示例2: 64 | 1. Query:江宁区江宁滨江开发区中环大道10号环宇人力行政部 65 | 2. 结构化:district=江宁区 town=江宁街道 devzone=江宁滨江开发区 road=中环大道 roadno=10号 poi=环宇人力行政部 66 | 67 | ####【roadno】 68 | 1. 解释:门牌号、主路号、组号、组号附号; 69 | 2. 示例1: 70 | 1. Query:江宁区江宁滨江开发区中环大道10号环宇人力行政部 71 | 2. 结构化:district=江宁区 town=江宁街道 devzone=江宁滨江开发区 road=中环大道 roadno=10号 poi=环宇人力行政部 72 | 3. 示例2: 73 | 1. Query:沿山村5组6号 74 | 2. 结构化:community=沿山村 roadno=5组6号 75 | 4. 注意:出现“X组X号”,“X组X号附X号”,“X号附X号” 这三种情况,统一标为roadno,不拆分road和roadno。 76 | 77 | ####【subroad】 78 | 1. 解释:子路、支路、辅路; 79 | 2. 示例1: 80 | 1. Query:浙江省台州市临海市江南大道创业大道288号 81 | 2. 结构化:prov=浙江省 city=台州市 district=临海市 town=江南街道 road=江南大道 subroad=创业大道 subroadno=288号 82 | 83 | 84 | ####【subroadno】 85 | 1. 解释:(子路、支路、辅路)门牌号、路号、组号、组号附号; 86 | 2. 示例1:见上一条。 87 | 88 | ####【poi】 89 | 1. 解释:兴趣点; 90 | 2. 示例1: 91 | 1. Query:浙江省杭州市余杭区五常街道文一西路969号阿里巴巴西溪园区 92 | 2. 结构化:prov=浙江省 city=杭州市 district=余杭区 town=五常街道 road=文一西路 roadno=969号 poi=阿里巴巴西溪园区 93 | 94 | ####【subpoi】 95 | 1. 解释:子兴趣点; 96 | 2. 示例1: 97 | 1. Query:新疆维吾尔自治区 昌吉回族自治州 昌吉市 延安北路街道 延安南路石油小区东门 98 | 2. 结构化:prov=新疆维吾尔自治区 city=昌吉回族自治州 district=昌吉市 town=延安北路街道 road=延安南路 poi=石油小区 subpoi=东门 99 | 3. 示例2: 100 | 1. Query:西湖区新金都城市花园西雅园10幢3底层 101 | 2. 结构化:district=西湖区 town=文新街道 road=古墩路 roadno=415号 poi=新金都城市花园 subpoi=西雅园 houseno=10幢 floorno=3底层 102 | 4. 示例3: 103 | 1. Query:广宁伯街2号金泽大厦东区15层 104 | 2. 结构化:road=广宁伯街 roadno=2号 poi=金泽大厦 subpoi=东区 floorno=15层 105 | 5. 注意:poi后面还有一个poi,而且地域上属于之前poi的范围内或附近,可以认为是subpoi,通常见于“XX小区【东门】”、“XX大厦【收发室】”、“阿里巴巴西溪园区【全家】”,“西溪花园【蒹葭苑】”,“西溪北苑【东区】”;特别需要注意“XXX小区一期”、“XXX小区二期”这种,整体是一个POI,“一期”不是一个subpoi。 106 | 107 | ####【houseno】 108 | 1. 解释:楼栋号; 109 | 2. 示例1: 110 | 1. Query:阿里巴巴西溪园区6号楼小邮局 111 | 2. 结构化:poi=阿里巴巴西溪园区 houseno=6号楼 person=小邮局 112 | 3. 示例2: 113 | 1. Query:四川省 成都市 金牛区 沙河源街道 金牛区九里堤街道 金府机电城A区3栋16号 114 | 2. 结构化:prov=四川省 city=成都市 district=金牛区 town=沙河源街道 road=金府路 poi=金府机电城 subpo=A区 houseno=3栋 cellno=16号 115 | 4. 示例3: 116 | 1. Query:竹海水韵春风里12-3-1001 117 | 2. 结构化:poi=竹海水韵 subpoi=春风里 houseno=12幢 cellno=3单元 roomno= 1001室 118 | 5. 注意:一般跟在poi之后,常见于“智慧产业园二期【一号楼】”、“春风里【7】-2-801”。 119 | 120 | ####【cellno】 121 | 1. 解释:单元号; 122 | 2. 示例1: 123 | 1. Query:竹海水韵春风里12-3-1001 124 | 2. 结构化:poi=竹海水韵 subpoi=春风里 houseno=12幢 cellno=3单元 roomno= 1001室 125 | 3. 示例2: 126 | 1. Query:蒋村花园新达苑18幢二单元101 127 | 2. 结构化:town=蒋村街道 community=蒋村花园社区 road=晴川街 poi=蒋村花园 subpoi=新达苑 houseno=18幢 cellno=二单元 roomno=101室 128 | 129 | ####【floorno】 130 | 1. 解释:楼层号; 131 | 2. 示例1: 132 | 1. Query:北京市东城区东中街29号东环广场B座5层信达资本 133 | 2. 结构化:city=北京市 district=东城区 town=东直门街道 road=东中街 roadno=29号 poi=东环广场 houseno=B座 floorno=5层 person=信达资本 134 | 135 | ####【roomno】 136 | 1. 解释:房间号、户号; 137 | 2. 示例1:见cellno的示例 138 | 139 | ####【person】 140 | 1. 解释:楼栋内的“POI”,企业、法人、商铺名等。 141 | 2. 示例1: 142 | 1. Query:北京 北京市 西城区 广安门外街道 马连道马正和大厦3层我的未来网总部 143 | 2. 结构化:city=北京市 district=西城区 town=广安门外街道 road=马连道 poi=马正和大厦 floorno=3层 person=我的未来网总部 144 | 3. 示例2: 145 | 1. Query:浙江省 杭州市 余杭区 良渚街道沈港路11号2楼 常春藤公司 146 | 2. 结构化:prov=浙江省 city=杭州市 district=余杭区 town=良渚街道 road=沈港路 roadno=11号 floorno=2楼 person=常春藤公司 147 | 148 | 149 | ####【assist】 150 | 1. 解释:普通辅助定位词,比如门口、旁边、附近、楼下、边上等等较为一般的或者模糊的定位词; 151 | 2. 示例1: 152 | 1. Query:广西柳州市城中区潭中东路勿忘我网吧门口 153 | 2. 结构化:prov=广西壮族自治区 city=柳州市 district=城中区 town=潭中街道 road=潭中东路 poi=勿忘我网吧 assist=门口 154 | 155 | ####【redundant】 156 | 1. 解释:无意义词,冗余等对地址无帮助的信息; 157 | 2. 示例1: 158 | 1. Query:浙江省 杭州市 滨江区 浙江省 杭州市 滨江区 六和路东50米 六和路东信大道口自行车租赁点 159 | 2. 结构化:prov=浙江省 city=杭州市 district=滨江区 redundant=东 浙江省杭州市滨江区 town=浦沿街道 road=六和路 subroad=东信大道 intersection=口 poi=自行车租赁点 160 | 161 | 3. 示例2: 162 | 1. Query:浙江省 杭州市 滨江区 六和路 ---- 东信大道口自行车租赁点 163 | 2. 结构化:prov=浙江省 city=杭州市 district=滨江区 town=浦沿街道 road=六和路 redundant=---- subroad=东信大道 intersection=口 poi=自行车租赁点 164 | 165 | ####【otherinfo】 166 | 1. 解释:其他无法分类的信息。 167 | 168 | ## 标注篇序关系 169 | label之间存在较强的篇序关系,比如说city不能跑到prov前面去,具体有如下几种偏序关系: 170 | 171 | #### 1. prov > city > district > town > comm > road > roadno > poi > houseno > cellno > floorno > roomno 172 | #### 2. district > devzone 173 | #### 3. devzone > comm 174 | #### 4. road > subroad 175 | #### 5. poi > subpoi 176 | 177 | -------------------------------------------------------------------------------- /data/anno/anno-en.md: -------------------------------------------------------------------------------- 1 | # Chinese Address Parsing Annotation Guideline 2 | ## Goal 3 | Recognize all the chunks (e.g. city, road, etc.) in the given Chinese address. For each chunk, the boundary and the label are required to be annotated. 4 | 5 | ####【prov】 6 | 1. Interpretation: province, autonomous region 7 | 2. Example 1: 8 | 1. Query: 内蒙古赤峰市锦山镇 9 | 2. Annotation: prov=内蒙古 city=赤峰市 district=喀喇沁旗 town=锦山镇 10 | 3. Example 2: 11 | 1. Query: 渭南市大荔县户家乡边章营村 12 | 2. Annotation: prov=陕西省 city=渭南市 district=大荔县 town=户家乡 community=边章营村 13 | 14 | ####【city】 15 | 1. Interpretation: city,municipality,autonomous District 16 | 2. Example 1: 17 | 1. Query: 杭州市富阳区戴家墩路91号东阳诚心木线(富阳店) 18 | 2. Annotation: city=杭州市 district=富阳区 road=戴家墩路 roadno=91号 poi=东阳诚心木线(富阳店) 19 | 3. Example 2: 20 | 1. Query: 西藏自治区日喀则地区定日县柑碑村712号 21 | 2. Annotation: prov=西藏自治区 city=日喀则地区 district=定日县 community=柑碑村 roadno=712号 22 | 23 | ####【district】 24 | 1. Interpretation: district 25 | 2. Example 1: 26 | 1. Query: 东城区福新东路245号 27 | 2. Annotation: district=东城区 road=福新东路 roadno=245号 28 | 29 | ####【devzone】 30 | 1. Interpretation: economical development zone 31 | 2. Example 1: 32 | 1. Query: 内蒙古自治区呼和浩特市土默特左旗金川开发区公元仰山9号楼2单元202 33 | 2. Annotation: prov=内蒙古自治区 city=呼和浩特市 district=土默特左旗 town=察素齐镇 devzone=金川开发区 poi=公元仰山 houseno=9号楼 cellno=2单元 roomno=202室 34 | 3. Example 2: 35 | 1. Query: 南宁市青秀区仙葫经济开发区开泰路148号广西警察学院仙葫校区 36 | 2. Annotation: city=南宁市 district=青秀区 devzone=仙葫经济开发区 road=开泰路 roadno=148号 poi=广西警察学院仙葫校区 37 | 38 | ####【town】 39 | 1. Interpretation: town, administrative street; 40 | 2. Example 1: 41 | 1. Query: 上海市 静安区 共和新路街道 柳营路669弄14号1102 42 | 2. Annotation: city=上海市 district=静安区 town=共和新路街道 road=柳营路 roadno=669弄 houseno=14号 roomno=1102室 43 | 3. Example 2: 44 | 1. Query: 五常街道顾家桥社区河西北9号衣服鞋子店 45 | 2. Annotation: town=五常街道 community=顾家桥社区 road=河西北 roadno=9号 poi=衣服鞋子店 46 | 47 | ####【community】 48 | 1. Interpretation: community 49 | 2. Example 1: 50 | 1. Query: 张庆乡北胡乔村 51 | 2. Annotation: town=张庆乡 community=北胡乔村 52 | 3. Example 2: 53 | 1. Query: 五常街道顾家桥社区河西北9号衣服鞋子店 54 | 2. Annotation: town=五常街道 community=顾家桥社区 road=河西北 roadno=9号 poi=衣服鞋子店 55 | 56 | ####【road】 57 | 1. Interpretation: road 58 | 2. Example 1: 59 | 1. Query: 静安区江场三路238号1613室 60 | 2. Annotation: district=静安区 road=江场三路 roadno=238号 roomno=1613室 61 | 3. Example 2: 62 | 1. Query: 沿山村5组 63 | 2. Annotation: community=沿山村 road=5组 64 | 4. Example 2: 65 | 1. Query: 江宁区江宁滨江开发区中环大道10号环宇人力行政部 66 | 2. Annotation: district=江宁区 town=江宁街道 devzone=江宁滨江开发区 road=中环大道 roadno=10号 poi=环宇人力行政部 67 | 68 | ####【roadno】 69 | 1. Interpretation: road number 70 | 2. Example 1: 71 | 1. Query: 江宁区江宁滨江开发区中环大道10号环宇人力行政部 72 | 2. Annotation: district=江宁区 town=江宁街道 devzone=江宁滨江开发区 road=中环大道 roadno=10号 poi=环宇人力行政部 73 | 3. Example 2: 74 | 1. Query: 沿山村5组6号 75 | 2. Annotation: community=沿山村 roadno=5组6号 76 | 77 | ####【subroad】 78 | 1. Interpretation: subroad 79 | 2. Example 1: 80 | 1. Query: 浙江省台州市临海市江南大道创业大道288号 81 | 2. Annotation: prov=浙江省 city=台州市 district=临海市 town=江南街道 road=江南大道 subroad=创业大道 subroadno=288号 82 | 83 | 84 | ####【subroadno】 85 | 1. Interpretation: subroad number 86 | 2. Example 1: See the last example。 87 | 88 | ####【poi】 89 | 1. Interpretation: point of interest; 90 | 2. Example 1: 91 | 1. Query: 浙江省杭州市余杭区五常街道文一西路969号阿里巴巴西溪园区 92 | 2. Annotation: prov=浙江省 city=杭州市 district=余杭区 town=五常街道 road=文一西路 roadno=969号 poi=阿里巴巴西溪园区 93 | 94 | ####【subpoi】 95 | 1. Interpretation: sub-poi; 96 | 2. Example 1: 97 | 1. Query: 新疆维吾尔自治区 昌吉回族自治州 昌吉市 延安北路街道 延安南路石油小区东门 98 | 2. Annotation: prov=新疆维吾尔自治区 city=昌吉回族自治州 district=昌吉市 town=延安北路街道 road=延安南路 poi=石油小区 subpoi=东门 99 | 3. Example 2: 100 | 1. Query: 西湖区新金都城市花园西雅园10幢3底层 101 | 2. Annotation: district=西湖区 town=文新街道 road=古墩路 roadno=415号 poi=新金都城市花园 subpoi=西雅园 houseno=10幢 floorno=3底层 102 | 4. Example 3: 103 | 1. Query: 广宁伯街2号金泽大厦东区15层 104 | 2. Annotation: road=广宁伯街 roadno=2号 poi=金泽大厦 subpoi=东区 floorno=15层 105 | 5. Comment: We regard the second poi appearing after the first poi as subpoi if they are located in the same region. 106 | 107 | ####【houseno】 108 | 1. Interpretation: house number; 109 | 2. Example 1: 110 | 1. Query: 阿里巴巴西溪园区6号楼小邮局 111 | 2. Annotation: poi=阿里巴巴西溪园区 houseno=6号楼 person=小邮局 112 | 3. Example 2: 113 | 1. Query: 四川省 成都市 金牛区 沙河源街道 金牛区九里堤街道 金府机电城A区3栋16号 114 | 2. Annotation: prov=四川省 city=成都市 district=金牛区 town=沙河源街道 road=金府路 poi=金府机电城 subpo=A区 houseno=3栋 cellno=16号 115 | 4. Example 3: 116 | 1. Query: 竹海水韵春风里12-3-1001 117 | 2. Annotation: poi=竹海水韵 subpoi=春风里 houseno=12幢 cellno=3单元 roomno= 1001室 118 | 5. Comment: It usually appears after poi. 119 | 120 | ####【cellno】 121 | 1. Interpretation: cell number; 122 | 2. Example 1: 123 | 1. Query: 竹海水韵春风里12-3-1001 124 | 2. Annotation: poi=竹海水韵 subpoi=春风里 houseno=12幢 cellno=3单元 roomno= 1001室 125 | 3. Example 2: 126 | 1. Query: 蒋村花园新达苑18幢二单元101 127 | 2. Annotation: town=蒋村街道 community=蒋村花园社区 road=晴川街 poi=蒋村花园 subpoi=新达苑 houseno=18幢 cellno=二单元 roomno=101室 128 | 129 | ####【floorno】 130 | 1. Interpretation: floor number; 131 | 2. Example 1: 132 | 1. Query: 北京市东城区东中街29号东环广场B座5层信达资本 133 | 2. Annotation: city=北京市 district=东城区 town=东直门街道 road=东中街 roadno=29号 poi=东环广场 houseno=B座 floorno=5层 person=信达资本 134 | 135 | ####【roomno】 136 | 1. Interpretation: room number; 137 | 2. Example 1: See the example in cell number 138 | 139 | ####【person】 140 | 1. Interpretation: name of a company or company representative 141 | 2. Example 1: 142 | 1. Query: 北京 北京市 西城区 广安门外街道 马连道马正和大厦3层我的未来网总部 143 | 2. Annotation: city=北京市 district=西城区 town=广安门外街道 road=马连道 poi=马正和大厦 floorno=3层 person=我的未来网总部 144 | 3. Example 2: 145 | 1. Query: 浙江省 杭州市 余杭区 良渚街道沈港路11号2楼 常春藤公司 146 | 2. Annotation: prov=浙江省 city=杭州市 district=余杭区 town=良渚街道 road=沈港路 roadno=11号 floorno=2楼 person=常春藤公司 147 | 148 | ####【assist】 149 | 1. Interpretation: assistant words for better location. For example, 旁边(beside), 对面(opposite) 150 | 2. Example 1: 151 | 1. Query: 广西柳州市城中区潭中东路勿忘我网吧门口 152 | 2. Annotation: prov=广西壮族自治区 city=柳州市 district=城中区 town=潭中街道 road=潭中东路 poi=勿忘我网吧 assist=门口 153 | 154 | ####【redundant】 155 | 1. Interpretation: useless and redundant words 156 | 2. Example 1: 157 | 1. Query: 浙江省 杭州市 滨江区 浙江省 杭州市 滨江区 六和路东信大道自行车租赁点 158 | 2. Annotation: prov=浙江省 city=杭州市 district=滨江区 redundant=浙江省杭州市滨江区 road=六和路 subroad=东信大道 poi=自行车租赁点 159 | 160 | 3. Example 2: 161 | 1. Query: 浙江省 杭州市 滨江区 六和路 ---- 东信大道自行车租赁点 162 | 2. Annotation: prov=浙江省 city=杭州市 district=滨江区 road=六和路 redundant=---- subroad=东信大道 poi=自行车租赁点 163 | 164 | ####【otherinfo】 165 | 1. Interpretation: other information which cannot be classified 166 | 167 | ## Partial Order 168 | There exists partial orders among labels. For example, city appears after province in the address. We summarize the following partial orders. 169 | 170 | #### 1. prov > city > district > town > comm > road > roadno > poi > houseno > cellno > floorno > roomno 171 | #### 2. district > devzone 172 | #### 3. devzone > comm 173 | #### 4. road > subroad 174 | #### 5. poi > subpoi 175 | 176 | -------------------------------------------------------------------------------- /data/labels.txt: -------------------------------------------------------------------------------- 1 | country 2 | prov 3 | city 4 | district 5 | devzone 6 | town 7 | community 8 | road 9 | subroad 10 | roadno 11 | subroadno 12 | poi 13 | subpoi 14 | houseno 15 | cellno 16 | floorno 17 | roomno 18 | person 19 | assist 20 | redundant 21 | otherinfo 22 | -------------------------------------------------------------------------------- /exp_dytree.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | model="chartdyRBTC" 3 | treetype="dytree" 4 | 5 | for pretrain in giga 6 | do 7 | 8 | for dropout in 0.4 #0.1 0.2 0.3 #0.4 9 | do 10 | 11 | for lstmdim in 200 12 | do 13 | 14 | for batchsize in 1 #10 #20 15 | do 16 | 17 | for normal in 1 18 | do 19 | 20 | for RBTlabel in houseno 21 | do 22 | #prov #city district devzone town community road subroad roadno subroadno poi subpoi #houseno cellno floorno roomno person otherinfo assist redundant #country # 23 | 24 | 25 | for nontlabelstyle in 0 #0 #1 3 #1 2 26 | do 27 | 28 | for zerocostchunk in 0 #0 1 29 | do 30 | 31 | log="addr_"$treetype"_"$pretrain"_"$dropout"_"$lstmdim"_"$batchsize"_"$model"_"$treetype"_"$normal"_"$RBTlabel"_"$nontlabelstyle"_"$zerocostchunk 32 | echo $log".log" 33 | 34 | nohup python3 src/main_dyRBT.py train --parser-type $model --model-path-base models/$model-model --lstm-dim $lstmdim --label-hidden-dim $lstmdim --split-hidden-dim $lstmdim --pretrainemb $pretrain --batch-size $batchsize --epochs 30 --treetype $treetype --expname $log --normal $normal --checks-per-epoch 4 --RBTlabel $RBTlabel --nontlabelstyle $nontlabelstyle --dropout $dropout --zerocostchunk $zerocostchunk --loadmodel none >> $log".log" 2>&1 & 35 | 36 | # 37 | 38 | done 39 | done 40 | done 41 | done 42 | done 43 | done 44 | done 45 | done 46 | 47 | 48 | -------------------------------------------------------------------------------- /giga.emb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leodotnet/neural-chinese-address-parsing/54cf6bde941a152cc096f54294254fc4ee288fa7/giga.emb -------------------------------------------------------------------------------- /log.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leodotnet/neural-chinese-address-parsing/54cf6bde941a152cc096f54294254fc4ee288fa7/log.jpg -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os.path 3 | import re 4 | import subprocess 5 | import tempfile 6 | 7 | import trees 8 | 9 | 10 | class FScore(object): 11 | def __init__(self, recall, precision, fscore): 12 | self.recall = recall 13 | self.precision = precision 14 | self.fscore = fscore 15 | 16 | def __str__(self): 17 | return "(Recall={:.2f}%, Precision={:.2f}%, FScore={:.2f}%)".format( 18 | self.recall * 100, self.precision * 100, self.fscore * 100) 19 | 20 | def evalb(evalb_dir, gold_trees, predicted_trees, expname = "default"): 21 | assert os.path.exists(evalb_dir) 22 | evalb_program_path = os.path.join(evalb_dir, "evalb") 23 | evalb_param_path = os.path.join(evalb_dir, "COLLINS.prm") 24 | assert os.path.exists(evalb_program_path) 25 | assert os.path.exists(evalb_param_path) 26 | 27 | assert len(gold_trees) == len(predicted_trees) 28 | for gold_tree, predicted_tree in zip(gold_trees, predicted_trees): 29 | assert isinstance(gold_tree, trees.TreebankNode) 30 | assert isinstance(predicted_tree, trees.TreebankNode) 31 | gold_leaves = list(gold_tree.leaves()) 32 | predicted_leaves = list(predicted_tree.leaves()) 33 | assert len(gold_leaves) == len(predicted_leaves) 34 | assert all( 35 | gold_leaf.word == predicted_leaf.word 36 | for gold_leaf, predicted_leaf in zip(gold_leaves, predicted_leaves)) 37 | 38 | temp_dir = tempfile.TemporaryDirectory(prefix="evalb-") 39 | temp_dir = 'EVALB/' 40 | gold_path = os.path.join(temp_dir, expname + ".gold.txt") 41 | predicted_path = os.path.join(temp_dir, expname + ".predicted.txt") 42 | output_path = os.path.join(temp_dir, expname + ".output.txt") 43 | 44 | with open(gold_path, "w", encoding='utf-8') as outfile: 45 | for tree in gold_trees: 46 | outfile.write("{}\n".format(tree.linearize())) 47 | 48 | with open(predicted_path, "w", encoding='utf-8') as outfile: 49 | for tree in predicted_trees: 50 | outfile.write("{}\n".format(tree.linearize())) 51 | 52 | command = "{} -p {} {} {} > {}".format( 53 | evalb_program_path, 54 | evalb_param_path, 55 | gold_path, 56 | predicted_path, 57 | output_path, 58 | ) 59 | subprocess.run(command, shell=True) 60 | 61 | fscore = FScore(math.nan, math.nan, math.nan) 62 | with open(output_path) as infile: 63 | for line in infile: 64 | match = re.match(r"Bracketing Recall\s+=\s+(\d+\.\d+)", line) 65 | if match: 66 | fscore.recall = float(match.group(1)) 67 | match = re.match(r"Bracketing Precision\s+=\s+(\d+\.\d+)", line) 68 | if match: 69 | fscore.precision = float(match.group(1)) 70 | match = re.match(r"Bracketing FMeasure\s+=\s+(\d+\.\d+)", line) 71 | if match: 72 | fscore.fscore = float(match.group(1)) 73 | break 74 | 75 | success = ( 76 | not math.isnan(fscore.fscore) or 77 | fscore.recall == 0.0 or 78 | fscore.precision == 0.0) 79 | 80 | if success: 81 | pass 82 | #temp_dir.cleanup() 83 | else: 84 | print("Error reading EVALB results.") 85 | print("Gold path: {}".format(gold_path)) 86 | print("Predicted path: {}".format(predicted_path)) 87 | print("Output path: {}".format(output_path)) 88 | 89 | return fscore 90 | 91 | 92 | def count_common_chunks(chunk1, chunk2): 93 | common = 0 94 | for c1 in chunk1: 95 | for c2 in chunk2: 96 | if c1 == c2: 97 | common += 1 98 | 99 | return common 100 | 101 | 102 | def get_performance(match_num, gold_num, pred_num): 103 | p = (match_num + 0.0) / pred_num 104 | r = (match_num + 0.0) / gold_num 105 | 106 | try: 107 | f1 = 2 * p * r / (p + r) 108 | except ZeroDivisionError: 109 | f1 = 0.0 110 | 111 | return p, r, f1 112 | 113 | 114 | def get_text_from_chunks(chunks): 115 | text = [] 116 | for chunk in chunks: 117 | text += chunk[3] 118 | 119 | return text 120 | 121 | 122 | def chunk2seq(chunks:[]): 123 | seq = [] 124 | for label, start_pos, end_pos, text_list in chunks: 125 | seq.append('B-' + label) 126 | for i in range(start_pos, end_pos - 1): 127 | seq.append('I-' + label) 128 | 129 | return seq 130 | 131 | def chunks2str(chunks): 132 | #print(chunks) 133 | return ' '.join(["({} {})".format(label, ''.join(text_list)) for label, _, _, text_list in chunks]) 134 | 135 | def eval_chunks(evalb_dir, gold_trees, predicted_trees, output_filename = 'dev.out.txt'): 136 | match_num = 0 137 | gold_num = 0 138 | pred_num = 0 139 | 140 | 141 | 142 | fout = open(output_filename, 'w', encoding='utf-8') 143 | 144 | invalid_predicted_tree = 0 145 | 146 | for gold_tree, predicted_tree in zip(gold_trees, predicted_trees): 147 | 148 | # print(colored(gold_tree.linearize(), 'red')) 149 | # print(colored(predicted_tree.linearize(), 'yellow')) 150 | 151 | gold_chunks = gold_tree.to_chunks() 152 | predict_chunks = predicted_tree.to_chunks() 153 | input_seq = get_text_from_chunks(gold_chunks) 154 | gold_seq = chunk2seq(gold_chunks) 155 | predict_seq = chunk2seq(predict_chunks) 156 | 157 | if len(gold_seq) != len(predict_seq): 158 | invalid_predicted_tree += 1 159 | # print(colored('Error:', 'red')) 160 | # print(input_seq) 161 | # exit() 162 | 163 | o_list = zip(input_seq, gold_seq, predict_seq) 164 | fout.write('\n'.join(['\t'.join(x) for x in o_list])) 165 | fout.write('\n\n') 166 | 167 | 168 | # print(colored(chunks2str(gold_chunks), 'red')) 169 | # print(colored(chunks2str(predict_chunks), 'yellow')) 170 | # print() 171 | 172 | match_num += count_common_chunks(gold_chunks, predict_chunks) 173 | gold_num += len(gold_chunks) 174 | pred_num += len(predict_chunks) 175 | 176 | fout.close() 177 | 178 | #p, r, f1 = get_performance(match_num, gold_num, pred_num) 179 | # fscore = FScore(r, p, f1) 180 | # print('P,R,F: [{0:.2f}, {1:.2f}, {2:.2f}]'.format(p * 100, r * 100, f1 * 100), flush=True) 181 | print(output_filename) 182 | 183 | print('invalid_predicted_tree:',invalid_predicted_tree) 184 | 185 | cmdline = ["perl", "conlleval.pl"] 186 | cmd = subprocess.Popen(cmdline, stdin=open(output_filename, 'r', encoding='utf-8'), stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 187 | stdout, stderr = cmd.communicate() 188 | lines = stdout.decode("utf-8") 189 | print(lines) 190 | for line in lines.split('\n'): 191 | if line.startswith('accuracy'): 192 | #print('line:', line) 193 | for item in line.strip().split(';'): 194 | items = item.strip().split(':') 195 | if items[0].startswith('precision'): 196 | p = float(items[1].strip()[:-1]) / 100 197 | elif items[0].startswith('recall'): 198 | r = float(items[1].strip()[:-1]) / 100 199 | elif items[0].startswith('FB1'): 200 | f1 = float(items[1].strip()) / 100 201 | 202 | fscore = FScore(r, p, f1) 203 | print('P,R,F: [{0:.2f}, {1:.2f}, {2:.2f}]'.format(p * 100, r * 100, f1 * 100), flush=True) 204 | return fscore 205 | 206 | 207 | 208 | 209 | 210 | 211 | def eval_chunks2(evalb_dir, gold_chunks_list, pred_chunk_list, output_filename = 'dev.out.txt'): 212 | match_num = 0 213 | gold_num = 0 214 | pred_num = 0 215 | 216 | 217 | fout = open(output_filename, 'w', encoding='utf-8') 218 | 219 | invalid_predicted_tree = 0 220 | 221 | for gold_chunks, predict_chunks in zip(gold_chunks_list, pred_chunk_list): 222 | 223 | # print(colored(gold_tree.linearize(), 'red')) 224 | # print(colored(predicted_tree.linearize(), 'yellow')) 225 | 226 | #gold_chunks = gold_tree.to_chunks() 227 | #predict_chunks = predicted_tree.to_chunks() 228 | input_seq = get_text_from_chunks(gold_chunks) 229 | gold_seq = chunk2seq(gold_chunks) 230 | predict_seq = chunk2seq(predict_chunks) 231 | 232 | if len(gold_seq) != len(predict_seq): 233 | invalid_predicted_tree += 1 234 | # print(colored('Error:', 'red')) 235 | # print(input_seq) 236 | # exit() 237 | 238 | o_list = zip(input_seq, gold_seq, predict_seq) 239 | fout.write('\n'.join(['\t'.join(x) for x in o_list])) 240 | fout.write('\n\n') 241 | 242 | 243 | # print(colored(chunks2str(gold_chunks), 'red')) 244 | # print(colored(chunks2str(predict_chunks), 'yellow')) 245 | # print() 246 | 247 | match_num += count_common_chunks(gold_chunks, predict_chunks) 248 | gold_num += len(gold_chunks) 249 | pred_num += len(predict_chunks) 250 | 251 | fout.close() 252 | 253 | #p, r, f1 = get_performance(match_num, gold_num, pred_num) 254 | # fscore = FScore(r, p, f1) 255 | # print('P,R,F: [{0:.2f}, {1:.2f}, {2:.2f}]'.format(p * 100, r * 100, f1 * 100), flush=True) 256 | print(output_filename) 257 | 258 | print('invalid_predicted_tree:',invalid_predicted_tree) 259 | 260 | cmdline = ["perl", "conlleval.pl"] 261 | cmd = subprocess.Popen(cmdline, stdin=open(output_filename, 'r', encoding='utf-8'), stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 262 | stdout, stderr = cmd.communicate() 263 | lines = stdout.decode("utf-8") 264 | print(lines) 265 | for line in lines.split('\n'): 266 | if line.startswith('accuracy'): 267 | #print('line:', line) 268 | for item in line.strip().split(';'): 269 | items = item.strip().split(':') 270 | if items[0].startswith('precision'): 271 | p = float(items[1].strip()[:-1]) / 100 272 | elif items[0].startswith('recall'): 273 | r = float(items[1].strip()[:-1]) / 100 274 | elif items[0].startswith('FB1'): 275 | f1 = float(items[1].strip()) / 100 276 | 277 | fscore = FScore(r, p, f1) 278 | print('P,R,F: [{0:.2f}, {1:.2f}, {2:.2f}]'.format(p * 100, r * 100, f1 * 100), flush=True) 279 | return fscore 280 | -------------------------------------------------------------------------------- /src/latent.py: -------------------------------------------------------------------------------- 1 | import random 2 | #import vocabulary 3 | 4 | 5 | ASSIST = 'assist' 6 | REDUNDANT = 'redundant' 7 | NO = 'no' 8 | 9 | 10 | class latent_tree_builder: 11 | def __init__(self, label_vocab, RBT_order_before_label, non_terminal_label_mode = 0): 12 | ''' 13 | :param label_vocab: 14 | :param RBT_order_before_label: 15 | :param non_terminal_label_mode: 0, based on label order; 1, empty label ; 2, random label 16 | ''' 17 | self.label_vocab = label_vocab 18 | if RBT_order_before_label == 'none': 19 | self.RBT_order_before_idx = -1 20 | elif RBT_order_before_label == 'start': 21 | self.RBT_order_before_idx = 23 22 | else: 23 | self.RBT_order_before_idx = self.get_label_order(RBT_order_before_label) 24 | self.non_terminal_label_mode = non_terminal_label_mode 25 | self.label_size = 21 26 | 27 | 28 | def get_label_order(self, label_str): 29 | if label_str.endswith("'"): 30 | label_str = label_str[:-1] 31 | 32 | label = (label_str,) 33 | label_order = self.label_vocab.size - self.label_vocab.index(label) 34 | return label_order; 35 | 36 | 37 | def non_terminal_label(self, label): 38 | if label[-1] == "'": 39 | return label 40 | else: 41 | return label + "'" 42 | 43 | def terminal_label(self, label): 44 | if label[-1] == "'": 45 | return label[:-1] 46 | else: 47 | return label 48 | 49 | 50 | def get_parent_label(self, child1 : str, child2 : str, order_type = 0): 51 | parent = None 52 | 53 | # if order_type == 1: 54 | # return get_parent_label_reverse(child1, child2) 55 | # 56 | # if order_type == 2: 57 | # return get_parent_label_order(child1, child2) 58 | # 59 | # if order_type == 3: 60 | # return get_parent_label_empty(child1, child2) 61 | # 62 | # if order_type == 4: 63 | # return get_parent_label_semi(child1, child2) 64 | 65 | if child1.startswith(REDUNDANT) or child2.startswith(REDUNDANT): 66 | if child1.startswith(REDUNDANT) and child2.startswith(REDUNDANT): 67 | parent = self.non_terminal_label(REDUNDANT) 68 | elif child1.startswith(REDUNDANT): 69 | parent = self.non_terminal_label(child2) 70 | else: 71 | parent = self.non_terminal_label(child1) 72 | 73 | elif child1.startswith(ASSIST) or child2.startswith(ASSIST): 74 | if child1.startswith(ASSIST) and child2.startswith(ASSIST): 75 | parent = self.non_terminal_label(ASSIST) 76 | elif child1.startswith(ASSIST): 77 | parent = self.non_terminal_label(child2) 78 | else: 79 | parent = self.non_terminal_label(child1) 80 | 81 | 82 | else: 83 | 84 | last1_priority_id = self.get_label_order(self.terminal_label(child1)) 85 | last2_priority_id = self.get_label_order(self.terminal_label(child2)) 86 | 87 | 88 | if last1_priority_id >= last2_priority_id: 89 | parent = self.non_terminal_label(child1) 90 | else: 91 | parent = self.non_terminal_label(child2) 92 | 93 | 94 | 95 | return parent 96 | 97 | 98 | 99 | def build_latent_tree_str(self, x, chunks): 100 | 101 | RBT_order_before_idx = self.RBT_order_before_idx 102 | pos_POI = -1 103 | for i in range(len(chunks)): 104 | if chunks[i][0] == 'poi': 105 | pos_POI = i 106 | 107 | if pos_POI == -1: 108 | RBT_order_before_idx = 0 109 | 110 | 111 | idx = pos_POI 112 | chunk_aug = [(chunk[0], chunk[1], chunk[2],'(' + chunk[0] + ' ' + ' '.join(['(XX ' + item + ')' for item in x[chunk[1]:chunk[2]]]) + ' )') for chunk in chunks] 113 | 114 | 115 | def create_parent_node(left, right): 116 | parent_label = self.get_parent_label(left[0], right[0]) 117 | parent_left_boundary = left[1] 118 | parent_right_boundary = right[2] 119 | parent_str = '(' + parent_label + ' ' + left[3] + ' ' + right[3] + ' )' 120 | 121 | parent = (parent_label, parent_left_boundary, parent_right_boundary, parent_str) 122 | return parent 123 | 124 | #Build Random Tree 125 | while len(chunk_aug) > 1 and idx >= 0: 126 | label_order = self.get_label_order(chunk_aug[idx][0]) 127 | 128 | options = [] 129 | if label_order < RBT_order_before_idx: 130 | if idx + 1 < len(chunk_aug): 131 | label_order_next = self.get_label_order(chunk_aug[idx + 1][0]) 132 | if label_order_next < RBT_order_before_idx: 133 | options.append(idx + 1) 134 | 135 | if idx - 1 >= 0: 136 | label_order_prev = self.get_label_order(chunk_aug[idx - 1][0]) 137 | if label_order_prev < RBT_order_before_idx: 138 | options.append(idx - 1) 139 | 140 | if len(options) == 0: 141 | break 142 | 143 | 144 | if len(options) == 1: 145 | option = options[0] 146 | else: 147 | p = random.random() 148 | option = options[0 if p > 0.5 else 1] 149 | 150 | 151 | if option == idx - 1: 152 | left = chunk_aug[idx - 1] 153 | right = chunk_aug[idx] 154 | 155 | parent = create_parent_node(left, right) 156 | 157 | chunk_aug[idx - 1] = parent 158 | chunk_aug.remove(right) 159 | 160 | idx = idx - 1 161 | else: 162 | left = chunk_aug[idx] 163 | right = chunk_aug[idx + 1] 164 | 165 | parent = create_parent_node(left, right) 166 | 167 | chunk_aug[idx] = parent 168 | chunk_aug.remove(right) 169 | 170 | idx = idx 171 | 172 | 173 | #Build RBT tree for the rest 174 | while len(chunk_aug) > 1: 175 | idx = len(chunk_aug) - 1 176 | 177 | left = chunk_aug[idx - 1] 178 | right = chunk_aug[idx] 179 | 180 | parent = create_parent_node(left, right) 181 | 182 | chunk_aug[idx - 1] = parent 183 | chunk_aug.remove(right) 184 | 185 | 186 | return chunk_aug[0][3] 187 | 188 | def build_latent_tree(self, x, chunks): 189 | import util 190 | tree_str = self.build_latent_tree_str(x, chunks) 191 | tree = util.load_trees_from_str(tree_str, 0) 192 | return tree 193 | 194 | def build_latent_trees(self, insts): 195 | import util 196 | trees_str = '' 197 | for x, chunks in insts: 198 | tree_str = self.build_latent_tree_str(x, chunks) 199 | trees_str += tree_str + '\n' 200 | trees = util.load_trees_from_str(trees_str, 0) 201 | return trees 202 | 203 | 204 | 205 | 206 | def build_dynamicRBT_tree(self, x, chunks): 207 | 208 | from trees import InternalTreebankNode, LeafTreebankNode, InternalUncompletedTreebankNode, InternalParseChunkNode, InternalTreebankChunkNode 209 | import parse 210 | 211 | cut_off_point = -1 212 | for i in reversed(range(len(chunks))): 213 | label_order = self.get_label_order(chunks[i][0]) 214 | if label_order >= self.RBT_order_before_idx: 215 | cut_off_point = i 216 | break 217 | 218 | 219 | #Build RBT from [0, i] 220 | 221 | latentscope = (chunks[cut_off_point + 1][1] if cut_off_point + 1 < len(chunks) else len(x), len(x)) 222 | 223 | chunks_in_scope = chunks[cut_off_point+1:] 224 | chunks_in_scope = [(label, s, e, x[s:e]) for label, s, e in chunks_in_scope] 225 | if len(chunks_in_scope) > 0: 226 | chunkleaves = [] 227 | for label, s, e, text in chunks_in_scope: 228 | leaves = [] 229 | for ch in text: 230 | leaf = LeafTreebankNode(parse.XX, ch) 231 | leaves.append(leaf) 232 | 233 | chunk_leaf = InternalTreebankNode(label, leaves) #InternalTreebankChunkNode 234 | chunkleaves.append(chunk_leaf) 235 | 236 | if self.non_terminal_label_mode == 0 or self.non_terminal_label_mode == 3: 237 | label = max([chunk[0] for chunk in chunks_in_scope], key=lambda l: self.get_label_order(l)) 238 | label = self.non_terminal_label(label) 239 | elif self.non_terminal_label_mode == 1: 240 | label = parse.EMPTY 241 | else: #self.non_terminal_label_mode == 2: 242 | import random 243 | label_id = random.randint(1 + self.latent.label_size + 0, 1 + self.latent.label_size + self.latent.label_size - 1) 244 | label = self.label_vocab.value(label_id) 245 | 246 | latent_area = [InternalUncompletedTreebankNode(label, chunkleaves, chunks_in_scope, self)] 247 | else: 248 | latent_area = [] 249 | 250 | RBT_chunks = list(chunks[:cut_off_point+1]) + latent_area #(parse.EMPTY, chunks[i+1][1], chunks[-1][2]) 251 | 252 | if len(latent_area) == 0: 253 | label, s, e = chunks[-1] 254 | text = x[s:e] 255 | leaves = [] 256 | for ch in text: 257 | leaf = LeafTreebankNode(parse.XX, ch) 258 | leaves.append(leaf) 259 | 260 | RBT_chunks[-1] = InternalTreebankNode(label, leaves) 261 | 262 | while len(RBT_chunks) > 1: 263 | 264 | second_last_chunk = RBT_chunks[-2] 265 | 266 | second_last_children = [] 267 | for pos in range(second_last_chunk[1], second_last_chunk[2]): 268 | second_last_children.append(LeafTreebankNode(parse.XX, x[pos])) 269 | 270 | second_last_node = InternalTreebankNode(second_last_chunk[0], second_last_children) #InternalTreebankChunkNode 271 | 272 | last_node = RBT_chunks[-1] 273 | 274 | if self.non_terminal_label_mode == 0 or self.non_terminal_label_mode == 3: 275 | parent_label = self.get_parent_label(second_last_chunk[0], last_node.label) 276 | elif self.non_terminal_label_mode == 1: 277 | parent_label = parse.EMPTY 278 | else: #self.non_terminal_label_mode == 2: 279 | import random 280 | label_id = random.randint(1 + self.latent.label_size + 0, 1 + self.latent.label_size + self.latent.label_size - 1) 281 | parent_label = self.label_vocab.value(label_id) 282 | 283 | parent_node = InternalTreebankNode(parent_label, [second_last_node, last_node]) 284 | 285 | 286 | 287 | RBT_chunks[-2] = parent_node 288 | RBT_chunks.remove(last_node) 289 | 290 | tree = RBT_chunks[0] 291 | return x, tree, chunks, latentscope 292 | 293 | def build_dynamicRBT_trees(self, insts): 294 | trees = [] 295 | for x, chunks in insts: 296 | x, tree, chunks, latentscope = self.build_dynamicRBT_tree(x, chunks) 297 | trees.append((x, tree, chunks, latentscope)) 298 | return trees 299 | 300 | def main_test(): 301 | import util 302 | import vocabulary 303 | import parse 304 | label_list = util.load_label_list('data/labels.txt') 305 | label_vocab = vocabulary.Vocabulary() 306 | 307 | label_vocab.index(()) 308 | 309 | 310 | for item in label_list: 311 | label_vocab.index((item,)) 312 | 313 | for item in label_list: 314 | label_vocab.index((item + "'",)) 315 | 316 | label_vocab.index((parse.EMPTY,)) 317 | 318 | label_vocab.freeze() 319 | 320 | latent = latent_tree_builder(label_vocab, 'city') 321 | 322 | insts = util.read_chunks('data/trial.txt') 323 | 324 | 325 | # for k in range(3): 326 | # trees = latent.build_latent_trees(insts) 327 | # for tree in trees: 328 | # print(tree.linearize()) 329 | # print() 330 | 331 | 332 | trees = latent.build_dynamicRBT_trees(insts) 333 | for x, tree, chunks, latentscope in trees: 334 | print(tree.linearize()) 335 | tree = tree.convert() 336 | print() 337 | tree = tree.convert() 338 | print() 339 | 340 | 341 | 342 | #main_test() -------------------------------------------------------------------------------- /src/latenttrees.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | import parse 3 | import util 4 | 5 | class TreebankNode(object): 6 | pass 7 | 8 | class InternalTreebankNode(TreebankNode): 9 | def __init__(self, label, children): 10 | assert isinstance(label, str) 11 | self.label = label 12 | 13 | assert isinstance(children, collections.abc.Sequence) 14 | assert all(isinstance(child, TreebankNode) for child in children) 15 | #assert children 16 | self.children = tuple(children) 17 | 18 | def to_chunks(self): 19 | raw_chunks = [] 20 | self.chunk_helper(raw_chunks) 21 | 22 | chunks = [] 23 | p = 0 24 | for label, text_list in raw_chunks: 25 | if label.endswith("'"): 26 | label = label[:-1] 27 | chunks.append((label, p, p + len(text_list), text_list)) 28 | p += len(text_list) 29 | 30 | return chunks 31 | 32 | 33 | def chunk_helper(self, chunks): 34 | 35 | children_status = [isinstance(child, LeafTreebankNode) for child in self.children] 36 | 37 | if all(children_status): 38 | chunks.append((self.label, [child.word for child in self.children])) 39 | elif any(children_status): 40 | char_list = [] 41 | 42 | for child in self.children: 43 | if isinstance(child, InternalTreebankNode): 44 | char_list += child.get_word_list() 45 | else: 46 | char_list.append(child.word) 47 | 48 | 49 | chunk = (self.label, char_list) 50 | chunks.append(chunk) 51 | else: 52 | for child in self.children: 53 | if isinstance(child, InternalTreebankNode): 54 | child.chunk_helper(chunks) 55 | 56 | 57 | def get_word_list(self): 58 | word_list = [] 59 | for child in self.children: 60 | if isinstance(child, InternalTreebankNode): 61 | word_list += child.get_word_list() 62 | else: 63 | word_list += [child.word] 64 | return word_list 65 | 66 | def linearize(self): 67 | return "({} {})".format( 68 | self.label, " ".join(child.linearize() for child in self.children)) 69 | 70 | def leaves(self): 71 | for child in self.children: 72 | yield from child.leaves() 73 | 74 | def convert(self, index=0): 75 | tree = self 76 | sublabels = [self.label] 77 | 78 | while len(tree.children) == 1 and isinstance(tree.children[0], InternalTreebankNode): 79 | tree = tree.children[0] 80 | sublabels.append(tree.label) 81 | 82 | children = [] 83 | for child in tree.children: 84 | children.append(child.convert(index=index)) 85 | index = children[-1].right 86 | 87 | return InternalParseNode(tuple(sublabels), children) 88 | 89 | 90 | class LeafTreebankNode(TreebankNode): 91 | def __init__(self, tag, word): 92 | assert isinstance(tag, str) 93 | self.tag = tag 94 | 95 | assert isinstance(word, str) 96 | self.word = word 97 | 98 | def linearize(self): 99 | return "({} {})".format(self.tag, self.word) 100 | 101 | def leaves(self): 102 | yield self 103 | 104 | def convert(self, index=0): 105 | return LeafParseNode(index, self.tag, self.word) 106 | 107 | 108 | 109 | class InternalUncompletedTreebankNode(TreebankNode): 110 | def __init__(self, label, chunkleaves, chunks_in_scope, latent): 111 | assert isinstance(label, str) 112 | self.label = label 113 | self.latent = latent 114 | #assert isinstance(children, collections.abc.Sequence) 115 | #assert all(isinstance(child, TreebankNode) for child in children) 116 | #assert children 117 | self.children = () #tuple(children) 118 | self.chunkleaves = chunkleaves 119 | self.chunks_in_scope = chunks_in_scope 120 | 121 | 122 | def linearize(self): 123 | 124 | #chunks_str_list = [ '(' + label + ' ' + ''.join( '(XX ' + item + ')' for item in text) + ')' for label, s, e, text in self.chunks_in_scope] 125 | return "({} leaves: {})".format(self.label, " ".join(leaf.linearize() for leaf in self.chunkleaves)) 126 | 127 | def leaves(self): 128 | return self.chunkleaves 129 | 130 | def convert(self, index=0): 131 | tree = self 132 | sublabels = [self.label] 133 | 134 | # while len(tree.children) == 1 and isinstance(tree.children[0], InternalTreebankNode): 135 | # tree = tree.children[0] 136 | # sublabels.append(tree.label) 137 | 138 | #children = [] 139 | # for child in tree.children: 140 | # children.append(child.convert(index=index)) 141 | # index = children[-1].right 142 | #index = self.chunks_in_scope[0][1] 143 | chunkleaves = [] 144 | for chunkleaf in self.chunkleaves: 145 | chunkleaves.append(chunkleaf.convert(index=index)) 146 | index = chunkleaves[-1].right 147 | # for i in range(len(self.chunkleaves)): 148 | # chunk = self.chunks_in_scope[i] 149 | # chunkleaf = self.chunkleaves[i] 150 | # chunkleaf.convert(index=chunk[1]) 151 | # chunkleaves.append(chunkleaf) 152 | 153 | return InternalUncompletedParseNode(tuple(sublabels), chunkleaves, self.chunks_in_scope, self.latent) 154 | 155 | 156 | 157 | class ParseNode(object): 158 | pass 159 | 160 | class InternalParseNode(ParseNode): 161 | def __init__(self, label, left, right, children): 162 | assert isinstance(label, tuple) 163 | assert all(isinstance(sublabel, str) for sublabel in label) 164 | assert label 165 | self.label = label 166 | 167 | assert isinstance(children, collections.abc.Sequence) 168 | assert all(isinstance(child, ParseNode) for child in children) 169 | assert children 170 | assert len(children) > 1 or isinstance(children[0], LeafParseNode) 171 | assert all( 172 | left.right == right.left 173 | for left, right in zip(children, children[1:])) 174 | self.children = tuple(children) 175 | 176 | self.left = children[0].left 177 | self.right = children[-1].right 178 | 179 | def leaves(self): 180 | for child in self.children: 181 | yield from child.leaves() 182 | 183 | def convert(self): 184 | children = [child.convert() for child in self.children] 185 | tree = InternalTreebankNode(self.label[-1], children) 186 | for sublabel in reversed(self.label[:-1]): 187 | tree = InternalTreebankNode(sublabel, [tree]) 188 | return tree 189 | 190 | def enclosing(self, left, right): 191 | assert self.left <= left < right <= self.right 192 | for child in self.children: 193 | if isinstance(child, LeafParseNode): 194 | continue 195 | if child.left <= left < right <= child.right: 196 | return child.enclosing(left, right) 197 | return self 198 | 199 | def oracle_label(self, left, right): 200 | enclosing = self.enclosing(left, right) 201 | if enclosing.left == left and enclosing.right == right: 202 | return enclosing.label 203 | return () 204 | 205 | def oracle_splits(self, left, right): 206 | # return [ 207 | # child.left 208 | # for child in self.enclosing(left, right).children 209 | # if left < child.left < right 210 | # ] 211 | enclosing = self.enclosing(left, right) 212 | return [ 213 | child.left 214 | for child in enclosing.children 215 | if left < child.left < right 216 | ] 217 | # if isinstance(enclosing, InternalUncompletedParseNode): 218 | # return [ 219 | # child.left 220 | # for child in enclosing.chunkleaves 221 | # if left < child.left < right 222 | # ] 223 | # else: 224 | # return [ 225 | # child.left 226 | # for child in enclosing.children 227 | # if left < child.left < right 228 | # ] 229 | 230 | 231 | 232 | class InternalUncompletedParseNode(InternalParseNode): 233 | def __init__(self, label, chunkleaves, chunks_in_scope:[], latent): 234 | assert isinstance(label, tuple) 235 | assert all(isinstance(sublabel, str) for sublabel in label) 236 | assert label 237 | self.label = label 238 | self.latent = latent 239 | 240 | #assert isinstance(children, collections.abc.Sequence) 241 | #assert all(isinstance(child, ParseNode) for child in children) 242 | #assert children 243 | #assert len(children) > 1 or isinstance(children[0], LeafParseNode) 244 | # assert all( 245 | # left.right == right.left 246 | # for left, right in zip(children, children[1:])) 247 | self.children = [] #tuple(children) 248 | self.chunkleaves = chunkleaves 249 | 250 | #self.left = children[0].left 251 | #self.right = children[-1].right 252 | self.left = chunkleaves[0].left #chunks_in_scope[0][1] 253 | self.right = chunkleaves[-1].right #chunks_in_scope[-1][2] 254 | self.chunks_in_scope = chunks_in_scope 255 | self.splits = [chunk[1] for chunk in self.chunks_in_scope] + [self.chunks_in_scope[-1][2]] 256 | 257 | def leaves(self): 258 | return self.chunkleaves 259 | 260 | def convert(self): 261 | # children = [child.convert() for child in self.children] 262 | # tree = InternalUncompletedTreebankNode(self.label[-1], children) 263 | # for sublabel in reversed(self.label[:-1]): 264 | # tree = InternalUncompletedTreebankNode(sublabel, [tree]) 265 | 266 | chunkleaves = [chunkleaf.convert() for chunkleaf in self.chunkleaves] 267 | tree = InternalUncompletedTreebankNode(self.label[-1], chunkleaves, (self.left, self.right), self.latent) 268 | return tree 269 | 270 | def enclosing(self, left, right): 271 | assert self.left <= left < right <= self.right 272 | for chunkleaf in self.chunkleaves: 273 | if isinstance(chunkleaf, LeafParseNode): 274 | continue 275 | if chunkleaf.left <= left < right <= chunkleaf.right: 276 | return chunkleaf.enclosing(left, right) 277 | 278 | # if left in self.splits and right in self.splits: 279 | # children = [chunkleaf for chunkleaf in self.chunkleaves if left <= chunkleaf.left and chunkleaf.right <= right] 280 | # # label = max([child.label for child in children],key=lambda l:self.latent.get_label_order(l[0])) 281 | # # label = (self.latent.non_terminal_label(label[0]),) 282 | # label = self.label 283 | # return InternalParseNode(label, children) 284 | 285 | children = [chunkleaf for chunkleaf in self.chunkleaves if left < chunkleaf.right] 286 | children = [chunkleaf for chunkleaf in children if right > chunkleaf.left] 287 | 288 | if self.latent.non_terminal_label_mode == 0: 289 | label = max([child.label for child in children], key=lambda l: self.latent.get_label_order(l[0])) 290 | label = (self.latent.non_terminal_label(label[0]),) 291 | elif self.latent.non_terminal_label_mode == 1: 292 | label = self.label 293 | else: #self.latent.non_terminal_label_mode == 2: 294 | import random 295 | label_id = random.randint(1 + self.latent.label_size + 0, 1 + self.latent.label_size + self.latent.label_size - 1) 296 | label = self.latent.label_vocab.value(label_id) 297 | return InternalParseNode(label, children) 298 | 299 | # self.children = self.chunkleaves 300 | # return self 301 | 302 | def oracle_label(self, left, right): 303 | # enclosing = self.enclosing(left, right) 304 | # if enclosing.left == left and enclosing.right == right: 305 | # return enclosing.label 306 | 307 | for chunk in self.chunks_in_scope: 308 | if chunk[1] == left and chunk[2] == right: 309 | return (chunk[0],) 310 | 311 | 312 | if left in self.splits and right in self.splits: 313 | return self.label 314 | 315 | return () 316 | 317 | def oracle_splits(self, left, right): 318 | 319 | 320 | ret = [p for p in self.splits if left < p and p < right] 321 | if len(ret) == 0: 322 | return [ 323 | child.left 324 | for child in self.enclosing(left, right).children 325 | if left < child.left < right 326 | ] 327 | 328 | 329 | return ret 330 | 331 | 332 | 333 | class LeafParseNode(ParseNode): 334 | def __init__(self, index, tag, word): 335 | assert isinstance(index, int) 336 | assert index >= 0 337 | self.left = index 338 | self.right = index + 1 339 | 340 | assert isinstance(tag, str) 341 | self.tag = tag 342 | 343 | assert isinstance(word, str) 344 | self.word = word 345 | 346 | def leaves(self): 347 | yield self 348 | 349 | def convert(self): 350 | return LeafTreebankNode(self.tag, self.word) 351 | 352 | 353 | def load_trees(path, normal, strip_top=True): 354 | with open(path, 'r', encoding='utf-8') as infile: 355 | tokens = infile.read().replace("(", " ( ").replace(")", " ) ").split() 356 | 357 | def helper(index): 358 | trees = [] 359 | 360 | while index < len(tokens) and tokens[index] == "(": 361 | paren_count = 0 362 | while tokens[index] == "(": 363 | index += 1 364 | paren_count += 1 365 | 366 | label = tokens[index] 367 | index += 1 368 | 369 | if tokens[index] == "(": 370 | children, index = helper(index) 371 | trees.append(InternalTreebankNode(label, children)) 372 | else: 373 | word = tokens[index] 374 | if normal == 1: 375 | newword = '' 376 | for c in word: 377 | if util.is_digit(c): 378 | newword += '0' 379 | else: 380 | newword += c 381 | else: 382 | newword = word 383 | index += 1 384 | trees.append(LeafTreebankNode(label, newword)) 385 | 386 | while paren_count > 0: 387 | assert tokens[index] == ")" 388 | index += 1 389 | paren_count -= 1 390 | 391 | return trees, index 392 | 393 | trees, index = helper(0) 394 | assert index == len(tokens) 395 | 396 | if strip_top: 397 | for i, tree in enumerate(trees): 398 | if tree.label == "TOP": 399 | assert len(tree.children) == 1 400 | trees[i] = tree.children[0] 401 | 402 | return trees 403 | -------------------------------------------------------------------------------- /src/main_dyRBT.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import os.path 4 | import time 5 | 6 | seed = 3986067777 7 | import dynet_config 8 | dynet_config.set(random_seed = seed) 9 | import dynet as dy 10 | 11 | import numpy as np 12 | 13 | import evaluate 14 | import parse 15 | import trees 16 | import vocabulary 17 | 18 | import latent 19 | import util 20 | 21 | 22 | def format_elapsed(start_time): 23 | elapsed_time = int(time.time() - start_time) 24 | minutes, seconds = divmod(elapsed_time, 60) 25 | hours, minutes = divmod(minutes, 60) 26 | days, hours = divmod(hours, 24) 27 | elapsed_string = "{}h{:02}m{:02}s".format(hours, minutes, seconds) 28 | if days > 0: 29 | elapsed_string = "{}d{}".format(days, elapsed_string) 30 | return elapsed_string 31 | 32 | 33 | 34 | 35 | def run_train(args): 36 | 37 | args.numpy_seed = seed 38 | if args.numpy_seed is not None: 39 | print("Setting numpy random seed to {}...".format(args.numpy_seed)) 40 | np.random.seed(args.numpy_seed) 41 | 42 | 43 | if args.trial == 1: 44 | args.train_path = 'data/trial.txt' 45 | args.dev_path = 'data/trial.txt' 46 | args.test_path = 'data/trial.txt' 47 | 48 | # args.train_path = args.train_path.replace('[*]', args.treetype) 49 | # args.dev_path = args.dev_path.replace('[*]', args.treetype) 50 | # args.test_path = args.test_path.replace('[*]', args.treetype) 51 | 52 | print("Loading training trees from {}...".format(args.train_path)) 53 | train_chunk_insts = util.read_chunks(args.train_path, args.normal) 54 | print("Loaded {:,} training examples.".format(len(train_chunk_insts))) 55 | 56 | print("Loading development trees from {}...".format(args.dev_path)) 57 | dev_chunk_insts = util.read_chunks(args.dev_path, args.normal) 58 | print("Loaded {:,} development examples.".format(len(dev_chunk_insts))) 59 | 60 | print("Loading test trees from {}...".format(args.test_path)) 61 | test_chunk_insts = util.read_chunks(args.test_path, args.normal) 62 | print("Loaded {:,} test examples.".format(len(test_chunk_insts))) 63 | 64 | # print("Processing trees for training...") 65 | # train_parse = [tree.convert() for tree in train_treebank] 66 | 67 | print("Constructing vocabularies...") 68 | 69 | tag_vocab = vocabulary.Vocabulary() 70 | tag_vocab.index(parse.START) 71 | tag_vocab.index(parse.STOP) 72 | tag_vocab.index(parse.XX) 73 | 74 | 75 | word_vocab = vocabulary.Vocabulary() 76 | word_vocab.index(parse.START) 77 | word_vocab.index(parse.STOP) 78 | word_vocab.index(parse.UNK) 79 | word_vocab.index(parse.NUM) 80 | 81 | for x, chunks in train_chunk_insts + dev_chunk_insts + test_chunk_insts: 82 | for ch in x: 83 | word_vocab.index(ch) 84 | 85 | label_vocab = vocabulary.Vocabulary() 86 | label_vocab.index(()) 87 | 88 | 89 | label_list = util.load_label_list(args.labellist_path) #'data/labels.txt') 90 | for item in label_list: 91 | label_vocab.index((item, )) 92 | 93 | if args.nontlabelstyle != 1: 94 | for item in label_list: 95 | label_vocab.index((item + "'",)) 96 | 97 | if args.nontlabelstyle == 1: 98 | label_vocab.index((parse.EMPTY,)) 99 | 100 | tag_vocab.freeze() 101 | word_vocab.freeze() 102 | label_vocab.freeze() 103 | 104 | latent_tree = latent.latent_tree_builder(label_vocab, args.RBTlabel, args.nontlabelstyle) 105 | 106 | def print_vocabulary(name, vocab): 107 | special = {parse.START, parse.STOP, parse.UNK} 108 | print("{} ({:,}): {}".format( 109 | name, vocab.size, 110 | sorted(value for value in vocab.values if value in special) + 111 | sorted(value for value in vocab.values if value not in special))) 112 | 113 | if args.print_vocabs: 114 | print_vocabulary("Tag", tag_vocab) 115 | print_vocabulary("Word", word_vocab) 116 | print_vocabulary("Label", label_vocab) 117 | 118 | print("Initializing model...") 119 | 120 | pretrain = {'giga':'data/giga.vec100', 'none':'none'} 121 | pretrainemb = util.load_pretrain(pretrain[args.pretrainemb], args.word_embedding_dim, word_vocab) 122 | 123 | model = dy.ParameterCollection() 124 | if args.parser_type == "chartdyRBTC": 125 | parser = parse.ChartDynamicRBTConstraintParser( 126 | model, 127 | tag_vocab, 128 | word_vocab, 129 | label_vocab, 130 | args.tag_embedding_dim, 131 | args.word_embedding_dim, 132 | args.lstm_layers, 133 | args.lstm_dim, 134 | args.label_hidden_dim, 135 | args.dropout, 136 | (args.pretrainemb, pretrainemb), 137 | args.chunkencoding, 138 | args.trainc == 1, 139 | True, 140 | (args.zerocostchunk == 1), 141 | ) 142 | 143 | 144 | else: 145 | print('Model is not valid!') 146 | exit() 147 | 148 | if args.loadmodel != 'none': 149 | tmp = dy.load(args.loadmodel, model) 150 | parser = tmp[0] 151 | print('Model is loaded from ', args.loadmodel) 152 | 153 | trainer = dy.AdamTrainer(model) 154 | 155 | total_processed = 0 156 | current_processed = 0 157 | check_every = len(train_chunk_insts) / args.checks_per_epoch 158 | best_dev_fscore = -np.inf 159 | best_dev_model_path = None 160 | 161 | start_time = time.time() 162 | 163 | def check_dev(): 164 | nonlocal best_dev_fscore 165 | nonlocal best_dev_model_path 166 | 167 | dev_start_time = time.time() 168 | 169 | dev_predicted = [] 170 | #dev_gold = [] 171 | 172 | #dev_gold = latent_tree.build_latent_trees(dev_chunk_insts) 173 | dev_gold = [] 174 | for inst in dev_chunk_insts: 175 | chunks = util.inst2chunks(inst) 176 | dev_gold.append(chunks) 177 | 178 | for x, chunks in dev_chunk_insts: 179 | dy.renew_cg() 180 | #sentence = [(leaf.tag, leaf.word) for leaf in tree.leaves()] 181 | sentence = [(parse.XX, ch) for ch in x] 182 | predicted, _ = parser.parse(sentence) 183 | dev_predicted.append(predicted.convert().to_chunks()) 184 | 185 | 186 | #dev_fscore = evaluate.evalb(args.evalb_dir, dev_gold, dev_predicted, args.expname + '.dev.') #evalb 187 | dev_fscore = evaluate.eval_chunks2(args.evalb_dir, dev_gold, dev_predicted, output_filename=args.expname + '.dev.txt') # evalb 188 | 189 | 190 | print( 191 | "dev-fscore {} " 192 | "dev-elapsed {} " 193 | "total-elapsed {}".format( 194 | dev_fscore, 195 | format_elapsed(dev_start_time), 196 | format_elapsed(start_time), 197 | ) 198 | ) 199 | 200 | 201 | if dev_fscore.fscore > best_dev_fscore: 202 | if best_dev_model_path is not None: 203 | for ext in [".data", ".meta"]: 204 | path = best_dev_model_path + ext 205 | if os.path.exists(path): 206 | print("Removing previous model file {}...".format(path)) 207 | os.remove(path) 208 | 209 | best_dev_fscore = dev_fscore.fscore 210 | best_dev_model_path = "{}_dev={:.2f}".format(args.model_path_base + "_" + args.expname, dev_fscore.fscore) 211 | print("Saving new best model to {}...".format(best_dev_model_path)) 212 | dy.save(best_dev_model_path, [parser]) 213 | 214 | test_start_time = time.time() 215 | test_predicted = [] 216 | #test_gold = latent_tree.build_latent_trees(test_chunk_insts) 217 | test_gold = [] 218 | for inst in test_chunk_insts: 219 | chunks = util.inst2chunks(inst) 220 | test_gold.append(chunks) 221 | 222 | ftreelog = open(args.expname + '.test.predtree.txt', 'w', encoding='utf-8') 223 | 224 | for x, chunks in test_chunk_insts: 225 | dy.renew_cg() 226 | #sentence = [(leaf.tag, leaf.word) for leaf in tree.leaves()] 227 | sentence = [(parse.XX, ch) for ch in x] 228 | predicted, _ = parser.parse(sentence) 229 | pred_tree = predicted.convert() 230 | ftreelog.write(pred_tree.linearize() + '\n') 231 | test_predicted.append(pred_tree.to_chunks()) 232 | 233 | 234 | 235 | ftreelog.close() 236 | 237 | #test_fscore = evaluate.evalb(args.evalb_dir, test_chunk_insts, test_predicted, args.expname + '.test.') 238 | test_fscore = evaluate.eval_chunks2(args.evalb_dir, test_gold, test_predicted, output_filename=args.expname + '.test.txt') # evalb 239 | 240 | print( 241 | "epoch {:,} " 242 | "test-fscore {} " 243 | "test-elapsed {} " 244 | "total-elapsed {}".format( 245 | epoch, 246 | test_fscore, 247 | format_elapsed(test_start_time), 248 | format_elapsed(start_time), 249 | ) 250 | ) 251 | 252 | 253 | train_trees = latent_tree.build_dynamicRBT_trees(train_chunk_insts) 254 | train_trees = [(x, tree.convert(), chunks, latentscope) for x, tree, chunks, latentscope in train_trees] 255 | 256 | for epoch in itertools.count(start=1): 257 | if args.epochs is not None and epoch > args.epochs: 258 | break 259 | 260 | np.random.shuffle(train_chunk_insts) 261 | epoch_start_time = time.time() 262 | 263 | for start_index in range(0, len(train_chunk_insts), args.batch_size): 264 | dy.renew_cg() 265 | batch_losses = [] 266 | 267 | 268 | for x, tree, chunks, latentscope in train_trees[start_index:start_index + args.batch_size]: 269 | 270 | discard = False 271 | for chunk in chunks: 272 | length = chunk[2] - chunk[1] 273 | if length > args.maxllimit: 274 | discard = True 275 | break 276 | 277 | if discard: 278 | continue 279 | print('discard') 280 | 281 | 282 | sentence = [(parse.XX, ch) for ch in x] 283 | if args.parser_type == "top-down": 284 | _, loss = parser.parse(sentence, tree, args.explore) 285 | else: 286 | _, loss = parser.parse(sentence, tree, chunks, latentscope) 287 | batch_losses.append(loss) 288 | total_processed += 1 289 | current_processed += 1 290 | 291 | 292 | batch_loss = dy.average(batch_losses) 293 | batch_loss_value = batch_loss.scalar_value() 294 | batch_loss.backward() 295 | trainer.update() 296 | 297 | print( 298 | "Epoch {:,} " 299 | "batch {:,}/{:,} " 300 | "processed {:,} " 301 | "batch-loss {:.4f} " 302 | "epoch-elapsed {} " 303 | "total-elapsed {}".format( 304 | epoch, 305 | start_index // args.batch_size + 1, 306 | int(np.ceil(len(train_chunk_insts) / args.batch_size)), 307 | total_processed, 308 | batch_loss_value, 309 | format_elapsed(epoch_start_time), 310 | format_elapsed(start_time), 311 | ), flush=True 312 | ) 313 | 314 | if current_processed >= check_every: 315 | current_processed -= check_every 316 | if epoch > 7: 317 | check_dev() 318 | 319 | 320 | def run_test(args): 321 | #args.test_path = args.test_path.replace('[*]', args.treetype) 322 | print("Loading test trees from {}...".format(args.test_path)) 323 | test_treebank = trees.load_trees(args.test_path, args.normal) 324 | print("Loaded {:,} test examples.".format(len(test_treebank))) 325 | 326 | print("Loading model from {}...".format(args.model_path_base)) 327 | model = dy.ParameterCollection() 328 | [parser] = dy.load(args.model_path_base, model) 329 | 330 | label_vocab = vocabulary.Vocabulary() 331 | 332 | label_list = util.load_label_list('../data/labels.txt') 333 | for item in label_list: 334 | label_vocab.index((item, )) 335 | label_vocab.index((parse.EMPTY,)) 336 | for item in label_list: 337 | label_vocab.index((item + "'",)) 338 | 339 | label_vocab.freeze() 340 | latent_tree = latent.latent_tree_builder(label_vocab, args.RBTlabel) 341 | 342 | 343 | print("Parsing test sentences...") 344 | 345 | start_time = time.time() 346 | 347 | test_predicted = [] 348 | test_gold = latent_tree.build_latent_trees(test_treebank) 349 | for x, chunks in test_treebank: 350 | dy.renew_cg() 351 | #sentence = [(leaf.tag, leaf.word) for leaf in tree.leaves()] 352 | sentence = [(parse.XX, ch) for ch in x] 353 | predicted, _ = parser.parse(sentence) 354 | test_predicted.append(predicted.convert()) 355 | 356 | #test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, test_predicted, args.expname + '.test.') 357 | test_fscore = evaluate.eval_chunks(args.evalb_dir, test_gold, test_predicted, output_filename=args.expname + '.finaltest.txt') # evalb 358 | print( 359 | "test-fscore {} " 360 | "test-elapsed {}".format( 361 | test_fscore, 362 | format_elapsed(start_time), 363 | ) 364 | ) 365 | 366 | def main(): 367 | dynet_args = [ 368 | "--dynet-mem", 369 | "--dynet-weight-decay", 370 | "--dynet-autobatch", 371 | "--dynet-gpus", 372 | "--dynet-gpu", 373 | "--dynet-devices", 374 | "--dynet-seed", 375 | ] 376 | 377 | parser = argparse.ArgumentParser() 378 | subparsers = parser.add_subparsers() 379 | 380 | subparser = subparsers.add_parser("train") 381 | subparser.set_defaults(callback=run_train) 382 | for arg in dynet_args: 383 | subparser.add_argument(arg) 384 | subparser.add_argument("--numpy-seed", type=int) 385 | subparser.add_argument("--parser-type", choices=["top-down", "chartdyRBT", "chartdyRBTchunk", "chartdyRBTC", "chartdyRBTCseg"], required=True) 386 | subparser.add_argument("--tag-embedding-dim", type=int, default=50) 387 | subparser.add_argument("--word-embedding-dim", type=int, default=100) 388 | subparser.add_argument("--lstm-layers", type=int, default=2) 389 | subparser.add_argument("--lstm-dim", type=int, default=250) 390 | subparser.add_argument("--label-hidden-dim", type=int, default=250) 391 | subparser.add_argument("--split-hidden-dim", type=int, default=250) 392 | subparser.add_argument("--dropout", type=float, default=0.4) 393 | subparser.add_argument("--explore", action="store_true") 394 | subparser.add_argument("--model-path-base", required=True) 395 | subparser.add_argument("--evalb-dir", default="EVALB/") 396 | subparser.add_argument("--train-path", default="data/train.txt") 397 | subparser.add_argument("--dev-path", default="data/dev.txt") 398 | subparser.add_argument("--labellist-path", default="data/labels.txt") 399 | subparser.add_argument("--batch-size", type=int, default=10) 400 | subparser.add_argument("--epochs", type=int) 401 | subparser.add_argument("--checks-per-epoch", type=int, default=4) 402 | subparser.add_argument("--print-vocabs", action="store_true") 403 | subparser.add_argument("--test-path", default="data/test.txt") 404 | subparser.add_argument("--pretrainemb", default="giga") 405 | subparser.add_argument("--treetype", default="NRBT") 406 | subparser.add_argument("--expname", default="default") 407 | subparser.add_argument("--chunkencoding", type=int, default=1) 408 | subparser.add_argument("--trial", type=int, default=0) 409 | subparser.add_argument("--normal", type=int, default=1) 410 | subparser.add_argument("--RBTlabel", type=str, default="city") 411 | subparser.add_argument("--nontlabelstyle", type=int, default=0) 412 | subparser.add_argument("--zerocostchunk", type=int, default=0) 413 | subparser.add_argument("--loadmodel", type=str, default="none") 414 | subparser.add_argument("--trainc", type=int, default=1) 415 | subparser.add_argument("--maxllimit", type=int, default=38) 416 | 417 | 418 | subparser = subparsers.add_parser("test") 419 | subparser.set_defaults(callback=run_test) 420 | for arg in dynet_args: 421 | subparser.add_argument(arg) 422 | subparser.add_argument("--model-path-base", required=True) 423 | subparser.add_argument("--evalb-dir", default="EVALB/") 424 | subparser.add_argument("--treetype", default="NRBT") 425 | subparser.add_argument("--test-path", default="data/test.txt") 426 | subparser.add_argument("--expname", default="default") 427 | subparser.add_argument("--normal", type=int, default=1) 428 | 429 | 430 | args = parser.parse_args() 431 | args.callback(args) 432 | 433 | if __name__ == "__main__": 434 | main() 435 | -------------------------------------------------------------------------------- /src/parse.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import dynet as dy 4 | import numpy as np 5 | 6 | import trees 7 | import util 8 | 9 | 10 | START = "" 11 | STOP = "" 12 | UNK = "" 13 | NUM = "" 14 | XX = "XX" 15 | EMPTY = "" 16 | 17 | 18 | def augment(scores, oracle_index): 19 | assert isinstance(scores, dy.Expression) 20 | shape = scores.dim()[0] 21 | assert len(shape) == 1 22 | increment = np.ones(shape) 23 | increment[oracle_index] = 0 24 | return scores + dy.inputVector(increment) 25 | 26 | 27 | 28 | class Feedforward(object): 29 | def __init__(self, model, input_dim, hidden_dims, output_dim): 30 | self.spec = locals() 31 | self.spec.pop("self") 32 | self.spec.pop("model") 33 | 34 | self.model = model.add_subcollection("Feedforward") 35 | 36 | self.weights = [] 37 | self.biases = [] 38 | dims = [input_dim] + hidden_dims + [output_dim] 39 | for prev_dim, next_dim in zip(dims, dims[1:]): 40 | self.weights.append(self.model.add_parameters((next_dim, prev_dim))) 41 | self.biases.append(self.model.add_parameters(next_dim)) 42 | 43 | def param_collection(self): 44 | return self.model 45 | 46 | @classmethod 47 | def from_spec(cls, spec, model): 48 | return cls(model, **spec) 49 | 50 | def __call__(self, x): 51 | for i, (weight, bias) in enumerate(zip(self.weights, self.biases)): 52 | weight = dy.parameter(weight) 53 | bias = dy.parameter(bias) 54 | x = dy.affine_transform([bias, weight, x]) 55 | if i < len(self.weights) - 1: 56 | x = dy.rectify(x) 57 | return x 58 | 59 | 60 | class ChartDynamicRBTConstraintParser(object): 61 | def __init__( 62 | self, 63 | model, 64 | tag_vocab, 65 | word_vocab, 66 | label_vocab, 67 | tag_embedding_dim, 68 | word_embedding_dim, 69 | lstm_layers, 70 | lstm_dim, 71 | label_hidden_dim, 72 | dropout, 73 | pretrainemb, 74 | chunk_encoding = 1, 75 | train_constraint = True, 76 | decode_constraint = True, 77 | zerocostchunk = 0, 78 | nontlabelstyle = 0, 79 | ): 80 | self.spec = locals() 81 | self.spec.pop("self") 82 | self.spec.pop("model") 83 | 84 | self.model = model.add_subcollection("Parser") 85 | self.tag_vocab = tag_vocab 86 | self.word_vocab = word_vocab 87 | self.label_vocab = label_vocab 88 | self.lstm_dim = lstm_dim 89 | 90 | self.tag_embeddings = self.model.add_lookup_parameters( 91 | (tag_vocab.size, tag_embedding_dim)) 92 | self.word_embeddings = self.model.add_lookup_parameters( 93 | (word_vocab.size, word_embedding_dim)) 94 | 95 | if pretrainemb[0] != 'none': 96 | print('Init Lookup table with pretrain emb') 97 | self.word_embeddings.init_from_array(pretrainemb[1]) 98 | 99 | self.lstm = dy.BiRNNBuilder( 100 | lstm_layers, 101 | tag_embedding_dim + word_embedding_dim, 102 | 2 * lstm_dim, 103 | self.model, 104 | dy.VanillaLSTMBuilder) 105 | 106 | self.CompFwdRNN = dy.LSTMBuilder(lstm_layers, 2 * lstm_dim, 2 * lstm_dim, model) 107 | self.CompBwdRNN = dy.LSTMBuilder(lstm_layers, 2 * lstm_dim, 2 * lstm_dim, model) 108 | 109 | self.pW_comp = model.add_parameters((lstm_dim * 2, lstm_dim * 4)) 110 | self.pb_comp = model.add_parameters((lstm_dim * 2,)) 111 | 112 | 113 | self.f_label = Feedforward( 114 | self.model, 2 * lstm_dim, [label_hidden_dim], label_vocab.size - 1) 115 | 116 | self.dropout = dropout 117 | 118 | self.chunk_encoding = chunk_encoding 119 | 120 | self.train_constraint = train_constraint 121 | self.decode_constraint = decode_constraint 122 | self.zerocostchunk = zerocostchunk 123 | self.nontlabelstyle = nontlabelstyle 124 | 125 | def param_collection(self): 126 | return self.model 127 | 128 | @classmethod 129 | def from_spec(cls, spec, model): 130 | return cls(model, **spec) 131 | 132 | 133 | 134 | 135 | def parse(self, sentence, gold=None, gold_chunks=None, latentscope=None): 136 | is_train = gold is not None 137 | 138 | if is_train: 139 | self.lstm.set_dropout(self.dropout) 140 | else: 141 | self.lstm.disable_dropout() 142 | 143 | embeddings = [] 144 | for tag, word in [(START, START)] + sentence + [(STOP, STOP)]: 145 | tag_embedding = self.tag_embeddings[self.tag_vocab.index(tag)] 146 | if word not in (START, STOP): 147 | count = self.word_vocab.count(word) 148 | if not count or (is_train and np.random.rand() < 1 / (1 + count)): 149 | word = UNK 150 | word_embedding = self.word_embeddings[self.word_vocab.index(word)] 151 | embeddings.append(dy.concatenate([tag_embedding, word_embedding])) 152 | 153 | lstm_outputs = self.lstm.transduce(embeddings) 154 | 155 | W_comp = dy.parameter(self.pW_comp) 156 | b_comp = dy.parameter(self.pb_comp) 157 | 158 | @functools.lru_cache(maxsize=None) 159 | def get_span_encoding(left, right): 160 | if self.chunk_encoding == 2: 161 | return get_span_encoding_chunk(left, right) 162 | forward = ( 163 | lstm_outputs[right][:self.lstm_dim] - 164 | lstm_outputs[left][:self.lstm_dim]) 165 | backward = ( 166 | lstm_outputs[left + 1][self.lstm_dim:] - 167 | lstm_outputs[right + 1][self.lstm_dim:]) 168 | 169 | 170 | bi = dy.concatenate([forward, backward]) 171 | 172 | return bi 173 | 174 | @functools.lru_cache(maxsize=None) 175 | def get_span_encoding_chunk(left, right): 176 | 177 | fw_init = self.CompFwdRNN.initial_state() 178 | bw_init = self.CompBwdRNN.initial_state() 179 | 180 | fwd_exp = fw_init.transduce(lstm_outputs[left:right]) 181 | bwd_exp = bw_init.transduce(reversed(lstm_outputs[left:right])) 182 | 183 | bi = dy.concatenate([fwd_exp[-1], bwd_exp[-1]]) 184 | chunk_rep = dy.rectify(dy.affine_transform([b_comp, W_comp, bi])) 185 | 186 | return chunk_rep 187 | 188 | @functools.lru_cache(maxsize=None) 189 | def get_label_scores(left, right): 190 | non_empty_label_scores = self.f_label(get_span_encoding(left, right)) 191 | return dy.concatenate([dy.zeros(1), non_empty_label_scores]) 192 | 193 | 194 | 195 | def helper(force_gold): 196 | if force_gold: 197 | assert is_train 198 | 199 | chart = {} 200 | label_scores_span_max = {} 201 | 202 | for length in range(1, len(sentence) + 1): 203 | for left in range(0, len(sentence) + 1 - length): 204 | right = left + length 205 | 206 | label_scores = get_label_scores(left, right) 207 | 208 | if is_train: 209 | oracle_label = gold.oracle_label(left, right) 210 | oracle_label_index = self.label_vocab.index(oracle_label) 211 | 212 | 213 | 214 | if force_gold: 215 | label = oracle_label 216 | label_index = oracle_label_index 217 | label_score = label_scores[label_index] 218 | 219 | if self.nontlabelstyle == 3: 220 | label_scores_np = label_scores.npvalue() 221 | argmax_label_index = int( 222 | label_scores_np.argmax() if length < len(sentence) else 223 | label_scores_np[1:].argmax() + 1) 224 | argmax_label = self.label_vocab.value(argmax_label_index) 225 | label = argmax_label 226 | label_score = label_scores[argmax_label_index] 227 | 228 | else: 229 | if is_train: 230 | label_scores = augment(label_scores, oracle_label_index) 231 | label_scores_np = label_scores.npvalue() 232 | #argmax_score = dy.argmax(label_scores, gradient_mode="straight_through_gradient") 233 | #dy.dot_product() 234 | argmax_label_index = int( 235 | label_scores_np.argmax() if length < len(sentence) else 236 | label_scores_np[1:].argmax() + 1) 237 | argmax_label = self.label_vocab.value(argmax_label_index) 238 | label = argmax_label 239 | label_score = label_scores[argmax_label_index] 240 | 241 | if length == 1: 242 | tag, word = sentence[left] 243 | tree = trees.LeafParseNode(left, tag, word) 244 | if label: 245 | tree = trees.InternalParseNode(label, [tree]) 246 | chart[left, right] = [tree], label_score 247 | label_scores_span_max[left, right] = [tree], label_score 248 | continue 249 | 250 | if force_gold: 251 | oracle_splits = gold.oracle_splits(left, right) 252 | 253 | if (len(label) > 0 and (label[0].endswith("'") or label[0] == EMPTY)) and latentscope[0] <= left <= right <= latentscope[1]: # if label == (EMPTY,) and latentscope[0] <= left <= right <= latentscope[1]: # and label != (): 254 | # Latent during Training 255 | #if self.train_constraint: 256 | 257 | # if self.train_constraint: 258 | # oracle_splits = [(oracle_splits[0], 0), (oracle_splits[-1], 1)] 259 | # else: 260 | # #oracle_splits = [(p, 0) for p in oracle_splits] + [(p, 1) for p in oracle_splits] # it is not correct 261 | # pass 262 | 263 | oracle_splits = [(oracle_splits[0], 0), (oracle_splits[-1], 1)] 264 | best_split = max(oracle_splits, 265 | key=lambda sb : #(split, branching) #branching == 0: right branching; 1: left branching 266 | label_scores_span_max[left, sb[0]][1].value() + chart[sb[0], right][1].value() if sb[1] == 0 else 267 | chart[left, sb[0]][1].value() + label_scores_span_max[sb[0], right][1].value() 268 | ) 269 | else: 270 | 271 | best_split = (min(oracle_splits), 0) #by default right braching 272 | 273 | 274 | else: 275 | pred_range = range(left + 1, right) 276 | pred_splits = [(p, 0) for p in pred_range] + [(p, 1) for p in pred_range] 277 | best_split = max(pred_splits, 278 | key=lambda sb: # (split, branching) #branching == 0: right branching; 1: left branching 279 | label_scores_span_max[left, sb[0]][1].value() + chart[sb[0], right][1].value() if sb[1] == 0 else 280 | chart[left, sb[0]][1].value() + label_scores_span_max[sb[0], right][1].value() 281 | ) 282 | 283 | children_leaf = [trees.LeafParseNode(pos, sentence[pos][0], sentence[pos][1]) for pos in range(left, right)] 284 | label_scores_span_max[left, right] = children_leaf, label_score 285 | 286 | 287 | 288 | if best_split[1] == 0:#Right Branching 289 | left_trees, left_score = label_scores_span_max[left, best_split[0]] 290 | right_trees, right_score = chart[best_split[0], right] 291 | else:#Left Branching 292 | left_trees, left_score = chart[left, best_split[0]] 293 | right_trees, right_score = label_scores_span_max[best_split[0], right] 294 | 295 | 296 | children = left_trees + right_trees 297 | 298 | if label: 299 | children = [trees.InternalParseNode(label, children)] 300 | if not label[0].endswith("'"): 301 | children_leaf = [trees.InternalParseNode(label, children_leaf)] 302 | label_scores_span_max[left, right] = children_leaf, label_score 303 | 304 | chart[left, right] = (children, label_score + left_score + right_score) 305 | 306 | 307 | 308 | children, score = chart[0, len(sentence)] 309 | assert len(children) == 1 310 | return children[0], score 311 | 312 | tree, score = helper(False) 313 | if is_train: 314 | oracle_tree, oracle_score = helper(True) 315 | #assert oracle_tree.convert().linearize() == gold.convert().linearize() 316 | #correct = tree.convert().linearize() == gold.convert().linearize() 317 | 318 | if self.zerocostchunk: 319 | pred_chunks = tree.convert().to_chunks() 320 | correct = (gold_chunks == pred_chunks) 321 | else: 322 | correct = False #(gold_chunks == pred_chunks) 323 | loss = dy.zeros(1) if correct else score - oracle_score 324 | #loss = score - oracle_score 325 | return tree, loss 326 | else: 327 | return tree, score 328 | 329 | -------------------------------------------------------------------------------- /src/trees.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | import parse 3 | import util 4 | 5 | class TreebankNode(object): 6 | pass 7 | 8 | class InternalTreebankNode(TreebankNode): 9 | def __init__(self, label, children): 10 | assert isinstance(label, str) 11 | self.label = label 12 | 13 | assert isinstance(children, collections.abc.Sequence) 14 | assert all(isinstance(child, TreebankNode) for child in children) 15 | #assert children 16 | self.children = tuple(children) 17 | 18 | def to_chunks(self): 19 | raw_chunks = [] 20 | self.chunk_helper(raw_chunks) 21 | 22 | chunks = [] 23 | p = 0 24 | for label, text_list in raw_chunks: 25 | if label.endswith("'"): 26 | label = label[:-1] 27 | chunks.append((label, p, p + len(text_list), text_list)) 28 | p += len(text_list) 29 | 30 | return chunks 31 | 32 | 33 | def chunk_helper(self, chunks): 34 | 35 | children_status = [isinstance(child, LeafTreebankNode) for child in self.children] 36 | 37 | if all(children_status): 38 | chunks.append((self.label, [child.word for child in self.children])) 39 | elif any(children_status): 40 | char_list = [] 41 | 42 | for child in self.children: 43 | if isinstance(child, InternalTreebankNode): 44 | char_list += child.get_word_list() 45 | else: 46 | char_list.append(child.word) 47 | 48 | 49 | chunk = (self.label, char_list) 50 | chunks.append(chunk) 51 | else: 52 | for child in self.children: 53 | if isinstance(child, InternalTreebankNode): 54 | child.chunk_helper(chunks) 55 | 56 | 57 | def get_word_list(self): 58 | word_list = [] 59 | for child in self.children: 60 | if isinstance(child, InternalTreebankNode): 61 | word_list += child.get_word_list() 62 | else: 63 | word_list += [child.word] 64 | return word_list 65 | 66 | def linearize(self): 67 | return "({} {})".format( 68 | self.label, " ".join(child.linearize() for child in self.children)) 69 | 70 | def leaves(self): 71 | for child in self.children: 72 | yield from child.leaves() 73 | 74 | def convert(self, index=0): 75 | tree = self 76 | sublabels = [self.label] 77 | 78 | while len(tree.children) == 1 and isinstance(tree.children[0], InternalTreebankNode): 79 | tree = tree.children[0] 80 | sublabels.append(tree.label) 81 | 82 | children = [] 83 | for child in tree.children: 84 | children.append(child.convert(index=index)) 85 | index = children[-1].right 86 | 87 | return InternalParseNode(tuple(sublabels), children) 88 | 89 | 90 | 91 | class InternalTreebankChunkNode(InternalTreebankNode): 92 | def __init__(self, label, children): 93 | super(InternalTreebankChunkNode, self).__init__(label, children) 94 | self.is_chunk_node = True 95 | 96 | def convert(self, index=0): 97 | tree = self 98 | sublabels = [self.label] 99 | 100 | while len(tree.children) == 1 and isinstance(tree.children[0], InternalTreebankNode): 101 | tree = tree.children[0] 102 | sublabels.append(tree.label) 103 | 104 | children = [] 105 | for child in tree.children: 106 | children.append(child.convert(index=index)) 107 | index = children[-1].right 108 | 109 | return InternalParseChunkNode(tuple(sublabels), children) 110 | 111 | 112 | class LeafTreebankNode(TreebankNode): 113 | def __init__(self, tag, word): 114 | assert isinstance(tag, str) 115 | self.tag = tag 116 | 117 | assert isinstance(word, str) 118 | self.word = word 119 | 120 | def linearize(self): 121 | return "({} {})".format(self.tag, self.word) 122 | 123 | def leaves(self): 124 | yield self 125 | 126 | def convert(self, index=0): 127 | return LeafParseNode(index, self.tag, self.word) 128 | 129 | 130 | 131 | class InternalUncompletedTreebankNode(TreebankNode): 132 | def __init__(self, label, chunkleaves, chunks_in_scope, latent): 133 | assert isinstance(label, str) 134 | self.label = label 135 | self.latent = latent 136 | #assert isinstance(children, collections.abc.Sequence) 137 | #assert all(isinstance(child, TreebankNode) for child in children) 138 | #assert children 139 | self.children = () #tuple(children) 140 | self.chunkleaves = chunkleaves 141 | self.chunks_in_scope = chunks_in_scope 142 | 143 | 144 | def linearize(self): 145 | 146 | #chunks_str_list = [ '(' + label + ' ' + ''.join( '(XX ' + item + ')' for item in text) + ')' for label, s, e, text in self.chunks_in_scope] 147 | return "({} leaves: {})".format(self.label, " ".join(leaf.linearize() for leaf in self.chunkleaves)) 148 | 149 | def leaves(self): 150 | return self.chunkleaves 151 | 152 | def convert(self, index=0): 153 | tree = self 154 | sublabels = [self.label] 155 | 156 | # while len(tree.children) == 1 and isinstance(tree.children[0], InternalTreebankNode): 157 | # tree = tree.children[0] 158 | # sublabels.append(tree.label) 159 | 160 | #children = [] 161 | # for child in tree.children: 162 | # children.append(child.convert(index=index)) 163 | # index = children[-1].right 164 | #index = self.chunks_in_scope[0][1] 165 | chunkleaves = [] 166 | for chunkleaf in self.chunkleaves: 167 | chunkleaves.append(chunkleaf.convert(index=index)) 168 | index = chunkleaves[-1].right 169 | # for i in range(len(self.chunkleaves)): 170 | # chunk = self.chunks_in_scope[i] 171 | # chunkleaf = self.chunkleaves[i] 172 | # chunkleaf.convert(index=chunk[1]) 173 | # chunkleaves.append(chunkleaf) 174 | 175 | return InternalUncompletedParseNode(tuple(sublabels), chunkleaves, self.chunks_in_scope, self.latent) 176 | 177 | 178 | 179 | class ParseNode(object): 180 | pass 181 | 182 | class InternalParseNode(ParseNode): 183 | def __init__(self, label, children): 184 | assert isinstance(label, tuple) 185 | assert all(isinstance(sublabel, str) for sublabel in label) 186 | assert label 187 | self.label = label 188 | 189 | assert isinstance(children, collections.abc.Sequence) 190 | assert all(isinstance(child, ParseNode) for child in children) 191 | assert children 192 | assert len(children) > 1 or isinstance(children[0], LeafParseNode) 193 | assert all( 194 | left.right == right.left 195 | for left, right in zip(children, children[1:])) 196 | self.children = tuple(children) 197 | 198 | self.left = children[0].left 199 | self.right = children[-1].right 200 | 201 | def leaves(self): 202 | for child in self.children: 203 | yield from child.leaves() 204 | 205 | def convert(self): 206 | children = [child.convert() for child in self.children] 207 | tree = InternalTreebankNode(self.label[-1], children) 208 | for sublabel in reversed(self.label[:-1]): 209 | tree = InternalTreebankNode(sublabel, [tree]) 210 | return tree 211 | 212 | def enclosing(self, left, right): 213 | assert self.left <= left < right <= self.right 214 | for child in self.children: 215 | if isinstance(child, LeafParseNode): 216 | continue 217 | if child.left <= left < right <= child.right: 218 | return child.enclosing(left, right) 219 | return self 220 | 221 | def oracle_label(self, left, right): 222 | enclosing = self.enclosing(left, right) 223 | if enclosing.left == left and enclosing.right == right: 224 | return enclosing.label 225 | return () 226 | 227 | def oracle_splits(self, left, right): 228 | # return [ 229 | # child.left 230 | # for child in self.enclosing(left, right).children 231 | # if left < child.left < right 232 | # ] 233 | enclosing = self.enclosing(left, right) 234 | return [ 235 | child.left 236 | for child in enclosing.children 237 | if left < child.left < right 238 | ] 239 | # if isinstance(enclosing, InternalUncompletedParseNode): 240 | # return [ 241 | # child.left 242 | # for child in enclosing.chunkleaves 243 | # if left < child.left < right 244 | # ] 245 | # else: 246 | # return [ 247 | # child.left 248 | # for child in enclosing.children 249 | # if left < child.left < right 250 | # ] 251 | 252 | 253 | def oracle_splits2(self, left, right): 254 | # return [ 255 | # child.left 256 | # for child in self.enclosing(left, right).children 257 | # if left < child.left < right 258 | # ] 259 | 260 | enclosing = self.enclosing(left, right) 261 | if enclosing.left == left and enclosing.right == right: 262 | splits = [child.left for child in enclosing.children if not isinstance(child, LeafTreebankNode)] 263 | splits = splits[1:] 264 | return splits 265 | else: 266 | return [] 267 | 268 | 269 | # if isinstance(self, InternalParseChunkNode): 270 | # return [] 271 | # elif isinstance(self, InternalUncompletedParseNode): 272 | # return self.oracle_splits(left, right) 273 | # else: 274 | # #enclosing = self.enclosing(left, right) 275 | # if self.left == left and self.right == right: 276 | # splits = [child.left for child in self.children if not isinstance(child, LeafTreebankNode)] 277 | # splits = splits[1:] 278 | # return splits 279 | # 280 | # 281 | # for child in self.children: 282 | # if not isinstance(child, LeafParseNode): 283 | # if isinstance(child, InternalUncompletedParseNode): 284 | # if child.left <= left <= right <= child.right: 285 | # return child.oracle_splits(left, right) 286 | # else: 287 | # if child.left == left and child.right == right: 288 | # return child.oracle_splits2(left, right) 289 | # 290 | # return [] 291 | 292 | 293 | 294 | class InternalParseChunkNode(InternalParseNode): 295 | def __init__(self, label, children): 296 | super(InternalParseChunkNode, self).__init__(label, children) 297 | self.is_chunk_node = True 298 | self._chunknode = None 299 | 300 | def convert(self): 301 | children = [child.convert() for child in self.children] 302 | tree = InternalTreebankChunkNode(self.label[-1], children) 303 | for sublabel in reversed(self.label[:-1]): 304 | tree = InternalTreebankChunkNode(sublabel, [tree]) 305 | return tree 306 | 307 | def enclosing(self, left, right): 308 | assert self.left <= left < right <= self.right 309 | for child in self.children: 310 | if isinstance(child, LeafParseNode): 311 | continue 312 | if child.left <= left < right <= child.right: 313 | return child.enclosing(left, right) 314 | 315 | if self._chunknode is None: 316 | self._chunknode = InternalParseChunkNode(self.label, self.children) 317 | self._chunknode.children = [] 318 | return self._chunknode 319 | 320 | def oracle_splits(self, left, right): 321 | # return [ 322 | # child.left 323 | # for child in self.enclosing(left, right).children 324 | # if left < child.left < right 325 | # ] 326 | enclosing = self.enclosing(left, right) 327 | return [ 328 | child.left 329 | for child in enclosing.children 330 | if left < child.left < right 331 | ] 332 | 333 | 334 | class InternalUncompletedParseNode(InternalParseNode): 335 | def __init__(self, label, chunkleaves, chunks_in_scope:[], latent): 336 | assert isinstance(label, tuple) 337 | assert all(isinstance(sublabel, str) for sublabel in label) 338 | assert label 339 | self.label = label 340 | self.latent = latent 341 | 342 | #assert isinstance(children, collections.abc.Sequence) 343 | #assert all(isinstance(child, ParseNode) for child in children) 344 | #assert children 345 | #assert len(children) > 1 or isinstance(children[0], LeafParseNode) 346 | # assert all( 347 | # left.right == right.left 348 | # for left, right in zip(children, children[1:])) 349 | self.children = [] #tuple(children) 350 | self.chunkleaves = chunkleaves 351 | 352 | #self.left = children[0].left 353 | #self.right = children[-1].right 354 | self.left = chunkleaves[0].left #chunks_in_scope[0][1] 355 | self.right = chunkleaves[-1].right #chunks_in_scope[-1][2] 356 | self.chunks_in_scope = chunks_in_scope 357 | self.splits = [chunk[1] for chunk in self.chunks_in_scope] + [self.chunks_in_scope[-1][2]] 358 | 359 | def leaves(self): 360 | return self.chunkleaves 361 | 362 | def convert(self): 363 | # children = [child.convert() for child in self.children] 364 | # tree = InternalUncompletedTreebankNode(self.label[-1], children) 365 | # for sublabel in reversed(self.label[:-1]): 366 | # tree = InternalUncompletedTreebankNode(sublabel, [tree]) 367 | 368 | chunkleaves = [chunkleaf.convert() for chunkleaf in self.chunkleaves] 369 | tree = InternalUncompletedTreebankNode(self.label[-1], chunkleaves, (self.left, self.right), self.latent) 370 | return tree 371 | 372 | def enclosing(self, left, right): 373 | assert self.left <= left < right <= self.right 374 | for chunkleaf in self.chunkleaves: 375 | if isinstance(chunkleaf, LeafParseNode): 376 | continue 377 | if chunkleaf.left <= left < right <= chunkleaf.right: 378 | return chunkleaf.enclosing(left, right) 379 | 380 | # if left in self.splits and right in self.splits: 381 | # children = [chunkleaf for chunkleaf in self.chunkleaves if left <= chunkleaf.left and chunkleaf.right <= right] 382 | # # label = max([child.label for child in children],key=lambda l:self.latent.get_label_order(l[0])) 383 | # # label = (self.latent.non_terminal_label(label[0]),) 384 | # label = self.label 385 | # return InternalParseNode(label, children) 386 | 387 | children = [chunkleaf for chunkleaf in self.chunkleaves if left < chunkleaf.right] 388 | children = [chunkleaf for chunkleaf in children if right > chunkleaf.left] 389 | 390 | if self.latent.non_terminal_label_mode == 0 or self.latent.non_terminal_label_mode == 3: 391 | label = max([child.label for child in children], key=lambda l: self.latent.get_label_order(l[0])) 392 | label = (self.latent.non_terminal_label(label[0]),) 393 | elif self.latent.non_terminal_label_mode == 1: 394 | label = self.label 395 | else: #self.latent.non_terminal_label_mode == 2: 396 | import random 397 | label_id = random.randint(1 + self.latent.label_size + 0, 1 + self.latent.label_size + self.latent.label_size - 1) 398 | label = self.latent.label_vocab.value(label_id) 399 | return InternalParseNode(label, children) 400 | 401 | # self.children = self.chunkleaves 402 | # return self 403 | 404 | def oracle_label(self, left, right): 405 | # enclosing = self.enclosing(left, right) 406 | # if enclosing.left == left and enclosing.right == right: 407 | # return enclosing.label 408 | 409 | for chunk in self.chunks_in_scope: 410 | if chunk[1] == left and chunk[2] == right: 411 | return (chunk[0],) 412 | 413 | 414 | if left in self.splits and right in self.splits: 415 | return self.label 416 | 417 | return () 418 | 419 | def oracle_splits(self, left, right): 420 | 421 | 422 | ret = [p for p in self.splits if left < p and p < right] 423 | if len(ret) == 0: 424 | return [ 425 | child.left 426 | for child in self.enclosing(left, right).children 427 | if left < child.left < right 428 | ] 429 | 430 | 431 | return ret 432 | 433 | 434 | 435 | class LeafParseNode(ParseNode): 436 | def __init__(self, index, tag, word): 437 | assert isinstance(index, int) 438 | assert index >= 0 439 | self.left = index 440 | self.right = index + 1 441 | 442 | assert isinstance(tag, str) 443 | self.tag = tag 444 | 445 | assert isinstance(word, str) 446 | self.word = word 447 | 448 | def leaves(self): 449 | yield self 450 | 451 | def convert(self): 452 | return LeafTreebankNode(self.tag, self.word) 453 | 454 | 455 | def load_trees(path, normal, strip_top=True): 456 | with open(path, 'r', encoding='utf-8') as infile: 457 | tokens = infile.read().replace("(", " ( ").replace(")", " ) ").split() 458 | 459 | def helper(index): 460 | trees = [] 461 | 462 | while index < len(tokens) and tokens[index] == "(": 463 | paren_count = 0 464 | while tokens[index] == "(": 465 | index += 1 466 | paren_count += 1 467 | 468 | label = tokens[index] 469 | index += 1 470 | 471 | if tokens[index] == "(": 472 | children, index = helper(index) 473 | trees.append(InternalTreebankNode(label, children)) 474 | else: 475 | word = tokens[index] 476 | if normal == 1: 477 | newword = '' 478 | for c in word: 479 | if util.is_digit(c): 480 | newword += '0' 481 | else: 482 | newword += c 483 | else: 484 | newword = word 485 | index += 1 486 | trees.append(LeafTreebankNode(label, newword)) 487 | 488 | while paren_count > 0: 489 | assert tokens[index] == ")" 490 | index += 1 491 | paren_count -= 1 492 | 493 | return trees, index 494 | 495 | trees, index = helper(0) 496 | assert index == len(tokens) 497 | 498 | if strip_top: 499 | for i, tree in enumerate(trees): 500 | if tree.label == "TOP": 501 | assert len(tree.children) == 1 502 | trees[i] = tree.children[0] 503 | 504 | return trees 505 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | import vocabulary 2 | 3 | 4 | 5 | chinese_digit = '一二三四五六七八九十百千万' 6 | 7 | def is_digit(tok:str): 8 | 9 | tok = tok.strip() 10 | 11 | if tok.isdigit(): 12 | return True 13 | 14 | if tok in chinese_digit: 15 | return True 16 | 17 | return False 18 | 19 | 20 | def seq2chunk(seq): 21 | chunks = [] 22 | label = None 23 | last_label = None 24 | start_idx = 0 25 | for i in range(len(seq)): 26 | tok = seq[i] 27 | label = tok[2:] if tok.startswith('B') or tok.startswith('I') else tok 28 | if tok.startswith('B') or last_label != label: 29 | if last_label == None: 30 | start_idx = i 31 | else: 32 | chunks.append((last_label, start_idx, i)) 33 | start_idx = i 34 | 35 | last_label = label 36 | 37 | chunks.append((label, start_idx, len(seq))) 38 | 39 | return chunks 40 | 41 | def read_chunks(filename, normal = 1): 42 | f = open(filename, 'r', encoding='utf-8') 43 | insts = [] 44 | inst = list() 45 | num_inst = 0 46 | max_chunk_length_limit = 36 47 | max_chunk_length = 0 48 | max_char_length = 0 49 | for line in f: 50 | line = line.strip() 51 | if line == "": 52 | if inst != None: 53 | 54 | inst = [tuple(x) for x in inst] 55 | 56 | tmp = list(zip(*inst)) 57 | 58 | x = tmp[0] 59 | new_x = [] 60 | for word in x: 61 | if normal == 1: 62 | newword = '' 63 | for c in word: 64 | if is_digit(c): 65 | newword += '0' 66 | else: 67 | newword += c 68 | else: 69 | newword = word 70 | 71 | new_x.append(newword) 72 | 73 | 74 | y = [x[:2] + x[2:].lower() for x in tmp[1]] 75 | 76 | chunks = seq2chunk(y) 77 | 78 | for chunk in chunks: 79 | if max_chunk_length < chunk[2] - chunk[1]: 80 | max_chunk_length = chunk[2] - chunk[1] 81 | 82 | if max_char_length < len(x): 83 | max_char_length = len(x) 84 | 85 | insts.append((new_x, chunks)) 86 | 87 | inst = list() 88 | else: 89 | inst.append(line.split()) 90 | f.close() 91 | print(filename + 'is loaded.','\t', 'max_chunk_length=',max_chunk_length, '\tmax_char_length:', max_char_length) 92 | return insts 93 | 94 | 95 | def load_trees_from_str(tokens, normal = 1, strip_top=True): 96 | from trees import InternalTreebankNode, LeafTreebankNode 97 | 98 | tokens = tokens.replace("(", " ( ").replace(")", " ) ").split() 99 | 100 | def helper(index): 101 | trees = [] 102 | 103 | while index < len(tokens) and tokens[index] == "(": 104 | paren_count = 0 105 | while tokens[index] == "(": 106 | index += 1 107 | paren_count += 1 108 | 109 | label = tokens[index] 110 | index += 1 111 | 112 | if tokens[index] == "(": 113 | children, index = helper(index) 114 | trees.append(InternalTreebankNode(label, children)) 115 | else: 116 | word = tokens[index] 117 | if normal == 1: 118 | newword = '' 119 | for c in word: 120 | if is_digit(c): 121 | newword += '0' 122 | else: 123 | newword += c 124 | else: 125 | newword = word 126 | index += 1 127 | trees.append(LeafTreebankNode(label, newword)) 128 | 129 | while paren_count > 0: 130 | assert tokens[index] == ")" 131 | index += 1 132 | paren_count -= 1 133 | 134 | return trees, index 135 | 136 | trees, index = helper(0) 137 | assert index == len(tokens) 138 | 139 | if strip_top: 140 | for i, tree in enumerate(trees): 141 | if tree.label == "TOP": 142 | assert len(tree.children) == 1 143 | trees[i] = tree.children[0] 144 | 145 | return trees 146 | 147 | 148 | 149 | def load_label_list(path): 150 | f = open(path, 'r', encoding='utf-8') 151 | label_list = [line.strip() for line in f] 152 | f.close() 153 | return label_list 154 | 155 | 156 | def load_pretrain(filename : str, WORD_DIM, word_vocab : vocabulary.Vocabulary, saveemb=True): 157 | import numpy as np 158 | import parse 159 | import pickle 160 | 161 | if filename == 'none': 162 | print("Do not use pretrain embedding...") 163 | return None 164 | 165 | print('Loading Pretrained Embedding from ', filename,' ...') 166 | vocab_dic = {} 167 | with open(filename, encoding='utf-8') as f: 168 | for line in f: 169 | s_s = line.split() 170 | if s_s[0] in word_vocab.counts: 171 | vocab_dic[s_s[0]] = np.array([float(x) for x in s_s[1:]]) 172 | # vocab_dic[s_s[0]] = [float(x) for x in s_s[1:]] 173 | 174 | unknowns = np.random.uniform(-0.01, 0.01, WORD_DIM).astype("float32") 175 | numbers = np.random.uniform(-0.01, 0.01, WORD_DIM).astype("float32") 176 | 177 | vocab_dic[parse.UNK] = unknowns 178 | vocab_dic[parse.NUM] = numbers 179 | 180 | 181 | 182 | ret_mat = np.zeros((word_vocab.size, WORD_DIM)) 183 | unk_counter = 0 184 | for token_id in range(word_vocab.size): 185 | token = word_vocab.value(token_id) 186 | if token in vocab_dic: 187 | # ret_mat.append(vocab_dic[token]) 188 | ret_mat[token_id] = vocab_dic[token] 189 | # elif parse.is_digit(token) or token == '': 190 | # ret_mat[token_id] = numbers 191 | else: 192 | # ret_mat.append(unknowns) 193 | ret_mat[token_id] = unknowns 194 | # print "Unknown token:", token 195 | unk_counter += 1 196 | #print('unk:', token) 197 | ret_mat = np.array(ret_mat) 198 | 199 | print('ret_mat shape:', ret_mat.shape) 200 | 201 | if saveemb: 202 | with open('giga.emb', "wb") as f: 203 | pickle.dump(ret_mat, f) 204 | 205 | print("{0} unk out of {1} vocab".format(unk_counter, word_vocab.size)) 206 | print('Glove Embedding is loaded.') 207 | return ret_mat 208 | 209 | 210 | def inst2chunks(inst): 211 | chunks = [] 212 | x, x_chunks = inst 213 | for x_chunk in x_chunks: 214 | chunk = (x_chunk[0], x_chunk[1], x_chunk[2], x[x_chunk[1]:x_chunk[2]]) 215 | chunks.append(chunk) 216 | 217 | return chunks -------------------------------------------------------------------------------- /src/vocabulary.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | class Vocabulary(object): 4 | def __init__(self): 5 | self.frozen = False 6 | self.values = [] 7 | self.indices = {} 8 | self.counts = collections.defaultdict(int) 9 | 10 | @property 11 | def size(self): 12 | return len(self.values) 13 | 14 | def value(self, index): 15 | assert 0 <= index < len(self.values) 16 | return self.values[index] 17 | 18 | def index(self, value): 19 | if not self.frozen: 20 | self.counts[value] += 1 21 | 22 | if value in self.indices: 23 | return self.indices[value] 24 | 25 | elif not self.frozen: 26 | self.values.append(value) 27 | self.indices[value] = len(self.values) - 1 28 | return self.indices[value] 29 | 30 | else: 31 | raise ValueError("Unknown value: {}".format(value)) 32 | 33 | def count(self, value): 34 | return self.counts[value] 35 | 36 | def freeze(self): 37 | self.frozen = True 38 | --------------------------------------------------------------------------------