├── LICENSE ├── README.md ├── code ├── data.py ├── eval │ ├── __init__.py │ ├── evaluate_moses.py │ ├── evaluate_nist.py │ ├── mteval-v11b.pl │ └── multi-bleu.perl ├── nmt.py ├── sample.py └── train.py ├── data └── newstest2014.en.trans.txt └── work ├── german.py ├── run.sh └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, DeepLearnXMU 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CAEncoder-NMT 2 | Source code for A Context-Aware Recurrent Encoder for Neural Machine Translation. Our model is much faster than the standard encoder-attention-decoder model, and obtains a BLEU point of 22.57 on English-German translation task, compared with that of 20.87 yielded by dl4mt. 3 | 4 | If you use this code, please cite our paper: 5 | ``` 6 | @article{Zhang:2017:CRE:3180104.3180106, 7 | author = {Zhang, Biao and Xiong, Deyi and Su, Jinsong and Duan, Hong}, 8 | title = {A Context-Aware Recurrent Encoder for Neural Machine Translation}, 9 | journal = {IEEE/ACM Trans. Audio, Speech and Lang. Proc.}, 10 | issue_date = {December 2017}, 11 | volume = {25}, 12 | number = {12}, 13 | month = dec, 14 | year = {2017}, 15 | issn = {2329-9290}, 16 | pages = {2424--2432}, 17 | numpages = {9}, 18 | url = {https://doi.org/10.1109/TASLP.2017.2751420}, 19 | doi = {10.1109/TASLP.2017.2751420}, 20 | acmid = {3180106}, 21 | publisher = {IEEE Press}, 22 | address = {Piscataway, NJ, USA}, 23 | } 24 | ``` 25 | 26 | ## How to Run? 27 | 28 | A demo case is provided in the `work` directory 29 | 30 | ### Training 31 | You need process your training data and set up a configuration file, as the `german.py` does. The `train.py` script is used for training. 32 | 33 | ### Testing 34 | All you need is the `sample.py` script. Of course, the directory for vocabularies and model files are required. 35 | 36 | 37 | For any comments or questions, please email Biao Zhang. 38 | -------------------------------------------------------------------------------- /code/data.py: -------------------------------------------------------------------------------- 1 | import cPickle as pkl 2 | import gzip 3 | import numpy 4 | 5 | ''' 6 | ID rule: UNK=>1 7 | =>0 8 | ''' 9 | 10 | numpy.random.seed(1234) 11 | 12 | def fopen(filename, mode='r'): 13 | if filename.endswith('.gz'): 14 | return gzip.open(filename, mode) 15 | return open(filename, mode) 16 | 17 | 18 | class TextIterator: 19 | """Simple Bitext iterator.""" 20 | def __init__(self, source, target, 21 | source_dict, target_dict, 22 | batch_size=128, 23 | maxlen=100, 24 | n_words_source=-1, 25 | n_words_target=-1, 26 | shuffle_prob=1): 27 | self.source = fopen(source, 'r') 28 | self.target = fopen(target, 'r') 29 | with open(source_dict, 'rb') as f: 30 | self.source_dict = pkl.load(f) 31 | with open(target_dict, 'rb') as f: 32 | self.target_dict = pkl.load(f) 33 | 34 | self.batch_size = batch_size 35 | self.maxlen = maxlen 36 | self.shuffle_prob=shuffle_prob 37 | 38 | self.n_words_source = n_words_source 39 | self.n_words_target = n_words_target 40 | 41 | self.end_of_data = False 42 | 43 | def __iter__(self): 44 | return self 45 | 46 | def reset(self): 47 | self.source.seek(0) 48 | self.target.seek(0) 49 | 50 | def next(self): 51 | if self.end_of_data: 52 | self.end_of_data = False 53 | self.reset() 54 | raise StopIteration 55 | 56 | source = [] 57 | target = [] 58 | 59 | try: 60 | 61 | # actual work here 62 | while True: 63 | 64 | # read from source file and map to word index 65 | ss = self.source.readline() 66 | if ss == "": 67 | raise IOError 68 | ss = ss.strip().split() 69 | ss = [self.source_dict[w] if w in self.source_dict else 1 70 | for w in ss] 71 | if self.n_words_source > 0: 72 | ss = [w if w < self.n_words_source else 1 for w in ss] 73 | 74 | # read from source file and map to word index 75 | tt = self.target.readline() 76 | if tt == "": 77 | raise IOError 78 | tt = tt.strip().split() 79 | tt = [self.target_dict[w] if w in self.target_dict else 1 80 | for w in tt] 81 | if self.n_words_target > 0: 82 | tt = [w if w < self.n_words_target else 1 for w in tt] 83 | 84 | if len(ss) > self.maxlen and len(tt) > self.maxlen: 85 | continue 86 | if numpy.random.random() > self.shuffle_prob: 87 | continue 88 | 89 | source.append(ss) 90 | target.append(tt) 91 | 92 | if len(source) >= self.batch_size or \ 93 | len(target) >= self.batch_size: 94 | break 95 | except IOError: 96 | self.end_of_data = True 97 | 98 | if len(source) <= 0 or len(target) <= 0: 99 | self.end_of_data = False 100 | self.reset() 101 | raise StopIteration 102 | 103 | return source, target 104 | -------------------------------------------------------------------------------- /code/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XMUDeepLIT/CAEncoder-NMT/8bb0e3d132c3fa12fb5792ab931268d68a3539ea/code/eval/__init__.py -------------------------------------------------------------------------------- /code/eval/evaluate_moses.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | 3 | ''' 4 | Evaluate the translation result from Neural Machine Translation 5 | ''' 6 | 7 | import sys 8 | import os 9 | import re 10 | import time 11 | import string 12 | 13 | run = os.system 14 | path = os.path.dirname(os.path.realpath(__file__)) 15 | 16 | def eval_trans(src_sgm, ref_sgm, trs_plain): 17 | 18 | cmd = ("%s/multi-bleu.perl %s < %s > %s.eval.nmt" \ 19 | %(path, ref_sgm, trs_plain, trs_plain)) 20 | print cmd 21 | run(cmd) 22 | eval_nmt = ''.join(file('%s.eval.nmt' % trs_plain, 'rU').readlines()) 23 | 24 | bleu = float(eval_nmt.strip().split(',')[0].split(' ')[-1]) 25 | 26 | return bleu 27 | 28 | if __name__ == "__main__": 29 | if len(sys.argv) != 4: 30 | print '%s src_sgm(meaningless), ref_sgm(plain ref), trs_plain(plain trans)' % sys.argv[0] 31 | sys.exit(0) 32 | 33 | src_sgm = sys.argv[1] 34 | ref_sgm = sys.argv[2] 35 | trs_plain = sys.argv[3] 36 | 37 | eval_trans(src_sgm, ref_sgm, trs_plain) 38 | -------------------------------------------------------------------------------- /code/eval/evaluate_nist.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | 3 | ''' 4 | Evaluate the translation result from Neural Machine Translation 5 | ''' 6 | 7 | import sys 8 | import os 9 | import re 10 | import time 11 | import string 12 | 13 | run = os.system 14 | path = os.path.dirname(os.path.realpath(__file__)) 15 | 16 | print path 17 | 18 | def split_seg(line): 19 | p1 = line.find(">")+1 20 | p2 = line.rfind("<") 21 | return [ line[:p1], line[p1:p2], line[p2:] ] 22 | 23 | def plain2sgm(trg_plain, src_sgm, trg_sgm): 24 | "Converse plain format to sgm format" 25 | fin_trg_plain = file(trg_plain , "r") 26 | fin_src_sgm = file(src_sgm, "r") 27 | fout = file(trg_sgm, "w") 28 | 29 | #head 30 | doc_head = fin_src_sgm.readline().rstrip().replace("srcset", "tstset") 31 | if doc_head.find("trglang") == -1 : 32 | doc_head = doc_head.replace(">", " trglang=\"en\">") 33 | print >> fout, doc_head 34 | 35 | for line in fin_src_sgm: 36 | line = line.rstrip() 37 | #process doc tag 38 | if "> fout, '''''' %id 43 | elif line.startswith("> fout, head, fin_trg_plain.readline().rstrip(), tail 46 | elif line.strip() == "": 47 | print >> fout, "" 48 | else: 49 | print >> fout, line 50 | 51 | fout.close() 52 | fin_src_sgm.close() 53 | 54 | def get_bleu(fe): 55 | "Get the bleu score from result file printed by mteval" 56 | 57 | c = file(fe, "rU").read() 58 | reg = re.compile(r"BLEU score =\s+(.*?)\s+") 59 | r = reg.findall(c) 60 | assert len(r) == 1 61 | return float(r[0]) 62 | 63 | def eval_trans(src_sgm, ref_sgm, trs_plain): 64 | 65 | print path 66 | 67 | if src_sgm[:-4] != '.sgm': 68 | src_sgm = src_sgm + '.sgm' 69 | 70 | plain2sgm(trs_plain, src_sgm, "result.sgm") 71 | cmd = "%s/mteval-v11b.pl -s %s -r %s -t result.sgm > %s.eval.nmt" \ 72 | %(path, src_sgm, ref_sgm, trs_plain) 73 | print cmd 74 | run(cmd) 75 | eval_nmt = ''.join(file('%s.eval.nmt' % trs_plain, 'rU').readlines()) 76 | print >> file('%s.eval.nmt' % trs_plain, 'w'), eval_nmt.replace('hiero', 'Nerual Machine Translation') 77 | run('mv result.sgm %s.sgm' % trs_plain) 78 | 79 | return get_bleu('%s.eval.nmt' % trs_plain) 80 | 81 | if __name__ == "__main__": 82 | if len(sys.argv) != 4: 83 | print '%s src_sgm, ref_sgm, trs_plain' % sys.argv[0] 84 | sys.exit(0) 85 | 86 | src_sgm = sys.argv[1] 87 | ref_sgm = sys.argv[2] 88 | trs_plain = sys.argv[3] 89 | 90 | eval_trans(src_sgm, ref_sgm, trs_plain) 91 | -------------------------------------------------------------------------------- /code/eval/mteval-v11b.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | 3 | use strict; 4 | 5 | ################################# 6 | # History: 7 | # 8 | # version 11b -- text normalization modified: 9 | # * take out the join digit line because it joins digits 10 | # when it shouldn't have 11 | # $norm_text =~ s/(\d)\s+(?=\d)/$1/g; #join digits 12 | # 13 | # version 11a -- corrected output of individual n-gram precision values 14 | # 15 | # version 11 -- bug fixes: 16 | # * make filehandle operate in binary mode to prevent Perl from operating 17 | # (by default in Red Hat 9) in UTF-8 18 | # * fix failure on joining digits 19 | # version 10 -- updated output to include more details of n-gram scoring. 20 | # Defaults to generate both NIST and BLEU scores. Use -b for BLEU 21 | # only, use -n for NIST only 22 | # 23 | # version 09d -- bug fix (for BLEU scoring, ngrams were fixed at 4 24 | # being the max, regardless what was entered on the command line.) 25 | # 26 | # version 09c -- bug fix (During the calculation of ngram information, 27 | # each ngram was being counted only once for each segment. This has 28 | # been fixed so that each ngram is counted correctly in each segment.) 29 | # 30 | # version 09b -- text normalization modified: 31 | # * option flag added to preserve upper case 32 | # * non-ASCII characters left in place. 33 | # 34 | # version 09a -- text normalization modified: 35 | # * " and & converted to "" and &, respectively 36 | # * non-ASCII characters kept together (bug fix) 37 | # 38 | # version 09 -- modified to accommodate sgml tag and attribute 39 | # names revised to conform to default SGML conventions. 40 | # 41 | # version 08 -- modifies the NIST metric in accordance with the 42 | # findings on the 2001 Chinese-English dry run corpus. Also 43 | # incorporates the BLEU metric as an option and supports the 44 | # output of ngram detail. 45 | # 46 | # version 07 -- in response to the MT meeting on 28 Jan 2002 at ISI 47 | # Keep strings of non-ASCII characters together as one word 48 | # (rather than splitting them into one-character words). 49 | # Change length penalty so that translations that are longer than 50 | # the average reference translation are not penalized. 51 | # 52 | # version 06 53 | # Prevent divide-by-zero when a segment has no evaluation N-grams. 54 | # Correct segment index for level 3 debug output. 55 | # 56 | # version 05 57 | # improve diagnostic error messages 58 | # 59 | # version 04 60 | # tag segments 61 | # 62 | # version 03 63 | # add detailed output option (intermediate document and segment scores) 64 | # 65 | # version 02 66 | # accommodation of modified sgml tags and attributes 67 | # 68 | # version 01 69 | # same as bleu version 15, but modified to provide formal score output. 70 | # 71 | # original IBM version 72 | # Author: Kishore Papineni 73 | # Date: 06/10/2001 74 | ################################# 75 | 76 | ###### 77 | # Intro 78 | my ($date, $time) = date_time_stamp(); 79 | print "MT evaluation scorer began on $date at $time\n"; 80 | print "command line: ", $0, " ", join(" ", @ARGV), "\n"; 81 | my $usage = "\n\nUsage: $0 [-h] -r -s src_file -t \n\n". 82 | "Description: This Perl script evaluates MT system performance.\n". 83 | "\n". 84 | "Required arguments:\n". 85 | " -r is a file containing the reference translations for\n". 86 | " the documents to be evaluated.\n". 87 | " -s is a file containing the source documents for which\n". 88 | " translations are to be evaluated\n". 89 | " -t is a file containing the translations to be evaluated\n". 90 | "\n". 91 | "Optional arguments:\n". 92 | " -c preserves upper-case alphabetic characters\n". 93 | " -b generate BLEU scores only\n". 94 | " -n generate NIST scores only\n". 95 | " -d detailed output flag used in conjunction with \"-b\" or \"-n\" flags:\n". 96 | " 0 (default) for system-level score only\n". 97 | " 1 to include document-level scores\n". 98 | " 2 to include segment-level scores\n". 99 | " 3 to include ngram-level scores\n". 100 | " -h prints this help message to STDOUT\n". 101 | "\n"; 102 | 103 | use vars qw ($opt_r $opt_s $opt_t $opt_d $opt_h $opt_b $opt_n $opt_c $opt_x); 104 | use Getopt::Std; 105 | getopts ('r:s:t:d:hbncx:'); 106 | die $usage if defined($opt_h); 107 | die "Error in command line: ref_file not defined$usage" unless defined $opt_r; 108 | die "Error in command line: src_file not defined$usage" unless defined $opt_s; 109 | die "Error in command line: tst_file not defined$usage" unless defined $opt_t; 110 | my $max_Ngram = 9; 111 | my $detail = defined $opt_d ? $opt_d : 0; 112 | my $preserve_case = defined $opt_c ? 1 : 0; 113 | 114 | my $METHOD = "BOTH"; 115 | if (defined $opt_b) { $METHOD = "BLEU"; } 116 | if (defined $opt_n) { $METHOD = "NIST"; } 117 | my $method; 118 | 119 | my ($ref_file) = $opt_r; 120 | my ($src_file) = $opt_s; 121 | my ($tst_file) = $opt_t; 122 | 123 | ###### 124 | # Global variables 125 | my ($src_lang, $tgt_lang, @tst_sys, @ref_sys); # evaluation parameters 126 | my (%tst_data, %ref_data); # the data -- with structure: {system}{document}[segments] 127 | my ($src_id, $ref_id, $tst_id); # unique identifiers for ref and tst translation sets 128 | my %eval_docs; # document information for the evaluation data set 129 | my %ngram_info; # the information obtained from (the last word in) the ngram 130 | 131 | ###### 132 | # Get source document ID's 133 | ($src_id) = get_source_info ($src_file); 134 | 135 | ###### 136 | # Get reference translations 137 | ($ref_id) = get_MT_data (\%ref_data, "RefSet", $ref_file); 138 | 139 | compute_ngram_info (); 140 | 141 | ###### 142 | # Get translations to evaluate 143 | ($tst_id) = get_MT_data (\%tst_data, "TstSet", $tst_file); 144 | 145 | ###### 146 | # Check data for completeness and correctness 147 | check_MT_data (); 148 | 149 | ###### 150 | # 151 | my %NISTmt = (); 152 | my %BLEUmt = (); 153 | 154 | ###### 155 | # Evaluate 156 | print " Evaluation of $src_lang-to-$tgt_lang translation using:\n"; 157 | my $cum_seg = 0; 158 | foreach my $doc (sort keys %eval_docs) { 159 | $cum_seg += @{$eval_docs{$doc}{SEGS}}; 160 | } 161 | print " src set \"$src_id\" (", scalar keys %eval_docs, " docs, $cum_seg segs)\n"; 162 | print " ref set \"$ref_id\" (", scalar keys %ref_data, " refs)\n"; 163 | print " tst set \"$tst_id\" (", scalar keys %tst_data, " systems)\n\n"; 164 | 165 | foreach my $sys (sort @tst_sys) { 166 | for (my $n=1; $n<=$max_Ngram; $n++) { 167 | $NISTmt{$n}{$sys}{cum} = 0; 168 | $NISTmt{$n}{$sys}{ind} = 0; 169 | $BLEUmt{$n}{$sys}{cum} = 0; 170 | $BLEUmt{$n}{$sys}{ind} = 0; 171 | } 172 | 173 | if (($METHOD eq "BOTH") || ($METHOD eq "NIST")) { 174 | $method="NIST"; 175 | score_system ($sys, %NISTmt); 176 | } 177 | if (($METHOD eq "BOTH") || ($METHOD eq "BLEU")) { 178 | $method="BLEU"; 179 | score_system ($sys, %BLEUmt); 180 | } 181 | } 182 | 183 | ###### 184 | printout_report (); 185 | 186 | ($date, $time) = date_time_stamp(); 187 | print "MT evaluation scorer ended on $date at $time\n"; 188 | 189 | exit 0; 190 | 191 | ################################# 192 | 193 | sub get_source_info { 194 | 195 | my ($file) = @_; 196 | my ($name, $id, $src, $doc); 197 | my ($data, $tag, $span); 198 | 199 | 200 | #read data from file 201 | open (FILE, $file) or die "\nUnable to open translation data file '$file'", $usage; 202 | binmode FILE; 203 | $data .= $_ while ; 204 | close (FILE); 205 | 206 | #get source set info 207 | die "\n\nFATAL INPUT ERROR: no 'src_set' tag in src_file '$file'\n\n" 208 | unless ($tag, $span, $data) = extract_sgml_tag_and_span ("SrcSet", $data); 209 | 210 | die "\n\nFATAL INPUT ERROR: no tag attribute '$name' in file '$file'\n\n" 211 | unless ($id) = extract_sgml_tag_attribute ($name="SetID", $tag); 212 | 213 | die "\n\nFATAL INPUT ERROR: no tag attribute '$name' in file '$file'\n\n" 214 | unless ($src) = extract_sgml_tag_attribute ($name="SrcLang", $tag); 215 | die "\n\nFATAL INPUT ERROR: $name ('$src') in file '$file' inconsistent\n" 216 | ." with $name in previous input data ('$src_lang')\n\n" 217 | unless (not defined $src_lang or $src eq $src_lang); 218 | $src_lang = $src; 219 | 220 | #get doc info -- ID and # of segs 221 | $data = $span; 222 | while (($tag, $span, $data) = extract_sgml_tag_and_span ("Doc", $data)) { 223 | die "\n\nFATAL INPUT ERROR: no tag attribute '$name' in file '$file'\n\n" 224 | unless ($doc) = extract_sgml_tag_attribute ($name="DocID", $tag); 225 | die "\n\nFATAL INPUT ERROR: duplicate '$name' in file '$file'\n\n" 226 | if defined $eval_docs{$doc}; 227 | $span =~ s/[\s\n\r]+/ /g; # concatenate records 228 | my $jseg=0, my $seg_data = $span; 229 | while (($tag, $span, $seg_data) = extract_sgml_tag_and_span ("Seg", $seg_data)) { 230 | ($eval_docs{$doc}{SEGS}[$jseg++]) = NormalizeText ($span); 231 | } 232 | die "\n\nFATAL INPUT ERROR: no segments in document '$doc' in file '$file'\n\n" 233 | if $jseg == 0; 234 | } 235 | die "\n\nFATAL INPUT ERROR: no documents in file '$file'\n\n" 236 | unless keys %eval_docs > 0; 237 | return $id; 238 | } 239 | 240 | ################################# 241 | 242 | sub get_MT_data { 243 | 244 | my ($docs, $set_tag, $file) = @_; 245 | my ($name, $id, $src, $tgt, $sys, $doc); 246 | my ($tag, $span, $data); 247 | 248 | #read data from file 249 | open (FILE, $file) or die "\nUnable to open translation data file '$file'", $usage; 250 | binmode FILE; 251 | $data .= $_ while ; 252 | close (FILE); 253 | 254 | #get tag info 255 | while (($tag, $span, $data) = extract_sgml_tag_and_span ($set_tag, $data)) { 256 | die "\n\nFATAL INPUT ERROR: no tag attribute '$name' in file '$file'\n\n" unless 257 | ($id) = extract_sgml_tag_attribute ($name="SetID", $tag); 258 | 259 | die "\n\nFATAL INPUT ERROR: no tag attribute '$name' in file '$file'\n\n" unless 260 | ($src) = extract_sgml_tag_attribute ($name="SrcLang", $tag); 261 | die "\n\nFATAL INPUT ERROR: $name ('$src') in file '$file' inconsistent\n" 262 | ." with $name of source ('$src_lang')\n\n" 263 | unless $src eq $src_lang; 264 | 265 | die "\n\nFATAL INPUT ERROR: no tag attribute '$name' in file '$file'\n\n" unless 266 | ($tgt) = extract_sgml_tag_attribute ($name="TrgLang", $tag); 267 | die "\n\nFATAL INPUT ERROR: $name ('$tgt') in file '$file' inconsistent\n" 268 | ." with $name of the evaluation ('$tgt_lang')\n\n" 269 | unless (not defined $tgt_lang or $tgt eq $tgt_lang); 270 | $tgt_lang = $tgt; 271 | 272 | my $mtdata = $span; 273 | while (($tag, $span, $mtdata) = extract_sgml_tag_and_span ("Doc", $mtdata)) { 274 | die "\n\nFATAL INPUT ERROR: no tag attribute '$name' in file '$file'\n\n" unless 275 | (my $sys) = extract_sgml_tag_attribute ($name="SysID", $tag); 276 | 277 | die "\n\nFATAL INPUT ERROR: no tag attribute '$name' in file '$file'\n\n" unless 278 | $doc = extract_sgml_tag_attribute ($name="DocID", $tag); 279 | 280 | die "\n\nFATAL INPUT ERROR: document '$doc' for system '$sys' in file '$file'\n" 281 | ." previously loaded from file '$docs->{$sys}{$doc}{FILE}'\n\n" 282 | unless (not defined $docs->{$sys}{$doc}); 283 | 284 | $span =~ s/[\s\n\r]+/ /g; # concatenate records 285 | my $jseg=0, my $seg_data = $span; 286 | while (($tag, $span, $seg_data) = extract_sgml_tag_and_span ("Seg", $seg_data)) { 287 | ($docs->{$sys}{$doc}{SEGS}[$jseg++]) = NormalizeText ($span); 288 | } 289 | die "\n\nFATAL INPUT ERROR: no segments in document '$doc' in file '$file'\n\n" 290 | if $jseg == 0; 291 | $docs->{$sys}{$doc}{FILE} = $file; 292 | } 293 | } 294 | return $id; 295 | } 296 | 297 | ################################# 298 | 299 | sub check_MT_data { 300 | 301 | @tst_sys = sort keys %tst_data; 302 | @ref_sys = sort keys %ref_data; 303 | 304 | #every evaluation document must be represented for every system and every reference 305 | foreach my $doc (sort keys %eval_docs) { 306 | my $nseg_source = @{$eval_docs{$doc}{SEGS}}; 307 | foreach my $sys (@tst_sys) { 308 | die "\n\nFATAL ERROR: no document '$doc' for system '$sys'\n\n" 309 | unless defined $tst_data{$sys}{$doc}; 310 | my $nseg = @{$tst_data{$sys}{$doc}{SEGS}}; 311 | die "\n\nFATAL ERROR: translated documents must contain the same # of segments as the source, but\n" 312 | ." document '$doc' for system '$sys' contains $nseg segments, while\n" 313 | ." the source document contains $nseg_source segments.\n\n" 314 | unless $nseg == $nseg_source; 315 | } 316 | 317 | foreach my $sys (@ref_sys) { 318 | die "\n\nFATAL ERROR: no document '$doc' for reference '$sys'\n\n" 319 | unless defined $ref_data{$sys}{$doc}; 320 | my $nseg = @{$ref_data{$sys}{$doc}{SEGS}}; 321 | die "\n\nFATAL ERROR: translated documents must contain the same # of segments as the source, but\n" 322 | ." document '$doc' for system '$sys' contains $nseg segments, while\n" 323 | ." the source document contains $nseg_source segments.\n\n" 324 | unless $nseg == $nseg_source; 325 | } 326 | } 327 | } 328 | 329 | ################################# 330 | 331 | sub compute_ngram_info { 332 | 333 | my ($ref, $doc, $seg); 334 | my (@wrds, $tot_wrds, %ngrams, $ngram, $mgram); 335 | my (%ngram_count, @tot_ngrams); 336 | 337 | foreach $ref (keys %ref_data) { 338 | foreach $doc (keys %{$ref_data{$ref}}) { 339 | foreach $seg (@{$ref_data{$ref}{$doc}{SEGS}}) { 340 | @wrds = split /\s+/, $seg; 341 | $tot_wrds += @wrds; 342 | %ngrams = %{Words2Ngrams (@wrds)}; 343 | foreach $ngram (keys %ngrams) { 344 | $ngram_count{$ngram} += $ngrams{$ngram}; 345 | } 346 | } 347 | } 348 | } 349 | 350 | foreach $ngram (keys %ngram_count) { 351 | @wrds = split / /, $ngram; 352 | pop @wrds, $mgram = join " ", @wrds; 353 | $ngram_info{$ngram} = - log 354 | ($mgram ? $ngram_count{$ngram}/$ngram_count{$mgram} 355 | : $ngram_count{$ngram}/$tot_wrds) / log 2; 356 | if (defined $opt_x and $opt_x eq "ngram info") { 357 | @wrds = split / /, $ngram; 358 | printf "ngram info:%9.4f%6d%6d%8d%3d %s\n", $ngram_info{$ngram}, $ngram_count{$ngram}, 359 | $mgram ? $ngram_count{$mgram} : $tot_wrds, $tot_wrds, scalar @wrds, $ngram; 360 | } 361 | } 362 | } 363 | 364 | ################################# 365 | 366 | sub score_system { 367 | 368 | my ($sys, $ref, $doc, %SCOREmt); 369 | ($sys, %SCOREmt) = @_; 370 | my ($shortest_ref_length, $match_cnt, $tst_cnt, $ref_cnt, $tst_info, $ref_info); 371 | my ($cum_ref_length, @cum_match, @cum_tst_cnt, @cum_ref_cnt, @cum_tst_info, @cum_ref_info); 372 | 373 | $cum_ref_length = 0; 374 | for (my $j=1; $j<=$max_Ngram; $j++) { 375 | $cum_match[$j] = $cum_tst_cnt[$j] = $cum_ref_cnt[$j] = $cum_tst_info[$j] = $cum_ref_info[$j] = 0; 376 | } 377 | 378 | foreach $doc (sort keys %eval_docs) { 379 | ($shortest_ref_length, $match_cnt, $tst_cnt, $ref_cnt, $tst_info, $ref_info) = score_document ($sys, $doc); 380 | 381 | #output document summary score 382 | if (($detail >= 1 ) && ($METHOD eq "NIST")) { 383 | my %DOCmt = (); 384 | printf "$method score using 5-grams = %.4f for system \"$sys\" on document \"$doc\" (%d segments, %d words)\n", 385 | nist_score (scalar @ref_sys, $match_cnt, $tst_cnt, $ref_cnt, $tst_info, $ref_info, $sys, %DOCmt), 386 | scalar @{$tst_data{$sys}{$doc}{SEGS}}, $tst_cnt->[1]; 387 | } 388 | if (($detail >= 1 ) && ($METHOD eq "BLEU")) { 389 | my %DOCmt = (); 390 | printf "$method score using 4-grams = %.4f for system \"$sys\" on document \"$doc\" (%d segments, %d words)\n", 391 | bleu_score($shortest_ref_length, $match_cnt, $tst_cnt, $sys, %DOCmt), 392 | scalar @{$tst_data{$sys}{$doc}{SEGS}}, $tst_cnt->[1]; 393 | } 394 | 395 | $cum_ref_length += $shortest_ref_length; 396 | for (my $j=1; $j<=$max_Ngram; $j++) { 397 | $cum_match[$j] += $match_cnt->[$j]; 398 | $cum_tst_cnt[$j] += $tst_cnt->[$j]; 399 | $cum_ref_cnt[$j] += $ref_cnt->[$j]; 400 | $cum_tst_info[$j] += $tst_info->[$j]; 401 | $cum_ref_info[$j] += $ref_info->[$j]; 402 | printf "document info: $sys $doc %d-gram %d %d %d %9.4f %9.4f\n", $j, $match_cnt->[$j], 403 | $tst_cnt->[$j], $ref_cnt->[$j], $tst_info->[$j], $ref_info->[$j] 404 | if (defined $opt_x and $opt_x eq "document info"); 405 | } 406 | } 407 | 408 | #x #output system summary score 409 | #x printf "$method score = %.4f for system \"$sys\"\n", 410 | #x $method eq "BLEU" ? bleu_score($cum_ref_length, \@cum_match, \@cum_tst_cnt) : 411 | #x nist_score (scalar @ref_sys, \@cum_match, \@cum_tst_cnt, \@cum_ref_cnt, \@cum_tst_info, \@cum_ref_info, $sys, %SCOREmt); 412 | if ($method eq "BLEU") { 413 | bleu_score($cum_ref_length, \@cum_match, \@cum_tst_cnt, $sys, %SCOREmt); 414 | } 415 | if ($method eq "NIST") { 416 | nist_score (scalar @ref_sys, \@cum_match, \@cum_tst_cnt, \@cum_ref_cnt, \@cum_tst_info, \@cum_ref_info, $sys, %SCOREmt); 417 | } 418 | } 419 | 420 | ################################# 421 | 422 | sub score_document { 423 | 424 | my ($sys, $ref, $doc); 425 | ($sys, $doc) = @_; 426 | my ($shortest_ref_length, $match_cnt, $tst_cnt, $ref_cnt, $tst_info, $ref_info); 427 | my ($cum_ref_length, @cum_match, @cum_tst_cnt, @cum_ref_cnt, @cum_tst_info, @cum_ref_info); 428 | 429 | $cum_ref_length = 0; 430 | for (my $j=1; $j<=$max_Ngram; $j++) { 431 | $cum_match[$j] = $cum_tst_cnt[$j] = $cum_ref_cnt[$j] = $cum_tst_info[$j] = $cum_ref_info[$j] = 0; 432 | } 433 | 434 | #score each segment 435 | for (my $jseg=0; $jseg<@{$tst_data{$sys}{$doc}{SEGS}}; $jseg++) { 436 | my @ref_segments = (); 437 | foreach $ref (@ref_sys) { 438 | push @ref_segments, $ref_data{$ref}{$doc}{SEGS}[$jseg]; 439 | printf "ref '$ref', seg %d: %s\n", $jseg+1, $ref_data{$ref}{$doc}{SEGS}[$jseg] 440 | if $detail >= 3; 441 | } 442 | printf "sys '$sys', seg %d: %s\n", $jseg+1, $tst_data{$sys}{$doc}{SEGS}[$jseg] 443 | if $detail >= 3; 444 | ($shortest_ref_length, $match_cnt, $tst_cnt, $ref_cnt, $tst_info, $ref_info) = 445 | score_segment ($tst_data{$sys}{$doc}{SEGS}[$jseg], @ref_segments); 446 | 447 | #output segment summary score 448 | #x printf "$method score = %.4f for system \"$sys\" on segment %d of document \"$doc\" (%d words)\n", 449 | #x $method eq "BLEU" ? bleu_score($shortest_ref_length, $match_cnt, $tst_cnt) : 450 | #x nist_score (scalar @ref_sys, $match_cnt, $tst_cnt, $ref_cnt, $tst_info, $ref_info), 451 | #x $jseg+1, $tst_cnt->[1] 452 | #x if $detail >= 2; 453 | if (($detail >=2) && ($METHOD eq "BLEU")) { 454 | my %DOCmt = (); 455 | printf " $method score using 4-grams = %.4f for system \"$sys\" on segment %d of document \"$doc\" (%d words)\n", 456 | bleu_score($shortest_ref_length, $match_cnt, $tst_cnt, $sys, %DOCmt), $jseg+1, $tst_cnt->[1]; 457 | } 458 | if (($detail >=2) && ($METHOD eq "NIST")) { 459 | my %DOCmt = (); 460 | printf " $method score using 5-grams = %.4f for system \"$sys\" on segment %d of document \"$doc\" (%d words)\n", 461 | nist_score (scalar @ref_sys, $match_cnt, $tst_cnt, $ref_cnt, $tst_info, $ref_info, $sys, %DOCmt), $jseg+1, $tst_cnt->[1]; 462 | } 463 | 464 | 465 | $cum_ref_length += $shortest_ref_length; 466 | for (my $j=1; $j<=$max_Ngram; $j++) { 467 | $cum_match[$j] += $match_cnt->[$j]; 468 | $cum_tst_cnt[$j] += $tst_cnt->[$j]; 469 | $cum_ref_cnt[$j] += $ref_cnt->[$j]; 470 | $cum_tst_info[$j] += $tst_info->[$j]; 471 | $cum_ref_info[$j] += $ref_info->[$j]; 472 | } 473 | } 474 | return ($cum_ref_length, [@cum_match], [@cum_tst_cnt], [@cum_ref_cnt], [@cum_tst_info], [@cum_ref_info]); 475 | } 476 | 477 | ################################# 478 | 479 | sub score_segment { 480 | 481 | my ($tst_seg, @ref_segs) = @_; 482 | my (@tst_wrds, %tst_ngrams, @match_count, @tst_count, @tst_info); 483 | my (@ref_wrds, $ref_seg, %ref_ngrams, %ref_ngrams_max, @ref_count, @ref_info); 484 | my ($ngram); 485 | my (@nwrds_ref); 486 | my $shortest_ref_length; 487 | 488 | for (my $j=1; $j<= $max_Ngram; $j++) { 489 | $match_count[$j] = $tst_count[$j] = $ref_count[$j] = $tst_info[$j] = $ref_info[$j] = 0; 490 | } 491 | 492 | # get the ngram counts for the test segment 493 | @tst_wrds = split /\s+/, $tst_seg; 494 | %tst_ngrams = %{Words2Ngrams (@tst_wrds)}; 495 | for (my $j=1; $j<=$max_Ngram; $j++) { # compute ngram counts 496 | $tst_count[$j] = $j<=@tst_wrds ? (@tst_wrds - $j + 1) : 0; 497 | } 498 | 499 | # get the ngram counts for the reference segments 500 | foreach $ref_seg (@ref_segs) { 501 | @ref_wrds = split /\s+/, $ref_seg; 502 | %ref_ngrams = %{Words2Ngrams (@ref_wrds)}; 503 | foreach $ngram (keys %ref_ngrams) { # find the maximum # of occurrences 504 | my @wrds = split / /, $ngram; 505 | $ref_info[@wrds] += $ngram_info{$ngram}; 506 | $ref_ngrams_max{$ngram} = defined $ref_ngrams_max{$ngram} ? 507 | max ($ref_ngrams_max{$ngram}, $ref_ngrams{$ngram}) : 508 | $ref_ngrams{$ngram}; 509 | } 510 | for (my $j=1; $j<=$max_Ngram; $j++) { # update ngram counts 511 | $ref_count[$j] += $j<=@ref_wrds ? (@ref_wrds - $j + 1) : 0; 512 | } 513 | $shortest_ref_length = scalar @ref_wrds # find the shortest reference segment 514 | if (not defined $shortest_ref_length) or @ref_wrds < $shortest_ref_length; 515 | } 516 | 517 | # accumulate scoring stats for tst_seg ngrams that match ref_seg ngrams 518 | foreach $ngram (keys %tst_ngrams) { 519 | next unless defined $ref_ngrams_max{$ngram}; 520 | my @wrds = split / /, $ngram; 521 | $tst_info[@wrds] += $ngram_info{$ngram} * min($tst_ngrams{$ngram},$ref_ngrams_max{$ngram}); 522 | $match_count[@wrds] += my $count = min($tst_ngrams{$ngram},$ref_ngrams_max{$ngram}); 523 | printf "%.2f info for each of $count %d-grams = '%s'\n", $ngram_info{$ngram}, scalar @wrds, $ngram 524 | if $detail >= 3; 525 | } 526 | 527 | return ($shortest_ref_length, [@match_count], [@tst_count], [@ref_count], [@tst_info], [@ref_info]); 528 | } 529 | 530 | ################################# 531 | 532 | sub bleu_score { 533 | 534 | my ($shortest_ref_length, $matching_ngrams, $tst_ngrams, $sys, %SCOREmt) = @_; 535 | 536 | my $score = 0; 537 | my $iscore = 0; 538 | my $len_score = min (0, 1-$shortest_ref_length/$tst_ngrams->[1]); 539 | 540 | for (my $j=1; $j<=$max_Ngram; $j++) { 541 | if ($matching_ngrams->[$j] == 0) { 542 | $SCOREmt{$j}{$sys}{cum}=0; 543 | } else { 544 | # Cumulative N-Gram score 545 | $score += log ($matching_ngrams->[$j]/$tst_ngrams->[$j]); 546 | $SCOREmt{$j}{$sys}{cum} = exp($score/$j + $len_score); 547 | # Individual N-Gram score 548 | $iscore = log ($matching_ngrams->[$j]/$tst_ngrams->[$j]); 549 | $SCOREmt{$j}{$sys}{ind} = exp($iscore); 550 | } 551 | } 552 | return $SCOREmt{4}{$sys}{cum}; 553 | } 554 | 555 | ################################# 556 | 557 | sub nist_score { 558 | 559 | my ($nsys, $matching_ngrams, $tst_ngrams, $ref_ngrams, $tst_info, $ref_info, $sys, %SCOREmt) = @_; 560 | 561 | my $score = 0; 562 | my $iscore = 0; 563 | 564 | 565 | for (my $n=1; $n<=$max_Ngram; $n++) { 566 | $score += $tst_info->[$n]/max($tst_ngrams->[$n],1); 567 | $SCOREmt{$n}{$sys}{cum} = $score * nist_length_penalty($tst_ngrams->[1]/($ref_ngrams->[1]/$nsys)); 568 | 569 | $iscore = $tst_info->[$n]/max($tst_ngrams->[$n],1); 570 | $SCOREmt{$n}{$sys}{ind} = $iscore * nist_length_penalty($tst_ngrams->[1]/($ref_ngrams->[1]/$nsys)); 571 | } 572 | return $SCOREmt{5}{$sys}{cum}; 573 | } 574 | 575 | ################################# 576 | 577 | sub Words2Ngrams { #convert a string of words to an Ngram count hash 578 | 579 | my %count = (); 580 | 581 | for (; @_; shift) { 582 | my ($j, $ngram, $word); 583 | for ($j=0; $j<$max_Ngram and defined($word=$_[$j]); $j++) { 584 | $ngram .= defined $ngram ? " $word" : $word; 585 | $count{$ngram}++; 586 | } 587 | } 588 | return {%count}; 589 | } 590 | 591 | ################################# 592 | 593 | sub NormalizeText { 594 | my ($norm_text) = @_; 595 | 596 | # language-independent part: 597 | $norm_text =~ s///g; # strip "skipped" tags 598 | $norm_text =~ s/-\n//g; # strip end-of-line hyphenation and join lines 599 | $norm_text =~ s/\n/ /g; # join lines 600 | $norm_text =~ s/"/"/g; # convert SGML tag for quote to " 601 | $norm_text =~ s/&/&/g; # convert SGML tag for ampersand to & 602 | $norm_text =~ s/</ 603 | $norm_text =~ s/>/>/g; # convert SGML tag for greater-than to < 604 | 605 | # language-dependent part (assuming Western languages): 606 | $norm_text = " $norm_text "; 607 | $norm_text =~ tr/[A-Z]/[a-z]/ unless $preserve_case; 608 | $norm_text =~ s/([\{-\~\[-\` -\&\(-\+\:-\@\/])/ $1 /g; # tokenize punctuation 609 | $norm_text =~ s/([^0-9])([\.,])/$1 $2 /g; # tokenize period and comma unless preceded by a digit 610 | $norm_text =~ s/([\.,])([^0-9])/ $1 $2/g; # tokenize period and comma unless followed by a digit 611 | $norm_text =~ s/([0-9])(-)/$1 $2 /g; # tokenize dash when preceded by a digit 612 | $norm_text =~ s/\s+/ /g; # one space only between words 613 | $norm_text =~ s/^\s+//; # no leading space 614 | $norm_text =~ s/\s+$//; # no trailing space 615 | 616 | return $norm_text; 617 | } 618 | 619 | ################################# 620 | 621 | sub nist_length_penalty { 622 | 623 | my ($ratio) = @_; 624 | return 1 if $ratio >= 1; 625 | return 0 if $ratio <= 0; 626 | my $ratio_x = 1.5; 627 | my $score_x = 0.5; 628 | my $beta = -log($score_x)/log($ratio_x)/log($ratio_x); 629 | return exp (-$beta*log($ratio)*log($ratio)); 630 | } 631 | 632 | ################################# 633 | 634 | sub date_time_stamp { 635 | 636 | my ($sec, $min, $hour, $mday, $mon, $year, $wday, $yday, $isdst) = localtime(); 637 | my @months = qw(Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec); 638 | my ($date, $time); 639 | 640 | $time = sprintf "%2.2d:%2.2d:%2.2d", $hour, $min, $sec; 641 | $date = sprintf "%4.4s %3.3s %s", 1900+$year, $months[$mon], $mday; 642 | return ($date, $time); 643 | } 644 | 645 | ################################# 646 | 647 | sub extract_sgml_tag_and_span { 648 | 649 | my ($name, $data) = @_; 650 | 651 | ($data =~ m|<$name\s*([^>]*)>(.*?)(.*)|si) ? ($1, $2, $3) : (); 652 | } 653 | 654 | ################################# 655 | 656 | sub extract_sgml_tag_attribute { 657 | 658 | my ($name, $data) = @_; 659 | 660 | ($data =~ m|$name\s*=\s*\"([^\"]*)\"|si) ? ($1) : (); 661 | } 662 | 663 | ################################# 664 | 665 | sub max { 666 | 667 | my ($max, $next); 668 | 669 | return unless defined ($max=pop); 670 | while (defined ($next=pop)) { 671 | $max = $next if $next > $max; 672 | } 673 | return $max; 674 | } 675 | 676 | ################################# 677 | 678 | sub min { 679 | 680 | my ($min, $next); 681 | 682 | return unless defined ($min=pop); 683 | while (defined ($next=pop)) { 684 | $min = $next if $next < $min; 685 | } 686 | return $min; 687 | } 688 | 689 | ################################# 690 | 691 | sub printout_report 692 | { 693 | 694 | if ( $METHOD eq "BOTH" ) { 695 | foreach my $sys (sort @tst_sys) { 696 | printf "NIST score = %2.4f BLEU score = %.4f for system \"$sys\"\n",$NISTmt{5}{$sys}{cum},$BLEUmt{4}{$sys}{cum}; 697 | } 698 | } elsif ($METHOD eq "NIST" ) { 699 | foreach my $sys (sort @tst_sys) { 700 | printf "NIST score = %2.4f for system \"$sys\"\n",$NISTmt{5}{$sys}{cum}; 701 | } 702 | } elsif ($METHOD eq "BLEU" ) { 703 | foreach my $sys (sort @tst_sys) { 704 | printf "\nBLEU score = %.4f for system \"$sys\"\n",$BLEUmt{4}{$sys}{cum}; 705 | } 706 | } 707 | 708 | 709 | printf "\n# ------------------------------------------------------------------------\n\n"; 710 | printf "Individual N-gram scoring\n"; 711 | printf " 1-gram 2-gram 3-gram 4-gram 5-gram 6-gram 7-gram 8-gram 9-gram\n"; 712 | printf " ------ ------ ------ ------ ------ ------ ------ ------ ------\n"; 713 | 714 | if (( $METHOD eq "BOTH" ) || ($METHOD eq "NIST")) { 715 | foreach my $sys (sort @tst_sys) { 716 | printf " NIST:"; 717 | for (my $i=1; $i<=$max_Ngram; $i++) { 718 | printf " %2.4f ",$NISTmt{$i}{$sys}{ind} 719 | } 720 | printf " \"$sys\"\n"; 721 | } 722 | printf "\n"; 723 | } 724 | 725 | if (( $METHOD eq "BOTH" ) || ($METHOD eq "BLEU")) { 726 | foreach my $sys (sort @tst_sys) { 727 | printf " BLEU:"; 728 | for (my $i=1; $i<=$max_Ngram; $i++) { 729 | printf " %2.4f ",$BLEUmt{$i}{$sys}{ind} 730 | } 731 | printf " \"$sys\"\n"; 732 | } 733 | } 734 | 735 | printf "\n# ------------------------------------------------------------------------\n"; 736 | printf "Cumulative N-gram scoring\n"; 737 | printf " 1-gram 2-gram 3-gram 4-gram 5-gram 6-gram 7-gram 8-gram 9-gram\n"; 738 | printf " ------ ------ ------ ------ ------ ------ ------ ------ ------\n"; 739 | 740 | if (( $METHOD eq "BOTH" ) || ($METHOD eq "NIST")) { 741 | foreach my $sys (sort @tst_sys) { 742 | printf " NIST:"; 743 | for (my $i=1; $i<=$max_Ngram; $i++) { 744 | printf " %2.4f ",$NISTmt{$i}{$sys}{cum} 745 | } 746 | printf " \"$sys\"\n"; 747 | } 748 | } 749 | printf "\n"; 750 | 751 | 752 | if (( $METHOD eq "BOTH" ) || ($METHOD eq "BLEU")) { 753 | foreach my $sys (sort @tst_sys) { 754 | printf " BLEU:"; 755 | for (my $i=1; $i<=$max_Ngram; $i++) { 756 | printf " %2.4f ",$BLEUmt{$i}{$sys}{cum} 757 | } 758 | printf " \"$sys\"\n"; 759 | } 760 | } 761 | } 762 | -------------------------------------------------------------------------------- /code/eval/multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | # add additional references explicitly specified on the command line 35 | shift; 36 | foreach my $stem (@ARGV) { 37 | &add_to_ref($stem,\@REF) if -e $stem; 38 | } 39 | 40 | 41 | 42 | sub add_to_ref { 43 | my ($file,$REF) = @_; 44 | my $s=0; 45 | if ($file =~ /.gz$/) { 46 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 47 | } else { 48 | open(REF,$file) or die "Can't read $file"; 49 | } 50 | while() { 51 | chop; 52 | push @{$$REF[$s++]}, $_; 53 | } 54 | close(REF); 55 | } 56 | 57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 58 | my $s=0; 59 | while() { 60 | chop; 61 | $_ = lc if $lowercase; 62 | my @WORD = split; 63 | my %REF_NGRAM = (); 64 | my $length_translation_this_sentence = scalar(@WORD); 65 | my ($closest_diff,$closest_length) = (9999,9999); 66 | foreach my $reference (@{$REF[$s]}) { 67 | # print "$s $_ <=> $reference\n"; 68 | $reference = lc($reference) if $lowercase; 69 | my @WORD = split(' ',$reference); 70 | my $length = scalar(@WORD); 71 | my $diff = abs($length_translation_this_sentence-$length); 72 | if ($diff < $closest_diff) { 73 | $closest_diff = $diff; 74 | $closest_length = $length; 75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 76 | } elsif ($diff == $closest_diff) { 77 | $closest_length = $length if $length < $closest_length; 78 | # from two references with the same closeness to me 79 | # take the *shorter* into account, not the "first" one. 80 | } 81 | for(my $n=1;$n<=4;$n++) { 82 | my %REF_NGRAM_N = (); 83 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 84 | my $ngram = "$n"; 85 | for(my $w=0;$w<$n;$w++) { 86 | $ngram .= " ".$WORD[$start+$w]; 87 | } 88 | $REF_NGRAM_N{$ngram}++; 89 | } 90 | foreach my $ngram (keys %REF_NGRAM_N) { 91 | if (!defined($REF_NGRAM{$ngram}) || 92 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 93 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 94 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 95 | } 96 | } 97 | } 98 | } 99 | $length_translation += $length_translation_this_sentence; 100 | $length_reference += $closest_length; 101 | for(my $n=1;$n<=4;$n++) { 102 | my %T_NGRAM = (); 103 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 104 | my $ngram = "$n"; 105 | for(my $w=0;$w<$n;$w++) { 106 | $ngram .= " ".$WORD[$start+$w]; 107 | } 108 | $T_NGRAM{$ngram}++; 109 | } 110 | foreach my $ngram (keys %T_NGRAM) { 111 | $ngram =~ /^(\d+) /; 112 | my $n = $1; 113 | # my $corr = 0; 114 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 115 | $TOTAL[$n] += $T_NGRAM{$ngram}; 116 | if (defined($REF_NGRAM{$ngram})) { 117 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 118 | $CORRECT[$n] += $T_NGRAM{$ngram}; 119 | # $corr = $T_NGRAM{$ngram}; 120 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 121 | } 122 | else { 123 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 124 | # $corr = $REF_NGRAM{$ngram}; 125 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 126 | } 127 | } 128 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 129 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 130 | } 131 | } 132 | $s++; 133 | } 134 | my $brevity_penalty = 1; 135 | my $bleu = 0; 136 | 137 | my @bleu=(); 138 | 139 | for(my $n=1;$n<=4;$n++) { 140 | if (defined ($TOTAL[$n])){ 141 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 142 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 143 | }else{ 144 | $bleu[$n]=0; 145 | } 146 | } 147 | 148 | if ($length_reference==0){ 149 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 150 | exit(1); 151 | } 152 | 153 | if ($length_translation<$length_reference) { 154 | $brevity_penalty = exp(1-$length_reference/$length_translation); 155 | } 156 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 157 | my_log( $bleu[2] ) + 158 | my_log( $bleu[3] ) + 159 | my_log( $bleu[4] ) ) / 4) ; 160 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 161 | 100*$bleu, 162 | 100*$bleu[1], 163 | 100*$bleu[2], 164 | 100*$bleu[3], 165 | 100*$bleu[4], 166 | $brevity_penalty, 167 | $length_translation / $length_reference, 168 | $length_translation, 169 | $length_reference; 170 | 171 | sub my_log { 172 | return -9999999999 unless $_[0]; 173 | return log($_[0]); 174 | } 175 | -------------------------------------------------------------------------------- /code/nmt.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Build a neural machine translation model with soft attention 3 | ''' 4 | 5 | ''' 6 | The main framework is shared from `dl4mt`. The proposed model `CAEncoder` is developped by Biao Zhang. 7 | For any quetion, welcome to contact me (zb@stu.xmu.edu.cn) 8 | ''' 9 | 10 | import theano 11 | import theano.tensor as tensor 12 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 13 | 14 | import cPickle as pkl 15 | import ipdb 16 | import numpy 17 | import copy 18 | 19 | import os 20 | import warnings 21 | import sys 22 | import time 23 | 24 | from collections import OrderedDict 25 | 26 | from data import TextIterator 27 | 28 | profile = False 29 | numpy.random.seed(1234) 30 | 31 | from eval.evaluate_nist import eval_trans as eval_nist 32 | from eval.evaluate_moses import eval_trans as eval_moses 33 | run = os.system 34 | root_path = os.path.dirname(os.path.realpath(__file__)) 35 | 36 | # push parameters to Theano shared variables 37 | def zipp(params, tparams): 38 | for kk, vv in params.iteritems(): 39 | tparams[kk].set_value(vv) 40 | 41 | 42 | # pull parameters from Theano shared variables 43 | def unzip(zipped): 44 | new_params = OrderedDict() 45 | for kk, vv in zipped.iteritems(): 46 | new_params[kk] = vv.get_value() 47 | return new_params 48 | 49 | 50 | # get the list of parameters: Note that tparams must be OrderedDict 51 | def itemlist(tparams): 52 | return [vv for kk, vv in tparams.iteritems()] 53 | 54 | 55 | # dropout 56 | def dropout_layer(state_before, use_noise, trng): 57 | proj = tensor.switch( 58 | use_noise, 59 | state_before * trng.binomial(state_before.shape, p=0.5, n=1, 60 | dtype=state_before.dtype), 61 | state_before * 0.5) 62 | return proj 63 | 64 | 65 | # make prefix-appended name 66 | def _p(pp, name): 67 | return '%s_%s' % (pp, name) 68 | 69 | 70 | # initialize Theano shared variables according to the initial parameters 71 | def init_tparams(params): 72 | tparams = OrderedDict() 73 | for kk, pp in params.iteritems(): 74 | tparams[kk] = theano.shared(params[kk], name=kk) 75 | return tparams 76 | 77 | 78 | # load parameters 79 | def load_params(path, params): 80 | pp = numpy.load(path) 81 | for kk, vv in params.iteritems(): 82 | if kk not in pp: 83 | warnings.warn('%s is not in the archive' % kk) 84 | continue 85 | if vv.shape != pp[kk].shape: 86 | warnings.warn('%s is not in the same shap' % kk) 87 | continue 88 | params[kk] = pp[kk] 89 | 90 | return params 91 | 92 | # layers: 'name': ('parameter initializer', 'feedforward') 93 | layers = {'ff': ('param_init_fflayer', 'fflayer'), 94 | 'gru': ('param_init_gru', 'gru_layer'), 95 | 'gru_cond': ('param_init_gru_cond', 'gru_cond_layer'), 96 | 'gru_cae': ('param_init_gru_cae', 'gru_cae_layer'), 97 | } 98 | 99 | 100 | def get_layer(name): 101 | fns = layers[name] 102 | return (eval(fns[0]), eval(fns[1])) 103 | 104 | 105 | # some utilities 106 | def ortho_weight(ndim): 107 | W = numpy.random.randn(ndim, ndim) 108 | u, s, v = numpy.linalg.svd(W) 109 | return u.astype('float32') 110 | 111 | 112 | def norm_weight(nin, nout=None, scale=0.01, ortho=True): 113 | if nout is None: 114 | nout = nin 115 | if nout == nin and ortho: 116 | W = ortho_weight(nin) 117 | else: 118 | W = scale * numpy.random.randn(nin, nout) 119 | return W.astype('float32') 120 | 121 | 122 | def tanh(x): 123 | return tensor.tanh(x) 124 | 125 | 126 | def linear(x): 127 | return x 128 | 129 | 130 | def concatenate(tensor_list, axis=0): 131 | """ 132 | Alternative implementation of `theano.tensor.concatenate`. 133 | This function does exactly the same thing, but contrary to Theano's own 134 | implementation, the gradient is implemented on the GPU. 135 | Backpropagating through `theano.tensor.concatenate` yields slowdowns 136 | because the inverse operation (splitting) needs to be done on the CPU. 137 | This implementation does not have that problem. 138 | :usage: 139 | >>> x, y = theano.tensor.matrices('x', 'y') 140 | >>> c = concatenate([x, y], axis=1) 141 | :parameters: 142 | - tensor_list : list 143 | list of Theano tensor expressions that should be concatenated. 144 | - axis : int 145 | the tensors will be joined along this axis. 146 | :returns: 147 | - out : tensor 148 | the concatenated tensor expression. 149 | """ 150 | concat_size = sum(tt.shape[axis] for tt in tensor_list) 151 | 152 | output_shape = () 153 | for k in range(axis): 154 | output_shape += (tensor_list[0].shape[k],) 155 | output_shape += (concat_size,) 156 | for k in range(axis + 1, tensor_list[0].ndim): 157 | output_shape += (tensor_list[0].shape[k],) 158 | 159 | out = tensor.zeros(output_shape) 160 | offset = 0 161 | for tt in tensor_list: 162 | indices = () 163 | for k in range(axis): 164 | indices += (slice(None),) 165 | indices += (slice(offset, offset + tt.shape[axis]),) 166 | for k in range(axis + 1, tensor_list[0].ndim): 167 | indices += (slice(None),) 168 | 169 | out = tensor.set_subtensor(out[indices], tt) 170 | offset += tt.shape[axis] 171 | 172 | return out 173 | 174 | 175 | # batch preparation 176 | def prepare_data(seqs_x, seqs_y, maxlen=None, n_words_src=30000, 177 | n_words=30000): 178 | # x: a list of sentences 179 | lengths_x = [len(s) for s in seqs_x] 180 | lengths_y = [len(s) for s in seqs_y] 181 | 182 | if maxlen is not None: 183 | new_seqs_x = [] 184 | new_seqs_y = [] 185 | new_lengths_x = [] 186 | new_lengths_y = [] 187 | for l_x, s_x, l_y, s_y in zip(lengths_x, seqs_x, lengths_y, seqs_y): 188 | if l_x < maxlen and l_y < maxlen: 189 | new_seqs_x.append(s_x) 190 | new_lengths_x.append(l_x) 191 | new_seqs_y.append(s_y) 192 | new_lengths_y.append(l_y) 193 | lengths_x = new_lengths_x 194 | seqs_x = new_seqs_x 195 | lengths_y = new_lengths_y 196 | seqs_y = new_seqs_y 197 | 198 | if len(lengths_x) < 1 or len(lengths_y) < 1: 199 | return None, None, None, None 200 | 201 | n_samples = len(seqs_x) 202 | maxlen_x = numpy.max(lengths_x) + 1 203 | maxlen_y = numpy.max(lengths_y) + 1 204 | 205 | x = numpy.zeros((maxlen_x, n_samples)).astype('int64') 206 | y = numpy.zeros((maxlen_y, n_samples)).astype('int64') 207 | x_mask = numpy.zeros((maxlen_x, n_samples)).astype('float32') 208 | y_mask = numpy.zeros((maxlen_y, n_samples)).astype('float32') 209 | for idx, [s_x, s_y] in enumerate(zip(seqs_x, seqs_y)): 210 | x[:lengths_x[idx], idx] = s_x 211 | x_mask[:lengths_x[idx]+1, idx] = 1. 212 | y[:lengths_y[idx], idx] = s_y 213 | y_mask[:lengths_y[idx]+1, idx] = 1. 214 | 215 | return x, x_mask, y, y_mask 216 | 217 | 218 | # feedforward layer: affine transformation + point-wise nonlinearity 219 | def param_init_fflayer(options, params, prefix='ff', nin=None, nout=None, 220 | ortho=True): 221 | if nin is None: 222 | nin = options['dim_proj'] 223 | if nout is None: 224 | nout = options['dim_proj'] 225 | params[_p(prefix, 'W')] = norm_weight(nin, nout, scale=0.01, ortho=ortho) 226 | params[_p(prefix, 'b')] = numpy.zeros((nout,)).astype('float32') 227 | 228 | return params 229 | 230 | 231 | def fflayer(tparams, state_below, options, prefix='rconv', 232 | activ='lambda x: tensor.tanh(x)', **kwargs): 233 | return eval(activ)( 234 | tensor.dot(state_below, tparams[_p(prefix, 'W')]) + 235 | tparams[_p(prefix, 'b')]) 236 | 237 | 238 | # GRU layer 239 | def param_init_gru(options, params, prefix='gru', nin=None, dim=None): 240 | if nin is None: 241 | nin = options['dim_proj'] 242 | if dim is None: 243 | dim = options['dim_proj'] 244 | 245 | # embedding to gates transformation weights, biases 246 | W = numpy.concatenate([norm_weight(nin, dim), 247 | norm_weight(nin, dim)], axis=1) 248 | params[_p(prefix, 'W')] = W 249 | params[_p(prefix, 'b')] = numpy.zeros((2 * dim,)).astype('float32') 250 | 251 | # recurrent transformation weights for gates 252 | U = numpy.concatenate([ortho_weight(dim), 253 | ortho_weight(dim)], axis=1) 254 | params[_p(prefix, 'U')] = U 255 | 256 | # embedding to hidden state proposal weights, biases 257 | Wx = norm_weight(nin, dim) 258 | params[_p(prefix, 'Wx')] = Wx 259 | params[_p(prefix, 'bx')] = numpy.zeros((dim,)).astype('float32') 260 | 261 | # recurrent transformation weights for hidden state proposal 262 | Ux = ortho_weight(dim) 263 | params[_p(prefix, 'Ux')] = Ux 264 | 265 | return params 266 | 267 | 268 | def gru_layer(tparams, state_below, options, prefix='gru', mask=None, 269 | **kwargs): 270 | nsteps = state_below.shape[0] 271 | if state_below.ndim == 3: 272 | n_samples = state_below.shape[1] 273 | else: 274 | n_samples = 1 275 | 276 | dim = tparams[_p(prefix, 'Ux')].shape[1] 277 | 278 | if mask is None: 279 | mask = tensor.alloc(1., state_below.shape[0], 1) 280 | 281 | # utility function to slice a tensor 282 | def _slice(_x, n, dim): 283 | if _x.ndim == 3: 284 | return _x[:, :, n*dim:(n+1)*dim] 285 | return _x[:, n*dim:(n+1)*dim] 286 | 287 | # state_below is the input word embeddings 288 | # input to the gates, concatenated 289 | state_below_ = tensor.dot(state_below, tparams[_p(prefix, 'W')]) + \ 290 | tparams[_p(prefix, 'b')] 291 | # input to compute the hidden state proposal 292 | state_belowx = tensor.dot(state_below, tparams[_p(prefix, 'Wx')]) + \ 293 | tparams[_p(prefix, 'bx')] 294 | 295 | # step function to be used by scan 296 | # arguments | sequences |outputs-info| non-seqs 297 | def _step_slice(m_, x_, xx_, h_, U, Ux): 298 | preact = tensor.dot(h_, U) 299 | preact += x_ 300 | 301 | # reset and update gates 302 | r = tensor.nnet.sigmoid(_slice(preact, 0, dim)) 303 | u = tensor.nnet.sigmoid(_slice(preact, 1, dim)) 304 | 305 | # compute the hidden state proposal 306 | preactx = tensor.dot(h_, Ux) 307 | preactx = preactx * r 308 | preactx = preactx + xx_ 309 | 310 | # hidden state proposal 311 | h = tensor.tanh(preactx) 312 | 313 | # leaky integrate and obtain next hidden state 314 | h = u * h_ + (1. - u) * h 315 | h = m_[:, None] * h + (1. - m_)[:, None] * h_ 316 | 317 | return h 318 | 319 | # prepare scan arguments 320 | seqs = [mask, state_below_, state_belowx] 321 | init_states = [tensor.alloc(0., n_samples, dim)] 322 | _step = _step_slice 323 | shared_vars = [tparams[_p(prefix, 'U')], 324 | tparams[_p(prefix, 'Ux')]] 325 | 326 | rval, updates = theano.scan(_step, 327 | sequences=seqs, 328 | outputs_info=init_states, 329 | non_sequences=shared_vars, 330 | name=_p(prefix, '_layers'), 331 | n_steps=nsteps, 332 | profile=profile, 333 | strict=True) 334 | rval = [rval] 335 | return rval 336 | 337 | # GRU Deep layer 338 | def param_init_gru_cae(options, params, prefix='gru', nin=None, dim=None, dimctx=None): 339 | if nin is None: 340 | nin = options['dim_proj'] 341 | if dim is None: 342 | dim = options['dim_proj'] 343 | if dimctx is None: 344 | dimctx = options['dim_proj'] 345 | 346 | # embedding to gates transformation weights, biases 347 | W = numpy.concatenate([norm_weight(nin, dim), 348 | norm_weight(nin, dim)], axis=1) 349 | params[_p(prefix, 'W')] = W 350 | params[_p(prefix, 'b')] = numpy.zeros((2 * dim,)).astype('float32') 351 | 352 | # recurrent transformation weights for gates 353 | U = numpy.concatenate([ortho_weight(dim), 354 | ortho_weight(dim)], axis=1) 355 | params[_p(prefix, 'U')] = U 356 | 357 | # embedding to hidden state proposal weights, biases 358 | Wx = norm_weight(nin, dim) 359 | params[_p(prefix, 'Wx')] = Wx 360 | params[_p(prefix, 'bx')] = numpy.zeros((dim,)).astype('float32') 361 | 362 | # recurrent transformation weights for hidden state proposal 363 | Ux = ortho_weight(dim) 364 | params[_p(prefix, 'Ux')] = Ux 365 | 366 | U_nl = numpy.concatenate([ortho_weight(dim), 367 | ortho_weight(dim)], axis=1) 368 | params[_p(prefix, 'U_nl')] = U_nl 369 | params[_p(prefix, 'b_nl')] = numpy.zeros((2 * dim,)).astype('float32') 370 | 371 | Ux_nl = ortho_weight(dim) 372 | params[_p(prefix, 'Ux_nl')] = Ux_nl 373 | params[_p(prefix, 'bx_nl')] = numpy.zeros((dim,)).astype('float32') 374 | 375 | # context to LSTM 376 | Wc = norm_weight(dimctx, dim*2) 377 | params[_p(prefix, 'Wc')] = Wc 378 | 379 | Wcx = norm_weight(dimctx, dim) 380 | params[_p(prefix, 'Wcx')] = Wcx 381 | 382 | return params 383 | 384 | 385 | def gru_cae_layer(tparams, state_below, options, prefix='gru', mask=None, context_below=None, 386 | **kwargs): 387 | nsteps = state_below.shape[0] 388 | if state_below.ndim == 3: 389 | if context_below: 390 | assert context_below.ndim == 3, 'Hi, context should have the same dimension as states' 391 | n_samples = state_below.shape[1] 392 | else: 393 | n_samples = 1 394 | if context_below is None: 395 | context_below = tensor.alloc(0, nsteps, n_samples, tparams[_p(prefix, 'Wc')].shape[0]) 396 | 397 | dim = tparams[_p(prefix, 'Ux')].shape[1] 398 | 399 | if mask is None: 400 | mask = tensor.alloc(1., state_below.shape[0], 1) 401 | 402 | # utility function to slice a tensor 403 | def _slice(_x, n, dim): 404 | if _x.ndim == 3: 405 | return _x[:, :, n*dim:(n+1)*dim] 406 | return _x[:, n*dim:(n+1)*dim] 407 | 408 | # state_below is the input word embeddings 409 | # input to the gates, concatenated 410 | state_below_ = tensor.dot(state_below, tparams[_p(prefix, 'W')]) + \ 411 | tparams[_p(prefix, 'b')] 412 | # input to compute the hidden state proposal 413 | state_belowx = tensor.dot(state_below, tparams[_p(prefix, 'Wx')]) + \ 414 | tparams[_p(prefix, 'bx')] 415 | 416 | # step function to be used by scan 417 | # arguments | sequences |outputs-info| non-seqs 418 | def _step_slice(m_, x_, xx_, ctx_, h_, U, Ux, Wc, Wcx, U_nl, Ux_nl, b_nl, bx_nl): 419 | preact = tensor.dot(h_, U) 420 | preact += x_ 421 | 422 | # reset and update gates 423 | r = tensor.nnet.sigmoid(_slice(preact, 0, dim)) 424 | u = tensor.nnet.sigmoid(_slice(preact, 1, dim)) 425 | 426 | # compute the hidden state proposal 427 | preactx = tensor.dot(h_, Ux) 428 | preactx = preactx * r 429 | preactx = preactx + xx_ 430 | 431 | # hidden state proposal 432 | h = tensor.tanh(preactx) 433 | 434 | # leaky integrate and obtain next hidden state 435 | h = u * h_ + (1. - u) * h 436 | h1 = m_[:, None] * h + (1. - m_)[:, None] * h_ 437 | 438 | preact2 = tensor.dot(h1, U_nl)+b_nl 439 | preact2 += tensor.dot(ctx_, Wc) 440 | preact2 = tensor.nnet.sigmoid(preact2) 441 | 442 | r2 = _slice(preact2, 0, dim) 443 | u2 = _slice(preact2, 1, dim) 444 | 445 | preactx2 = tensor.dot(h1, Ux_nl)+bx_nl 446 | preactx2 *= r2 447 | preactx2 += tensor.dot(ctx_, Wcx) 448 | 449 | h2 = tensor.tanh(preactx2) 450 | 451 | h2 = u2 * h1 + (1. - u2) * h2 452 | h2 = m_[:, None] * h2 + (1. - m_)[:, None] * h1 453 | 454 | return h2 455 | 456 | # prepare scan arguments 457 | seqs = [mask, state_below_, state_belowx, context_below] 458 | init_states = [tensor.alloc(0., n_samples, dim)] 459 | _step = _step_slice 460 | shared_vars = [tparams[_p(prefix, 'U')], 461 | tparams[_p(prefix, 'Ux')], 462 | tparams[_p(prefix, 'Wc')], 463 | tparams[_p(prefix, 'Wcx')], 464 | tparams[_p(prefix, 'U_nl')], 465 | tparams[_p(prefix, 'Ux_nl')], 466 | tparams[_p(prefix, 'b_nl')], 467 | tparams[_p(prefix, 'bx_nl')]] 468 | 469 | rval, updates = theano.scan(_step, 470 | sequences=seqs, 471 | outputs_info=init_states, 472 | non_sequences=shared_vars, 473 | name=_p(prefix, '_layers'), 474 | n_steps=nsteps, 475 | profile=profile, 476 | strict=True) 477 | rval = [rval] 478 | return rval 479 | 480 | 481 | # Conditional GRU layer with Attention 482 | def param_init_gru_cond(options, params, prefix='gru_cond', 483 | nin=None, dim=None, dimctx=None, 484 | nin_nonlin=None, dim_nonlin=None): 485 | if nin is None: 486 | nin = options['dim'] 487 | if dim is None: 488 | dim = options['dim'] 489 | if dimctx is None: 490 | dimctx = options['dim'] 491 | if nin_nonlin is None: 492 | nin_nonlin = nin 493 | if dim_nonlin is None: 494 | dim_nonlin = dim 495 | 496 | W = numpy.concatenate([norm_weight(nin, dim), 497 | norm_weight(nin, dim)], axis=1) 498 | params[_p(prefix, 'W')] = W 499 | params[_p(prefix, 'b')] = numpy.zeros((2 * dim,)).astype('float32') 500 | U = numpy.concatenate([ortho_weight(dim_nonlin), 501 | ortho_weight(dim_nonlin)], axis=1) 502 | params[_p(prefix, 'U')] = U 503 | 504 | Wx = norm_weight(nin_nonlin, dim_nonlin) 505 | params[_p(prefix, 'Wx')] = Wx 506 | Ux = ortho_weight(dim_nonlin) 507 | params[_p(prefix, 'Ux')] = Ux 508 | params[_p(prefix, 'bx')] = numpy.zeros((dim_nonlin,)).astype('float32') 509 | 510 | U_nl = numpy.concatenate([ortho_weight(dim_nonlin), 511 | ortho_weight(dim_nonlin)], axis=1) 512 | params[_p(prefix, 'U_nl')] = U_nl 513 | params[_p(prefix, 'b_nl')] = numpy.zeros((2 * dim_nonlin,)).astype('float32') 514 | 515 | Ux_nl = ortho_weight(dim_nonlin) 516 | params[_p(prefix, 'Ux_nl')] = Ux_nl 517 | params[_p(prefix, 'bx_nl')] = numpy.zeros((dim_nonlin,)).astype('float32') 518 | 519 | # context to LSTM 520 | Wc = norm_weight(dimctx, dim*2) 521 | params[_p(prefix, 'Wc')] = Wc 522 | 523 | Wcx = norm_weight(dimctx, dim) 524 | params[_p(prefix, 'Wcx')] = Wcx 525 | 526 | # attention: combined -> hidden 527 | W_comb_att = norm_weight(dim, dimctx) 528 | params[_p(prefix, 'W_comb_att')] = W_comb_att 529 | 530 | # attention: context -> hidden 531 | Wc_att = norm_weight(dimctx) 532 | params[_p(prefix, 'Wc_att')] = Wc_att 533 | 534 | # attention: hidden bias 535 | b_att = numpy.zeros((dimctx,)).astype('float32') 536 | params[_p(prefix, 'b_att')] = b_att 537 | 538 | # attention: 539 | U_att = norm_weight(dimctx, 1) 540 | params[_p(prefix, 'U_att')] = U_att 541 | c_att = numpy.zeros((1,)).astype('float32') 542 | params[_p(prefix, 'c_tt')] = c_att 543 | 544 | return params 545 | 546 | 547 | def gru_cond_layer(tparams, state_below, options, prefix='gru', 548 | mask=None, context=None, one_step=False, 549 | init_memory=None, init_state=None, 550 | context_mask=None, 551 | **kwargs): 552 | 553 | assert context, 'Context must be provided' 554 | 555 | if one_step: 556 | assert init_state, 'previous state must be provided' 557 | 558 | nsteps = state_below.shape[0] 559 | if state_below.ndim == 3: 560 | n_samples = state_below.shape[1] 561 | else: 562 | n_samples = 1 563 | 564 | # mask 565 | if mask is None: 566 | mask = tensor.alloc(1., state_below.shape[0], 1) 567 | 568 | dim = tparams[_p(prefix, 'Wcx')].shape[1] 569 | 570 | # initial/previous state 571 | if init_state is None: 572 | init_state = tensor.alloc(0., n_samples, dim) 573 | 574 | # projected context 575 | assert context.ndim == 3, \ 576 | 'Context must be 3-d: #annotation x #sample x dim' 577 | pctx_ = tensor.dot(context, tparams[_p(prefix, 'Wc_att')]) +\ 578 | tparams[_p(prefix, 'b_att')] 579 | 580 | def _slice(_x, n, dim): 581 | if _x.ndim == 3: 582 | return _x[:, :, n*dim:(n+1)*dim] 583 | return _x[:, n*dim:(n+1)*dim] 584 | 585 | # projected x 586 | state_belowx = tensor.dot(state_below, tparams[_p(prefix, 'Wx')]) +\ 587 | tparams[_p(prefix, 'bx')] 588 | state_below_ = tensor.dot(state_below, tparams[_p(prefix, 'W')]) +\ 589 | tparams[_p(prefix, 'b')] 590 | 591 | def _step_slice(m_, x_, xx_, h_, ctx_, alpha_, pctx_, cc_, 592 | U, Wc, W_comb_att, U_att, c_tt, Ux, Wcx, 593 | U_nl, Ux_nl, b_nl, bx_nl): 594 | preact1 = tensor.dot(h_, U) 595 | preact1 += x_ 596 | preact1 = tensor.nnet.sigmoid(preact1) 597 | 598 | r1 = _slice(preact1, 0, dim) 599 | u1 = _slice(preact1, 1, dim) 600 | 601 | preactx1 = tensor.dot(h_, Ux) 602 | preactx1 *= r1 603 | preactx1 += xx_ 604 | 605 | h1 = tensor.tanh(preactx1) 606 | 607 | h1 = u1 * h_ + (1. - u1) * h1 608 | h1 = m_[:, None] * h1 + (1. - m_)[:, None] * h_ 609 | 610 | # attention 611 | pstate_ = tensor.dot(h1, W_comb_att) 612 | pctx__ = pctx_ + pstate_[None, :, :] 613 | #pctx__ += xc_ 614 | pctx__ = tensor.tanh(pctx__) 615 | alpha = tensor.dot(pctx__, U_att)+c_tt 616 | alpha = alpha.reshape([alpha.shape[0], alpha.shape[1]]) 617 | alpha = tensor.exp(alpha) 618 | if context_mask: 619 | alpha = alpha * context_mask 620 | alpha = alpha / alpha.sum(0, keepdims=True) 621 | ctx_ = (cc_ * alpha[:, :, None]).sum(0) # current context 622 | 623 | preact2 = tensor.dot(h1, U_nl)+b_nl 624 | preact2 += tensor.dot(ctx_, Wc) 625 | preact2 = tensor.nnet.sigmoid(preact2) 626 | 627 | r2 = _slice(preact2, 0, dim) 628 | u2 = _slice(preact2, 1, dim) 629 | 630 | preactx2 = tensor.dot(h1, Ux_nl)+bx_nl 631 | preactx2 *= r2 632 | preactx2 += tensor.dot(ctx_, Wcx) 633 | 634 | h2 = tensor.tanh(preactx2) 635 | 636 | h2 = u2 * h1 + (1. - u2) * h2 637 | h2 = m_[:, None] * h2 + (1. - m_)[:, None] * h1 638 | 639 | return h2, ctx_, alpha.T # pstate_, preact, preactx, r, u 640 | 641 | seqs = [mask, state_below_, state_belowx] 642 | #seqs = [mask, state_below_, state_belowx, state_belowc] 643 | _step = _step_slice 644 | 645 | shared_vars = [tparams[_p(prefix, 'U')], 646 | tparams[_p(prefix, 'Wc')], 647 | tparams[_p(prefix, 'W_comb_att')], 648 | tparams[_p(prefix, 'U_att')], 649 | tparams[_p(prefix, 'c_tt')], 650 | tparams[_p(prefix, 'Ux')], 651 | tparams[_p(prefix, 'Wcx')], 652 | tparams[_p(prefix, 'U_nl')], 653 | tparams[_p(prefix, 'Ux_nl')], 654 | tparams[_p(prefix, 'b_nl')], 655 | tparams[_p(prefix, 'bx_nl')]] 656 | 657 | if one_step: 658 | rval = _step(*(seqs + [init_state, None, None, pctx_, context] + 659 | shared_vars)) 660 | else: 661 | rval, updates = theano.scan(_step, 662 | sequences=seqs, 663 | outputs_info=[init_state, 664 | tensor.alloc(0., n_samples, 665 | context.shape[2]), 666 | tensor.alloc(0., n_samples, 667 | context.shape[0])], 668 | non_sequences=[pctx_, context]+shared_vars, 669 | name=_p(prefix, '_layers'), 670 | n_steps=nsteps, 671 | profile=profile, 672 | strict=True) 673 | return rval 674 | 675 | 676 | # initialize all parameters 677 | def init_params(options): 678 | params = OrderedDict() 679 | 680 | # embedding 681 | params['Wemb'] = norm_weight(options['n_words_src'], options['dim_word']) 682 | params['Wemb_dec'] = norm_weight(options['n_words'], options['dim_word']) 683 | 684 | # encoder: bidirectional RNN 685 | params = get_layer('gru')[0](options, params, 686 | prefix='encoder', 687 | nin=options['dim_word'], 688 | dim=options['dim']) 689 | params = get_layer('gru_cae')[0](options, params, 690 | prefix='encoder_r', 691 | nin=options['dim_word'], 692 | dim=options['dim'], 693 | dimctx=options['dim']) 694 | ctxdim = options['dim'] 695 | 696 | # init_state, init_cell 697 | params = get_layer('ff')[0](options, params, prefix='ff_state', 698 | nin=ctxdim, nout=options['dim']) 699 | # decoder 700 | params = get_layer(options['decoder'])[0](options, params, 701 | prefix='decoder', 702 | nin=options['dim_word'], 703 | dim=options['dim'], 704 | dimctx=ctxdim) 705 | # readout 706 | params = get_layer('ff')[0](options, params, prefix='ff_logit_lstm', 707 | nin=options['dim'], nout=options['dim_word'], 708 | ortho=False) 709 | params = get_layer('ff')[0](options, params, prefix='ff_logit_prev', 710 | nin=options['dim_word'], 711 | nout=options['dim_word'], ortho=False) 712 | params = get_layer('ff')[0](options, params, prefix='ff_logit_ctx', 713 | nin=ctxdim, nout=options['dim_word'], 714 | ortho=False) 715 | params = get_layer('ff')[0](options, params, prefix='ff_logit', 716 | nin=options['dim_word'], 717 | nout=options['n_words']) 718 | 719 | return params 720 | 721 | 722 | # build a training model 723 | def build_model(tparams, options): 724 | opt_ret = dict() 725 | 726 | trng = RandomStreams(1234) 727 | use_noise = theano.shared(numpy.float32(0.)) 728 | 729 | # description string: #words x #samples 730 | x = tensor.matrix('x', dtype='int64') 731 | x_mask = tensor.matrix('x_mask', dtype='float32') 732 | y = tensor.matrix('y', dtype='int64') 733 | y_mask = tensor.matrix('y_mask', dtype='float32') 734 | 735 | # for the backward rnn, we just need to invert x and x_mask 736 | xr = x[::-1] 737 | xr_mask = x_mask[::-1] 738 | 739 | n_timesteps = x.shape[0] 740 | n_timesteps_trg = y.shape[0] 741 | n_samples = x.shape[1] 742 | 743 | # word embedding for forward rnn (source) 744 | emb = tparams['Wemb'][x.flatten()] 745 | emb = emb.reshape([n_timesteps, n_samples, options['dim_word']]) 746 | proj = get_layer('gru')[1](tparams, emb, options, 747 | prefix='encoder', 748 | mask=x_mask) 749 | # word embedding for backward rnn (source) 750 | embr = tparams['Wemb'][xr.flatten()] 751 | embr = embr.reshape([n_timesteps, n_samples, options['dim_word']]) 752 | projr = get_layer('gru_cae')[1](tparams, embr, options, 753 | prefix='encoder_r', 754 | context_below=proj[0][::-1], 755 | mask=xr_mask) 756 | 757 | # context will be the concatenation of forward and backward rnns 758 | ctx = projr[0][::-1] ## concatenate([proj[0], projr[0][::-1]], axis=proj[0].ndim-1) 759 | 760 | # mean of the context (across time) will be used to initialize decoder rnn 761 | ctx_mean = (ctx * x_mask[:, :, None]).sum(0) / x_mask.sum(0)[:, None] 762 | 763 | # or you can use the last state of forward + backward encoder rnns 764 | # ctx_mean = concatenate([proj[0][-1], projr[0][-1]], axis=proj[0].ndim-2) 765 | 766 | # initial decoder state 767 | init_state = get_layer('ff')[1](tparams, ctx_mean, options, 768 | prefix='ff_state', activ='tanh') 769 | 770 | # word embedding (target), we will shift the target sequence one time step 771 | # to the right. This is done because of the bi-gram connections in the 772 | # readout and decoder rnn. The first target will be all zeros and we will 773 | # not condition on the last output. 774 | emb = tparams['Wemb_dec'][y.flatten()] 775 | emb = emb.reshape([n_timesteps_trg, n_samples, options['dim_word']]) 776 | emb_shifted = tensor.zeros_like(emb) 777 | emb_shifted = tensor.set_subtensor(emb_shifted[1:], emb[:-1]) 778 | emb = emb_shifted 779 | 780 | # decoder - pass through the decoder conditional gru with attention 781 | proj = get_layer(options['decoder'])[1](tparams, emb, options, 782 | prefix='decoder', 783 | mask=y_mask, context=ctx, 784 | context_mask=x_mask, 785 | one_step=False, 786 | init_state=init_state) 787 | # hidden states of the decoder gru 788 | proj_h = proj[0] 789 | 790 | # weighted averages of context, generated by attention module 791 | ctxs = proj[1] 792 | 793 | # weights (alignment matrix) 794 | opt_ret['dec_alphas'] = proj[2] 795 | 796 | # compute word probabilities 797 | logit_lstm = get_layer('ff')[1](tparams, proj_h, options, 798 | prefix='ff_logit_lstm', activ='linear') 799 | logit_prev = get_layer('ff')[1](tparams, emb, options, 800 | prefix='ff_logit_prev', activ='linear') 801 | logit_ctx = get_layer('ff')[1](tparams, ctxs, options, 802 | prefix='ff_logit_ctx', activ='linear') 803 | logit = tensor.tanh(logit_lstm+logit_prev+logit_ctx) 804 | if options['use_dropout']: 805 | logit = dropout_layer(logit, use_noise, trng) 806 | logit = get_layer('ff')[1](tparams, logit, options, 807 | prefix='ff_logit', activ='linear') 808 | logit_shp = logit.shape 809 | probs = tensor.nnet.softmax(logit.reshape([logit_shp[0]*logit_shp[1], 810 | logit_shp[2]])) 811 | 812 | # cost 813 | y_flat = y.flatten() 814 | y_flat_idx = tensor.arange(y_flat.shape[0]) * options['n_words'] + y_flat 815 | cost = -tensor.log(probs.flatten()[y_flat_idx]) 816 | cost = cost.reshape([y.shape[0], y.shape[1]]) 817 | cost = (cost * y_mask).sum(0) 818 | 819 | return trng, use_noise, x, x_mask, y, y_mask, opt_ret, cost 820 | 821 | 822 | # build a sampler 823 | def build_sampler(tparams, options, trng, use_noise): 824 | x = tensor.matrix('x', dtype='int64') 825 | xr = x[::-1] 826 | n_timesteps = x.shape[0] 827 | n_samples = x.shape[1] 828 | 829 | # word embedding (source), forward and backward 830 | emb = tparams['Wemb'][x.flatten()] 831 | emb = emb.reshape([n_timesteps, n_samples, options['dim_word']]) 832 | embr = tparams['Wemb'][xr.flatten()] 833 | embr = embr.reshape([n_timesteps, n_samples, options['dim_word']]) 834 | 835 | # encoder 836 | proj = get_layer('gru')[1](tparams, emb, options, 837 | prefix='encoder') 838 | projr = get_layer('gru_cae')[1](tparams, embr, options, 839 | prefix='encoder_r', 840 | context_below=proj[0][::-1]) 841 | 842 | # concatenate forward and backward rnn hidden states 843 | ctx = projr[0][::-1] ##concatenate([proj[0], projr[0][::-1]], axis=proj[0].ndim-1) 844 | 845 | # get the input for decoder rnn initializer mlp 846 | ctx_mean = ctx.mean(0) 847 | # ctx_mean = concatenate([proj[0][-1],projr[0][-1]], axis=proj[0].ndim-2) 848 | init_state = get_layer('ff')[1](tparams, ctx_mean, options, 849 | prefix='ff_state', activ='tanh') 850 | 851 | print 'Building f_init...', 852 | outs = [init_state, ctx] 853 | f_init = theano.function([x], outs, name='f_init', profile=profile) 854 | print 'Done' 855 | 856 | # x: 1 x 1 857 | y = tensor.vector('y_sampler', dtype='int64') 858 | init_state = tensor.matrix('init_state', dtype='float32') 859 | 860 | # if it's the first word, emb should be all zero and it is indicated by -1 861 | emb = tensor.switch(y[:, None] < 0, 862 | tensor.alloc(0., 1, tparams['Wemb_dec'].shape[1]), 863 | tparams['Wemb_dec'][y]) 864 | 865 | # apply one step of conditional gru with attention 866 | proj = get_layer(options['decoder'])[1](tparams, emb, options, 867 | prefix='decoder', 868 | mask=None, context=ctx, 869 | one_step=True, 870 | init_state=init_state) 871 | # get the next hidden state 872 | next_state = proj[0] 873 | 874 | # get the weighted averages of context for this target word y 875 | ctxs = proj[1] 876 | 877 | logit_lstm = get_layer('ff')[1](tparams, next_state, options, 878 | prefix='ff_logit_lstm', activ='linear') 879 | logit_prev = get_layer('ff')[1](tparams, emb, options, 880 | prefix='ff_logit_prev', activ='linear') 881 | logit_ctx = get_layer('ff')[1](tparams, ctxs, options, 882 | prefix='ff_logit_ctx', activ='linear') 883 | logit = tensor.tanh(logit_lstm+logit_prev+logit_ctx) 884 | if options['use_dropout']: 885 | logit = dropout_layer(logit, use_noise, trng) 886 | logit = get_layer('ff')[1](tparams, logit, options, 887 | prefix='ff_logit', activ='linear') 888 | 889 | # compute the softmax probability 890 | next_probs = tensor.nnet.softmax(logit) 891 | 892 | # sample from softmax distribution to get the sample 893 | next_sample = trng.multinomial(pvals=next_probs).argmax(1) 894 | 895 | # compile a function to do the whole thing above, next word probability, 896 | # sampled word for the next target, next hidden state to be used 897 | print 'Building f_next..', 898 | inps = [y, ctx, init_state] 899 | outs = [next_probs, next_sample, next_state] 900 | f_next = theano.function(inps, outs, name='f_next', profile=profile) 901 | print 'Done' 902 | 903 | return f_init, f_next 904 | 905 | 906 | # generate sample, either with stochastic sampling or beam search. Note that, 907 | # this function iteratively calls f_init and f_next functions. 908 | def gen_sample(tparams, f_init, f_next, x, options, trng=None, k=1, maxlen=30, 909 | stochastic=True, argmax=False): 910 | 911 | # k is the beam size we have 912 | if k > 1: 913 | assert not stochastic, \ 914 | 'Beam search does not support stochastic sampling' 915 | 916 | sample = [] 917 | sample_score = [] 918 | if stochastic: 919 | sample_score = 0 920 | 921 | live_k = 1 922 | dead_k = 0 923 | 924 | hyp_samples = [[]] * live_k 925 | hyp_scores = numpy.zeros(live_k).astype('float32') 926 | hyp_states = [] 927 | 928 | # get initial state of decoder rnn and encoder context 929 | ret = f_init(x) 930 | next_state, ctx0 = ret[0], ret[1] 931 | next_w = -1 * numpy.ones((1,)).astype('int64') # bos indicator 932 | 933 | for ii in xrange(maxlen): 934 | ctx = numpy.tile(ctx0, [live_k, 1]) 935 | inps = [next_w, ctx, next_state] 936 | ret = f_next(*inps) 937 | next_p, next_w, next_state = ret[0], ret[1], ret[2] 938 | 939 | if stochastic: 940 | if argmax: 941 | nw = next_p[0].argmax() 942 | else: 943 | nw = next_w[0] 944 | sample.append(nw) 945 | sample_score -= numpy.log(next_p[0, nw]) 946 | if nw == 0: 947 | break 948 | else: 949 | cand_scores = hyp_scores[:, None] - numpy.log(next_p) 950 | cand_flat = cand_scores.flatten() 951 | ranks_flat = cand_flat.argsort()[:(k-dead_k)] 952 | 953 | voc_size = next_p.shape[1] 954 | trans_indices = ranks_flat / voc_size 955 | word_indices = ranks_flat % voc_size 956 | costs = cand_flat[ranks_flat] 957 | 958 | new_hyp_samples = [] 959 | new_hyp_scores = numpy.zeros(k-dead_k).astype('float32') 960 | new_hyp_states = [] 961 | 962 | for idx, [ti, wi] in enumerate(zip(trans_indices, word_indices)): 963 | new_hyp_samples.append(hyp_samples[ti]+[wi]) 964 | new_hyp_scores[idx] = copy.copy(costs[idx]) 965 | new_hyp_states.append(copy.copy(next_state[ti])) 966 | 967 | # check the finished samples 968 | new_live_k = 0 969 | hyp_samples = [] 970 | hyp_scores = [] 971 | hyp_states = [] 972 | 973 | for idx in xrange(len(new_hyp_samples)): 974 | if new_hyp_samples[idx][-1] == 0: 975 | sample.append(new_hyp_samples[idx]) 976 | sample_score.append(new_hyp_scores[idx]) 977 | dead_k += 1 978 | else: 979 | new_live_k += 1 980 | hyp_samples.append(new_hyp_samples[idx]) 981 | hyp_scores.append(new_hyp_scores[idx]) 982 | hyp_states.append(new_hyp_states[idx]) 983 | hyp_scores = numpy.array(hyp_scores) 984 | live_k = new_live_k 985 | 986 | if new_live_k < 1: 987 | break 988 | if dead_k >= k: 989 | break 990 | 991 | next_w = numpy.array([w[-1] for w in hyp_samples]) 992 | next_state = numpy.array(hyp_states) 993 | 994 | if not stochastic: 995 | # dump every remaining one 996 | if live_k > 0: 997 | for idx in xrange(live_k): 998 | sample.append(hyp_samples[idx]) 999 | sample_score.append(hyp_scores[idx]) 1000 | 1001 | return sample, sample_score 1002 | 1003 | # calculate the BLEU probabilities on a given corpus using translation model 1004 | def pred_bleus(source, refer, dic_src, dic_tgt, model, saveto="hypo.trans.plain", 1005 | k_=5, n_=False, c_=False, b_=1, is_nist=True): 1006 | # step 1. run the translation script 1007 | cmd = "python %s/sample.py -b %d -k %d" % (root_path, b_, k_) 1008 | if n_: 1009 | cmd += ' -n' 1010 | if c_: 1011 | cmd += ' -c' 1012 | cmd += ' %s %s %s %s %s' % (model, dic_src, dic_tgt, source, saveto+".bpe") 1013 | print cmd 1014 | run(cmd) 1015 | 1016 | cmd = 'sed "s/@@ //g" < %s > %s' % (saveto+".bpe", saveto) 1017 | print cmd 1018 | run(cmd) 1019 | 1020 | # step 2. beginning evaluation 1021 | if is_nist: 1022 | dev_bleu = eval_nist(source, refer, saveto) 1023 | else: 1024 | dev_bleu = eval_moses(source, refer, saveto) 1025 | 1026 | return -1*dev_bleu 1027 | 1028 | # calculate the log probablities on a given corpus using translation model 1029 | def pred_probs(f_log_probs, prepare_data, options, iterator, verbose=True): 1030 | probs = [] 1031 | 1032 | n_done = 0 1033 | 1034 | for x, y in iterator: 1035 | n_done += len(x) 1036 | 1037 | x, x_mask, y, y_mask = prepare_data(x, y, 1038 | n_words_src=options['n_words_src'], 1039 | n_words=options['n_words']) 1040 | 1041 | pprobs = f_log_probs(x, x_mask, y, y_mask) 1042 | for pp in pprobs: 1043 | probs.append(pp) 1044 | 1045 | if numpy.isnan(numpy.mean(probs)): 1046 | ipdb.set_trace() 1047 | 1048 | if verbose: 1049 | print >>sys.stderr, '%d samples computed' % (n_done) 1050 | 1051 | return numpy.array(probs) 1052 | 1053 | 1054 | # optimizers 1055 | # name(hyperp, tparams, grads, inputs (list), cost) = f_grad_shared, f_update 1056 | def adam(lr, tparams, grads, inp, cost, beta1=0.9, beta2=0.999, e=1e-8): 1057 | 1058 | gshared = [theano.shared(p.get_value() * 0., name='%s_grad' % k) 1059 | for k, p in tparams.iteritems()] 1060 | gsup = [(gs, g) for gs, g in zip(gshared, grads)] 1061 | 1062 | f_grad_shared = theano.function(inp, cost, updates=gsup, profile=profile) 1063 | 1064 | updates = [] 1065 | 1066 | t_prev = theano.shared(numpy.float32(0.)) 1067 | t = t_prev + 1. 1068 | lr_t = lr * tensor.sqrt(1. - beta2**t) / (1. - beta1**t) 1069 | 1070 | for p, g in zip(tparams.values(), gshared): 1071 | m = theano.shared(p.get_value() * 0., p.name + '_mean') 1072 | v = theano.shared(p.get_value() * 0., p.name + '_variance') 1073 | m_t = beta1 * m + (1. - beta1) * g 1074 | v_t = beta2 * v + (1. - beta2) * g**2 1075 | step = lr_t * m_t / (tensor.sqrt(v_t) + e) 1076 | p_t = p - step 1077 | updates.append((m, m_t)) 1078 | updates.append((v, v_t)) 1079 | updates.append((p, p_t)) 1080 | updates.append((t_prev, t)) 1081 | 1082 | f_update = theano.function([lr], [], updates=updates, 1083 | on_unused_input='ignore', profile=profile) 1084 | 1085 | return f_grad_shared, f_update 1086 | 1087 | 1088 | def adadelta(lr, tparams, grads, inp, cost): 1089 | zipped_grads = [theano.shared(p.get_value() * numpy.float32(0.), 1090 | name='%s_grad' % k) 1091 | for k, p in tparams.iteritems()] 1092 | running_up2 = [theano.shared(p.get_value() * numpy.float32(0.), 1093 | name='%s_rup2' % k) 1094 | for k, p in tparams.iteritems()] 1095 | running_grads2 = [theano.shared(p.get_value() * numpy.float32(0.), 1096 | name='%s_rgrad2' % k) 1097 | for k, p in tparams.iteritems()] 1098 | 1099 | zgup = [(zg, g) for zg, g in zip(zipped_grads, grads)] 1100 | rg2up = [(rg2, 0.95 * rg2 + 0.05 * (g ** 2)) 1101 | for rg2, g in zip(running_grads2, grads)] 1102 | 1103 | f_grad_shared = theano.function(inp, cost, updates=zgup+rg2up, 1104 | profile=profile) 1105 | 1106 | updir = [-tensor.sqrt(ru2 + 1e-6) / tensor.sqrt(rg2 + 1e-6) * zg 1107 | for zg, ru2, rg2 in zip(zipped_grads, running_up2, 1108 | running_grads2)] 1109 | ru2up = [(ru2, 0.95 * ru2 + 0.05 * (ud ** 2)) 1110 | for ru2, ud in zip(running_up2, updir)] 1111 | param_up = [(p, p + lr*ud) for p, ud in zip(itemlist(tparams), updir)] 1112 | 1113 | f_update = theano.function([lr], [], updates=ru2up+param_up, 1114 | on_unused_input='ignore', profile=profile) 1115 | 1116 | return f_grad_shared, f_update 1117 | 1118 | 1119 | def rmsprop(lr, tparams, grads, inp, cost): 1120 | zipped_grads = [theano.shared(p.get_value() * numpy.float32(0.), 1121 | name='%s_grad' % k) 1122 | for k, p in tparams.iteritems()] 1123 | running_grads = [theano.shared(p.get_value() * numpy.float32(0.), 1124 | name='%s_rgrad' % k) 1125 | for k, p in tparams.iteritems()] 1126 | running_grads2 = [theano.shared(p.get_value() * numpy.float32(0.), 1127 | name='%s_rgrad2' % k) 1128 | for k, p in tparams.iteritems()] 1129 | 1130 | zgup = [(zg, g) for zg, g in zip(zipped_grads, grads)] 1131 | # change pho to 0.99 1132 | ### rgup = [(rg, 0.95 * rg + 0.05 * g) for rg, g in zip(running_grads, grads)] 1133 | ### rg2up = [(rg2, 0.95 * rg2 + 0.05 * (g ** 2)) 1134 | ### for rg2, g in zip(running_grads2, grads)] 1135 | rgup = [(rg, 0.99 * rg + 0.01 * g) for rg, g in zip(running_grads, grads)] 1136 | rg2up = [(rg2, 0.99 * rg2 + 0.01 * (g ** 2)) 1137 | for rg2, g in zip(running_grads2, grads)] 1138 | 1139 | f_grad_shared = theano.function(inp, cost, updates=zgup+rgup+rg2up, 1140 | profile=profile) 1141 | 1142 | updir = [theano.shared(p.get_value() * numpy.float32(0.), 1143 | name='%s_updir' % k) 1144 | for k, p in tparams.iteritems()] 1145 | # changed 0.9 to 0; change 1e-4 to lr 1146 | ### updir_new = [(ud, 0.9 * ud - 1e-4 * zg / tensor.sqrt(rg2 - rg ** 2 + 1e-4)) 1147 | ### for ud, zg, rg, rg2 in zip(updir, zipped_grads, running_grads, 1148 | ### running_grads2)] 1149 | updir_new = [(ud, 0. * ud - lr * zg / tensor.sqrt(rg2 - rg ** 2 + 1e-4)) 1150 | for ud, zg, rg, rg2 in zip(updir, zipped_grads, running_grads, 1151 | running_grads2)] 1152 | param_up = [(p, p + udn[1]) 1153 | for p, udn in zip(itemlist(tparams), updir_new)] 1154 | f_update = theano.function([lr], [], updates=updir_new+param_up, 1155 | on_unused_input='ignore', profile=profile) 1156 | 1157 | return f_grad_shared, f_update 1158 | 1159 | 1160 | def sgd(lr, tparams, grads, x, mask, y, cost): 1161 | gshared = [theano.shared(p.get_value() * 0., 1162 | name='%s_grad' % k) 1163 | for k, p in tparams.iteritems()] 1164 | gsup = [(gs, g) for gs, g in zip(gshared, grads)] 1165 | 1166 | f_grad_shared = theano.function([x, mask, y], cost, updates=gsup, 1167 | profile=profile) 1168 | 1169 | pup = [(p, p - lr * g) for p, g in zip(itemlist(tparams), gshared)] 1170 | f_update = theano.function([lr], [], updates=pup, profile=profile) 1171 | 1172 | return f_grad_shared, f_update 1173 | 1174 | 1175 | def train(dim_word=100, # word vector dimensionality 1176 | dim=1000, # the number of LSTM units 1177 | max_epochs=5000, 1178 | finish_after=10000000, # finish after this many updates 1179 | dispFreq=100, 1180 | decay_c=0., # L2 regularization penalty 1181 | alpha_c=0., # alignment regularization 1182 | clip_c=-1., # gradient clipping threshold 1183 | lrate=0.01, # learning rate 1184 | n_words_src=100000, # source vocabulary size 1185 | n_words=100000, # target vocabulary size 1186 | maxlen=100, # maximum length of the description 1187 | optimizer='rmsprop', 1188 | batch_size=16, 1189 | valid_batch_size=16, 1190 | saveto='model.npz', 1191 | validFreq=1000, 1192 | validFreqLeast=10000, # at least greater this can be validated 1193 | validFreqFires=10000, # a split, before is relatively meaningless, after is more important 1194 | validFreqRefine=1000, # refined valid frequency 1195 | saveFreq=1000, # save the parameters after every saveFreq updates 1196 | sampleFreq=100, # generate some samples after every sampleFreq 1197 | datasets=[ 1198 | '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.en.tok', 1199 | '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.fr.tok'], 1200 | valid_datasets=['../data/dev/newstest2011.en.tok', 1201 | '../data/dev/newstest2011.fr.tok'], 1202 | dictionaries=[ 1203 | '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.en.tok.pkl', 1204 | '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.fr.tok.pkl'], 1205 | use_dropout=False, 1206 | reload_=False, 1207 | overwrite=False, 1208 | use_bleueval=True, 1209 | save_devscore_to='score.log', 1210 | save_devtrans_to='trans.txt', 1211 | beam_size=10, 1212 | normalize=False, 1213 | output_nbest=1, 1214 | shuffle_train=1, 1215 | is_eval_nist=True, 1216 | **known_args): 1217 | 1218 | # Model options 1219 | model_options = locals().copy() 1220 | 1221 | # load dictionaries and invert them 1222 | worddicts = [None] * len(dictionaries) 1223 | worddicts_r = [None] * len(dictionaries) 1224 | for ii, dd in enumerate(dictionaries): 1225 | with open(dd, 'rb') as f: 1226 | worddicts[ii] = pkl.load(f) 1227 | worddicts_r[ii] = dict() 1228 | for kk, vv in worddicts[ii].iteritems(): 1229 | worddicts_r[ii][vv] = kk 1230 | 1231 | # reload options 1232 | if reload_ and os.path.exists(saveto): 1233 | print 'Reloading model options' 1234 | with open('%s.pkl' % saveto, 'rb') as f: 1235 | model_options.update(pkl.load(f)) 1236 | 1237 | print 'Loading data' 1238 | train = TextIterator(datasets[0], datasets[1], 1239 | dictionaries[0], dictionaries[1], 1240 | n_words_source=n_words_src, n_words_target=n_words, 1241 | batch_size=batch_size, 1242 | maxlen=maxlen,shuffle_prob=shuffle_train) 1243 | try: 1244 | valid = TextIterator(valid_datasets[0], valid_datasets[1], 1245 | dictionaries[0], dictionaries[1], 1246 | n_words_source=n_words_src, n_words_target=n_words, 1247 | batch_size=valid_batch_size, 1248 | maxlen=maxlen,shuffle_prob=1) 1249 | except IOError: 1250 | if use_bleueval: 1251 | pass 1252 | else: 1253 | print 'You do not use BLEU eval, please check your validate dataset!' 1254 | raise 1255 | 1256 | print 'Building model' 1257 | params = init_params(model_options) 1258 | # reload parameters 1259 | if reload_ and os.path.exists(saveto): 1260 | print 'Reloading model parameters' 1261 | params = load_params(saveto, params) 1262 | 1263 | tparams = init_tparams(params) 1264 | 1265 | trng, use_noise, \ 1266 | x, x_mask, y, y_mask, \ 1267 | opt_ret, \ 1268 | cost = \ 1269 | build_model(tparams, model_options) 1270 | inps = [x, x_mask, y, y_mask] 1271 | 1272 | print 'Building sampler' 1273 | f_init, f_next = build_sampler(tparams, model_options, trng, use_noise) 1274 | 1275 | # before any regularizer 1276 | print 'Building f_log_probs...', 1277 | f_log_probs = theano.function(inps, cost, profile=profile) 1278 | print 'Done' 1279 | 1280 | cost = cost.mean() 1281 | 1282 | # apply L2 regularization on weights 1283 | if decay_c > 0.: 1284 | decay_c = theano.shared(numpy.float32(decay_c), name='decay_c') 1285 | weight_decay = 0. 1286 | for kk, vv in tparams.iteritems(): 1287 | weight_decay += (vv ** 2).sum() 1288 | weight_decay *= decay_c 1289 | cost += weight_decay 1290 | 1291 | # regularize the alpha weights 1292 | if alpha_c > 0. and not model_options['decoder'].endswith('simple'): 1293 | alpha_c = theano.shared(numpy.float32(alpha_c), name='alpha_c') 1294 | alpha_reg = alpha_c * ( 1295 | (tensor.cast(y_mask.sum(0)//x_mask.sum(0), 'float32')[:, None] - 1296 | opt_ret['dec_alphas'].sum(0))**2).sum(1).mean() 1297 | cost += alpha_reg 1298 | 1299 | # after all regularizers - compile the computational graph for cost 1300 | print 'Building f_cost...', 1301 | f_cost = theano.function(inps, cost, profile=profile) 1302 | print 'Done' 1303 | 1304 | print 'Computing gradient...', 1305 | grads = tensor.grad(cost, wrt=itemlist(tparams)) 1306 | print 'Done' 1307 | 1308 | # apply gradient clipping here 1309 | if clip_c > 0.: 1310 | g2 = 0. 1311 | for g in grads: 1312 | g2 += (g**2).sum() 1313 | new_grads = [] 1314 | for g in grads: 1315 | new_grads.append(tensor.switch(g2 > (clip_c**2), 1316 | g / tensor.sqrt(g2) * clip_c, 1317 | g)) 1318 | grads = new_grads 1319 | 1320 | # compile the optimizer, the actual computational graph is compiled here 1321 | lr = tensor.scalar(name='lr') 1322 | print 'Building optimizers...', 1323 | f_grad_shared, f_update = eval(optimizer)(lr, tparams, grads, inps, cost) 1324 | print 'Done' 1325 | 1326 | print 'Optimization' 1327 | 1328 | best_p = None 1329 | uidx = 0 1330 | estop = False 1331 | history_errs = [] 1332 | # reload history 1333 | if reload_ and os.path.exists(saveto): 1334 | rmodel = numpy.load(saveto) 1335 | history_errs = list(rmodel['history_errs']) 1336 | if 'uidx' in rmodel: 1337 | uidx = rmodel['uidx'] 1338 | 1339 | if validFreq == -1: 1340 | validFreq = len(train[0])/batch_size 1341 | if saveFreq == -1: 1342 | saveFreq = len(train[0])/batch_size 1343 | if sampleFreq == -1: 1344 | sampleFreq = len(train[0])/batch_size 1345 | 1346 | for eidx in xrange(max_epochs): 1347 | if eidx >= 1: 1348 | lrate = lrate / 2. 1349 | print 'learning rate decay from %s to %s' % (lrate * 2., lrate) 1350 | n_samples = 0 1351 | 1352 | for x, y in train: 1353 | n_samples += len(x) 1354 | uidx += 1 1355 | use_noise.set_value(1.) 1356 | 1357 | x, x_mask, y, y_mask = prepare_data(x, y, maxlen=maxlen, 1358 | n_words_src=n_words_src, 1359 | n_words=n_words) 1360 | 1361 | if x is None: 1362 | print 'Minibatch with zero sample under length ', maxlen 1363 | uidx -= 1 1364 | continue 1365 | 1366 | ud_start = time.time() 1367 | 1368 | # compute cost, grads and copy grads to shared variables 1369 | cost = f_grad_shared(x, x_mask, y, y_mask) 1370 | 1371 | # do the update on parameters 1372 | f_update(lrate) 1373 | 1374 | ud = time.time() - ud_start 1375 | 1376 | # check for bad numbers, usually we remove non-finite elements 1377 | # and continue training - but not done here 1378 | if numpy.isnan(cost) or numpy.isinf(cost): 1379 | print 'NaN detected' 1380 | return 1., 1., 1. 1381 | 1382 | # verbose 1383 | if numpy.mod(uidx, dispFreq) == 0: 1384 | print 'Epoch ', eidx, 'Update ', uidx, 'Cost ', cost, 'UD ', ud, 's' 1385 | 1386 | # save the best model so far, in addition, save the latest model 1387 | # into a separate file with the iteration number for external eval 1388 | if numpy.mod(uidx, saveFreq) == 0: 1389 | print 'Saving the best model...', 1390 | if best_p is not None: 1391 | params = best_p 1392 | else: 1393 | params = unzip(tparams) 1394 | numpy.savez("best_"+saveto, history_errs=history_errs, uidx=uidx, **params) 1395 | pkl.dump(model_options, open('%s.pkl' % ("best_"+saveto), 'wb')) 1396 | print 'Done' 1397 | 1398 | # save with uidx 1399 | if not overwrite: 1400 | print 'Saving the model at iteration {}...'.format(uidx), 1401 | saveto_uidx = '{}.iter{}.npz'.format( 1402 | os.path.splitext(saveto)[0], uidx) 1403 | numpy.savez(saveto_uidx, history_errs=history_errs, 1404 | uidx=uidx, **unzip(tparams)) 1405 | print 'Done' 1406 | 1407 | 1408 | # generate some samples with the model and display them 1409 | if numpy.mod(uidx, sampleFreq) == 0: 1410 | # FIXME: random selection? 1411 | for jj in xrange(numpy.minimum(5, x.shape[1])): 1412 | stochastic = True 1413 | sample, score = gen_sample(tparams, f_init, f_next, 1414 | x[:, jj][:, None], 1415 | model_options, trng=trng, k=1, 1416 | maxlen=30, 1417 | stochastic=stochastic, 1418 | argmax=False) 1419 | print 'Source ', jj, ': ', 1420 | for vv in x[:, jj]: 1421 | if vv == 0: 1422 | break 1423 | if vv in worddicts_r[0]: 1424 | print worddicts_r[0][vv], 1425 | else: 1426 | print 'UNK', 1427 | print 1428 | print 'Truth ', jj, ' : ', 1429 | for vv in y[:, jj]: 1430 | if vv == 0: 1431 | break 1432 | if vv in worddicts_r[1]: 1433 | print worddicts_r[1][vv], 1434 | else: 1435 | print 'UNK', 1436 | print 1437 | print 'Sample ', jj, ': ', 1438 | if stochastic: 1439 | ss = sample 1440 | else: 1441 | score = score / numpy.array([len(s) for s in sample]) 1442 | ss = sample[score.argmin()] 1443 | for vv in ss: 1444 | if vv == 0: 1445 | break 1446 | if vv in worddicts_r[1]: 1447 | print worddicts_r[1][vv], 1448 | else: 1449 | print 'UNK', 1450 | print 1451 | 1452 | # validate model on validation set and early stop if necessary 1453 | if (uidx >= validFreqLeast) and \ 1454 | ((uidx < validFreqFires and numpy.mod(uidx, validFreq) == 0) \ 1455 | or (uidx >= validFreqFires and numpy.mod(uidx, validFreqRefine) == 0)): 1456 | use_noise.set_value(0.) 1457 | if not use_bleueval: 1458 | valid_errs = pred_probs(f_log_probs, prepare_data, 1459 | model_options, valid) 1460 | valid_err = valid_errs.mean() 1461 | else: 1462 | params = unzip(tparams) 1463 | numpy.savez(saveto, history_errs=history_errs, uidx=uidx, **params) 1464 | pkl.dump(model_options, open('%s.pkl' % saveto, 'wb')) 1465 | valid_err = pred_bleus(valid_datasets[0], valid_datasets[1], 1466 | dictionaries[0], dictionaries[1], saveto, 1467 | saveto=save_devtrans_to, k_=beam_size, n_=normalize, 1468 | b_=output_nbest, is_nist=is_eval_nist) 1469 | print "development set bleu:\t", valid_err * -1 1470 | run('echo "Step:%s Development set bleu:%s" | cat >> %s' \ 1471 | % (uidx, valid_err * -1, save_devscore_to)) 1472 | 1473 | history_errs.append(valid_err) 1474 | if valid_err <= numpy.array(history_errs).min(): 1475 | best_p = unzip(tparams) 1476 | numpy.savez("best_"+saveto, history_errs=history_errs, uidx=uidx, **params) 1477 | pkl.dump(model_options, open('%s.pkl' % ("best_"+saveto), 'wb')) 1478 | os.system('cp %s best_%s' % (save_devtrans_to, save_devtrans_to)) 1479 | os.system('cp %s.eval.nmt best_%s.eval.nmt' % (save_devtrans_to, save_devtrans_to)) 1480 | 1481 | if numpy.isnan(valid_err): 1482 | ipdb.set_trace() 1483 | 1484 | print 'Valid ', valid_err 1485 | 1486 | # finish after this many updates 1487 | if uidx >= finish_after: 1488 | print 'Finishing after %d iterations!' % uidx 1489 | estop = True 1490 | break 1491 | 1492 | print 'Seen %d samples' % n_samples 1493 | 1494 | if estop: 1495 | break 1496 | 1497 | if best_p is not None: 1498 | zipp(best_p, tparams) 1499 | 1500 | use_noise.set_value(0.) 1501 | valid_err = pred_probs(f_log_probs, prepare_data, 1502 | model_options, valid).mean() 1503 | 1504 | print 'Valid ', valid_err 1505 | 1506 | params = copy.copy(best_p) 1507 | numpy.savez(saveto, zipped_params=best_p, 1508 | history_errs=history_errs, 1509 | uidx=uidx, 1510 | **params) 1511 | 1512 | return valid_err 1513 | 1514 | 1515 | if __name__ == '__main__': 1516 | pass 1517 | -------------------------------------------------------------------------------- /code/sample.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Translates a source file using a translation model. 3 | ''' 4 | import argparse 5 | 6 | import numpy 7 | import cPickle as pkl 8 | 9 | from nmt import (build_sampler, gen_sample, load_params, 10 | init_params, init_tparams) 11 | 12 | def translate_model(model, options, k, normalize, n_best): 13 | 14 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 15 | from theano import shared 16 | trng = RandomStreams(1234) 17 | use_noise = shared(numpy.float32(0.)) 18 | 19 | # allocate model parameters 20 | params = init_params(options) 21 | 22 | # load model parameters and set theano shared variables 23 | params = load_params(model, params) 24 | tparams = init_tparams(params) 25 | 26 | # word index 27 | f_init, f_next = build_sampler(tparams, options, trng, use_noise) 28 | 29 | def _translate(seq): 30 | # sample given an input sequence and obtain scores 31 | sample, score = gen_sample(tparams, f_init, f_next, 32 | numpy.array(seq).reshape([len(seq), 1]), 33 | options, trng=trng, k=k, maxlen=200, 34 | stochastic=False, argmax=False) 35 | 36 | # normalize scores according to sequence lengths 37 | if normalize: 38 | lengths = numpy.array([len(s) for s in sample]) 39 | score = score / lengths 40 | if n_best > 1: 41 | sidx = numpy.argsort(score)[:n_best] 42 | else: 43 | sidx = numpy.argmin(score) 44 | return numpy.array(sample)[sidx], numpy.array(score)[sidx] 45 | 46 | return _translate 47 | 48 | 49 | def main(model, dictionary, dictionary_target, source_file, saveto, k=5, 50 | normalize=False, n_process=5, chr_level=False, n_best=1): 51 | 52 | # load model model_options 53 | with open('%s.pkl' % model, 'rb') as f: 54 | options = pkl.load(f) 55 | 56 | # load source dictionary and invert 57 | with open(dictionary, 'rb') as f: 58 | word_dict = pkl.load(f) 59 | word_idict = dict() 60 | for kk, vv in word_dict.iteritems(): 61 | word_idict[vv] = kk 62 | word_idict[0] = '' 63 | word_idict[1] = 'UNK' 64 | 65 | # load target dictionary and invert 66 | with open(dictionary_target, 'rb') as f: 67 | word_dict_trg = pkl.load(f) 68 | word_idict_trg = dict() 69 | for kk, vv in word_dict_trg.iteritems(): 70 | word_idict_trg[vv] = kk 71 | word_idict_trg[0] = '' 72 | word_idict_trg[1] = 'UNK' 73 | 74 | # create input and output queues for processes 75 | trser = translate_model(model, options, k, normalize, n_best) 76 | 77 | # utility function 78 | def _seqs2words(caps): 79 | capsw = [] 80 | for cc in caps: 81 | ww = [] 82 | for w in cc: 83 | if w == 0: 84 | break 85 | ww.append(word_idict_trg[w]) 86 | trs = ' '.join(ww) 87 | if trs.strip() == '': 88 | trs = 'UNK' 89 | capsw.append(trs) 90 | return capsw 91 | 92 | xs = [] 93 | srcs = [] 94 | with open(source_file, 'r') as f: 95 | for idx, line in enumerate(f): 96 | if chr_level: 97 | words = list(line.decode('utf-8').strip()) 98 | else: 99 | words = line.strip().split() 100 | x = map(lambda w: word_dict[w] if w in word_dict else 1, words) 101 | x = map(lambda ii: ii if ii < options['n_words'] else 1, x) 102 | x += [0] 103 | xs.append((idx, x)) 104 | srcs.append(line.strip()) 105 | print 'Data loading over' 106 | 107 | print 'Translating ', source_file, '...' 108 | trans = [] 109 | scores = [] 110 | for req in xs: 111 | idx, x = req[0], req[1] 112 | tran, score = trser(x) 113 | trans.append(tran) 114 | scores.append(score) 115 | print 'the %d-th sentence' % idx 116 | print 'source side:\t%s' % srcs[idx] 117 | print 'target translation:\t%s' % ''.join(_seqs2words([trans[-1]])) 118 | 119 | if n_best == 1: 120 | trans = _seqs2words(trans) 121 | else: 122 | n_best_trans = [] 123 | for idx, (n_best_tr, score_) in enumerate(zip(trans, scores)): 124 | sentences = _seqs2words(n_best_tr) 125 | for ids, trans_ in enumerate(sentences): 126 | n_best_trans.append( 127 | '|||'.join( 128 | ['{}'.format(idx), trans_, 129 | '{}'.format(score_[ids])])) 130 | trans = n_best_trans 131 | 132 | with open(saveto, 'w') as f: 133 | print >>f, '\n'.join(trans) 134 | print 'Done' 135 | 136 | 137 | if __name__ == "__main__": 138 | parser = argparse.ArgumentParser() 139 | parser.add_argument('-k', type=int, default=5, help="Beam size") 140 | parser.add_argument('-n', action="store_true", default=False, 141 | help="Normalize wrt sequence length") 142 | parser.add_argument('-c', action="store_true", default=False, 143 | help="Character level") 144 | parser.add_argument('-b', type=int, default=1, help="Output n-best list") 145 | parser.add_argument('model', type=str) 146 | parser.add_argument('dictionary', type=str) 147 | parser.add_argument('dictionary_target', type=str) 148 | parser.add_argument('source', type=str) 149 | parser.add_argument('saveto', type=str) 150 | 151 | args = parser.parse_args() 152 | 153 | main(args.model, args.dictionary, args.dictionary_target, args.source, 154 | args.saveto, k=args.k, normalize=args.n, 155 | chr_level=args.c, n_best=args.b) 156 | -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import os 3 | import sys 4 | 5 | from nmt import train 6 | 7 | def main(job_id, params): 8 | print params 9 | validerr = train(**params) 10 | return validerr 11 | 12 | if __name__ == '__main__': 13 | if len(sys.argv) < 2: 14 | print '%s config.py' % sys.argv[0] 15 | sys.exit(1) 16 | 17 | options = eval(open(sys.argv[-1]).read()) 18 | numpy.random.seed(options['seed']) 19 | 20 | main(0, options) 21 | -------------------------------------------------------------------------------- /work/german.py: -------------------------------------------------------------------------------- 1 | dict( 2 | # network structure 3 | dim_word=620, # word vector dimensionality 4 | dim=1000, # the number of LSTM units 5 | n_words_src=40000, # source vocabulary size 6 | n_words=40000, # target vocabulary size 7 | maxlen=80, # maximum length of the description 8 | 9 | # process control 10 | max_epochs=10, 11 | finish_after=100000000, # finish after this many updates 12 | dispFreq=1, 13 | saveto='search_model.npz', 14 | validFreq=5000, 15 | validFreqLeast=100000, 16 | validFreqFires=150000, 17 | validFreqRefine=3000, 18 | saveFreq=1000, # save the parameters after every saveFreq updates 19 | sampleFreq=1000, # generate some samples after every sampleFreq 20 | reload_=True, 21 | overwrite=True, 22 | is_eval_nist=False, 23 | 24 | # optimization 25 | decay_c=0., # L2 regularization penalty 26 | alpha_c=0., # alignment regularization 27 | clip_c=5., # gradient clipping threshold 28 | lrate=1.0, # learning rate 29 | optimizer='adadelta', 30 | batch_size=80, 31 | valid_batch_size=80, 32 | use_dropout=False, 33 | shuffle_train=0.999, 34 | seed=1234, 35 | 36 | # development evaluation 37 | use_bleueval=True, 38 | save_devscore_to='search_bleu.log', 39 | save_devtrans_to='search_trans.txt', 40 | beam_size=10, 41 | proc_num=1, 42 | normalize=False, 43 | output_nbest=1, 44 | 45 | # datasets 46 | use_bpe=True, 47 | datasets=[ 48 | '/Path-to-training-data/train.en.bpe', 49 | '/Path-to-training-data/train.de.bpe'], 50 | valid_datasets=['/Path-to-dev-data/dev.en.plain.bpe', 51 | '/Path-to-dev-data/dev.de'], 52 | dictionaries=[ 53 | '/Vocabulary/vocab.en.pkl', 54 | '/Vocabulary/vocab.de.pkl'], 55 | ) 56 | -------------------------------------------------------------------------------- /work/run.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | THEANO_FLAGS='floatX=float32,device=gpu3,nvcc.fastmath=True' python train.py --proto=english-german-caencoder-nmt-bzhang --state german.py 4 | -------------------------------------------------------------------------------- /work/train.py: -------------------------------------------------------------------------------- 1 | ../code/train.py --------------------------------------------------------------------------------