├── LICENSE
├── README.md
├── code
├── data.py
├── eval
│ ├── __init__.py
│ ├── evaluate_moses.py
│ ├── evaluate_nist.py
│ ├── mteval-v11b.pl
│ └── multi-bleu.perl
├── nmt.py
├── sample.py
└── train.py
├── data
└── newstest2014.en.trans.txt
└── work
├── german.py
├── run.sh
└── train.py
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2017, DeepLearnXMU
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CAEncoder-NMT
2 | Source code for A Context-Aware Recurrent Encoder for Neural Machine Translation. Our model is much faster than the standard encoder-attention-decoder model, and obtains a BLEU point of 22.57 on English-German translation task, compared with that of 20.87 yielded by dl4mt.
3 |
4 | If you use this code, please cite our paper:
5 | ```
6 | @article{Zhang:2017:CRE:3180104.3180106,
7 | author = {Zhang, Biao and Xiong, Deyi and Su, Jinsong and Duan, Hong},
8 | title = {A Context-Aware Recurrent Encoder for Neural Machine Translation},
9 | journal = {IEEE/ACM Trans. Audio, Speech and Lang. Proc.},
10 | issue_date = {December 2017},
11 | volume = {25},
12 | number = {12},
13 | month = dec,
14 | year = {2017},
15 | issn = {2329-9290},
16 | pages = {2424--2432},
17 | numpages = {9},
18 | url = {https://doi.org/10.1109/TASLP.2017.2751420},
19 | doi = {10.1109/TASLP.2017.2751420},
20 | acmid = {3180106},
21 | publisher = {IEEE Press},
22 | address = {Piscataway, NJ, USA},
23 | }
24 | ```
25 |
26 | ## How to Run?
27 |
28 | A demo case is provided in the `work` directory
29 |
30 | ### Training
31 | You need process your training data and set up a configuration file, as the `german.py` does. The `train.py` script is used for training.
32 |
33 | ### Testing
34 | All you need is the `sample.py` script. Of course, the directory for vocabularies and model files are required.
35 |
36 |
37 | For any comments or questions, please email Biao Zhang.
38 |
--------------------------------------------------------------------------------
/code/data.py:
--------------------------------------------------------------------------------
1 | import cPickle as pkl
2 | import gzip
3 | import numpy
4 |
5 | '''
6 | ID rule: UNK=>1
7 | =>0
8 | '''
9 |
10 | numpy.random.seed(1234)
11 |
12 | def fopen(filename, mode='r'):
13 | if filename.endswith('.gz'):
14 | return gzip.open(filename, mode)
15 | return open(filename, mode)
16 |
17 |
18 | class TextIterator:
19 | """Simple Bitext iterator."""
20 | def __init__(self, source, target,
21 | source_dict, target_dict,
22 | batch_size=128,
23 | maxlen=100,
24 | n_words_source=-1,
25 | n_words_target=-1,
26 | shuffle_prob=1):
27 | self.source = fopen(source, 'r')
28 | self.target = fopen(target, 'r')
29 | with open(source_dict, 'rb') as f:
30 | self.source_dict = pkl.load(f)
31 | with open(target_dict, 'rb') as f:
32 | self.target_dict = pkl.load(f)
33 |
34 | self.batch_size = batch_size
35 | self.maxlen = maxlen
36 | self.shuffle_prob=shuffle_prob
37 |
38 | self.n_words_source = n_words_source
39 | self.n_words_target = n_words_target
40 |
41 | self.end_of_data = False
42 |
43 | def __iter__(self):
44 | return self
45 |
46 | def reset(self):
47 | self.source.seek(0)
48 | self.target.seek(0)
49 |
50 | def next(self):
51 | if self.end_of_data:
52 | self.end_of_data = False
53 | self.reset()
54 | raise StopIteration
55 |
56 | source = []
57 | target = []
58 |
59 | try:
60 |
61 | # actual work here
62 | while True:
63 |
64 | # read from source file and map to word index
65 | ss = self.source.readline()
66 | if ss == "":
67 | raise IOError
68 | ss = ss.strip().split()
69 | ss = [self.source_dict[w] if w in self.source_dict else 1
70 | for w in ss]
71 | if self.n_words_source > 0:
72 | ss = [w if w < self.n_words_source else 1 for w in ss]
73 |
74 | # read from source file and map to word index
75 | tt = self.target.readline()
76 | if tt == "":
77 | raise IOError
78 | tt = tt.strip().split()
79 | tt = [self.target_dict[w] if w in self.target_dict else 1
80 | for w in tt]
81 | if self.n_words_target > 0:
82 | tt = [w if w < self.n_words_target else 1 for w in tt]
83 |
84 | if len(ss) > self.maxlen and len(tt) > self.maxlen:
85 | continue
86 | if numpy.random.random() > self.shuffle_prob:
87 | continue
88 |
89 | source.append(ss)
90 | target.append(tt)
91 |
92 | if len(source) >= self.batch_size or \
93 | len(target) >= self.batch_size:
94 | break
95 | except IOError:
96 | self.end_of_data = True
97 |
98 | if len(source) <= 0 or len(target) <= 0:
99 | self.end_of_data = False
100 | self.reset()
101 | raise StopIteration
102 |
103 | return source, target
104 |
--------------------------------------------------------------------------------
/code/eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XMUDeepLIT/CAEncoder-NMT/8bb0e3d132c3fa12fb5792ab931268d68a3539ea/code/eval/__init__.py
--------------------------------------------------------------------------------
/code/eval/evaluate_moses.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 |
3 | '''
4 | Evaluate the translation result from Neural Machine Translation
5 | '''
6 |
7 | import sys
8 | import os
9 | import re
10 | import time
11 | import string
12 |
13 | run = os.system
14 | path = os.path.dirname(os.path.realpath(__file__))
15 |
16 | def eval_trans(src_sgm, ref_sgm, trs_plain):
17 |
18 | cmd = ("%s/multi-bleu.perl %s < %s > %s.eval.nmt" \
19 | %(path, ref_sgm, trs_plain, trs_plain))
20 | print cmd
21 | run(cmd)
22 | eval_nmt = ''.join(file('%s.eval.nmt' % trs_plain, 'rU').readlines())
23 |
24 | bleu = float(eval_nmt.strip().split(',')[0].split(' ')[-1])
25 |
26 | return bleu
27 |
28 | if __name__ == "__main__":
29 | if len(sys.argv) != 4:
30 | print '%s src_sgm(meaningless), ref_sgm(plain ref), trs_plain(plain trans)' % sys.argv[0]
31 | sys.exit(0)
32 |
33 | src_sgm = sys.argv[1]
34 | ref_sgm = sys.argv[2]
35 | trs_plain = sys.argv[3]
36 |
37 | eval_trans(src_sgm, ref_sgm, trs_plain)
38 |
--------------------------------------------------------------------------------
/code/eval/evaluate_nist.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 |
3 | '''
4 | Evaluate the translation result from Neural Machine Translation
5 | '''
6 |
7 | import sys
8 | import os
9 | import re
10 | import time
11 | import string
12 |
13 | run = os.system
14 | path = os.path.dirname(os.path.realpath(__file__))
15 |
16 | print path
17 |
18 | def split_seg(line):
19 | p1 = line.find(">")+1
20 | p2 = line.rfind("<")
21 | return [ line[:p1], line[p1:p2], line[p2:] ]
22 |
23 | def plain2sgm(trg_plain, src_sgm, trg_sgm):
24 | "Converse plain format to sgm format"
25 | fin_trg_plain = file(trg_plain , "r")
26 | fin_src_sgm = file(src_sgm, "r")
27 | fout = file(trg_sgm, "w")
28 |
29 | #head
30 | doc_head = fin_src_sgm.readline().rstrip().replace("srcset", "tstset")
31 | if doc_head.find("trglang") == -1 :
32 | doc_head = doc_head.replace(">", " trglang=\"en\">")
33 | print >> fout, doc_head
34 |
35 | for line in fin_src_sgm:
36 | line = line.rstrip()
37 | #process doc tag
38 | if "> fout, '''''' %id
43 | elif line.startswith("> fout, head, fin_trg_plain.readline().rstrip(), tail
46 | elif line.strip() == "":
47 | print >> fout, ""
48 | else:
49 | print >> fout, line
50 |
51 | fout.close()
52 | fin_src_sgm.close()
53 |
54 | def get_bleu(fe):
55 | "Get the bleu score from result file printed by mteval"
56 |
57 | c = file(fe, "rU").read()
58 | reg = re.compile(r"BLEU score =\s+(.*?)\s+")
59 | r = reg.findall(c)
60 | assert len(r) == 1
61 | return float(r[0])
62 |
63 | def eval_trans(src_sgm, ref_sgm, trs_plain):
64 |
65 | print path
66 |
67 | if src_sgm[:-4] != '.sgm':
68 | src_sgm = src_sgm + '.sgm'
69 |
70 | plain2sgm(trs_plain, src_sgm, "result.sgm")
71 | cmd = "%s/mteval-v11b.pl -s %s -r %s -t result.sgm > %s.eval.nmt" \
72 | %(path, src_sgm, ref_sgm, trs_plain)
73 | print cmd
74 | run(cmd)
75 | eval_nmt = ''.join(file('%s.eval.nmt' % trs_plain, 'rU').readlines())
76 | print >> file('%s.eval.nmt' % trs_plain, 'w'), eval_nmt.replace('hiero', 'Nerual Machine Translation')
77 | run('mv result.sgm %s.sgm' % trs_plain)
78 |
79 | return get_bleu('%s.eval.nmt' % trs_plain)
80 |
81 | if __name__ == "__main__":
82 | if len(sys.argv) != 4:
83 | print '%s src_sgm, ref_sgm, trs_plain' % sys.argv[0]
84 | sys.exit(0)
85 |
86 | src_sgm = sys.argv[1]
87 | ref_sgm = sys.argv[2]
88 | trs_plain = sys.argv[3]
89 |
90 | eval_trans(src_sgm, ref_sgm, trs_plain)
91 |
--------------------------------------------------------------------------------
/code/eval/mteval-v11b.pl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/perl -w
2 |
3 | use strict;
4 |
5 | #################################
6 | # History:
7 | #
8 | # version 11b -- text normalization modified:
9 | # * take out the join digit line because it joins digits
10 | # when it shouldn't have
11 | # $norm_text =~ s/(\d)\s+(?=\d)/$1/g; #join digits
12 | #
13 | # version 11a -- corrected output of individual n-gram precision values
14 | #
15 | # version 11 -- bug fixes:
16 | # * make filehandle operate in binary mode to prevent Perl from operating
17 | # (by default in Red Hat 9) in UTF-8
18 | # * fix failure on joining digits
19 | # version 10 -- updated output to include more details of n-gram scoring.
20 | # Defaults to generate both NIST and BLEU scores. Use -b for BLEU
21 | # only, use -n for NIST only
22 | #
23 | # version 09d -- bug fix (for BLEU scoring, ngrams were fixed at 4
24 | # being the max, regardless what was entered on the command line.)
25 | #
26 | # version 09c -- bug fix (During the calculation of ngram information,
27 | # each ngram was being counted only once for each segment. This has
28 | # been fixed so that each ngram is counted correctly in each segment.)
29 | #
30 | # version 09b -- text normalization modified:
31 | # * option flag added to preserve upper case
32 | # * non-ASCII characters left in place.
33 | #
34 | # version 09a -- text normalization modified:
35 | # * " and & converted to "" and &, respectively
36 | # * non-ASCII characters kept together (bug fix)
37 | #
38 | # version 09 -- modified to accommodate sgml tag and attribute
39 | # names revised to conform to default SGML conventions.
40 | #
41 | # version 08 -- modifies the NIST metric in accordance with the
42 | # findings on the 2001 Chinese-English dry run corpus. Also
43 | # incorporates the BLEU metric as an option and supports the
44 | # output of ngram detail.
45 | #
46 | # version 07 -- in response to the MT meeting on 28 Jan 2002 at ISI
47 | # Keep strings of non-ASCII characters together as one word
48 | # (rather than splitting them into one-character words).
49 | # Change length penalty so that translations that are longer than
50 | # the average reference translation are not penalized.
51 | #
52 | # version 06
53 | # Prevent divide-by-zero when a segment has no evaluation N-grams.
54 | # Correct segment index for level 3 debug output.
55 | #
56 | # version 05
57 | # improve diagnostic error messages
58 | #
59 | # version 04
60 | # tag segments
61 | #
62 | # version 03
63 | # add detailed output option (intermediate document and segment scores)
64 | #
65 | # version 02
66 | # accommodation of modified sgml tags and attributes
67 | #
68 | # version 01
69 | # same as bleu version 15, but modified to provide formal score output.
70 | #
71 | # original IBM version
72 | # Author: Kishore Papineni
73 | # Date: 06/10/2001
74 | #################################
75 |
76 | ######
77 | # Intro
78 | my ($date, $time) = date_time_stamp();
79 | print "MT evaluation scorer began on $date at $time\n";
80 | print "command line: ", $0, " ", join(" ", @ARGV), "\n";
81 | my $usage = "\n\nUsage: $0 [-h] -r -s src_file -t \n\n".
82 | "Description: This Perl script evaluates MT system performance.\n".
83 | "\n".
84 | "Required arguments:\n".
85 | " -r is a file containing the reference translations for\n".
86 | " the documents to be evaluated.\n".
87 | " -s is a file containing the source documents for which\n".
88 | " translations are to be evaluated\n".
89 | " -t is a file containing the translations to be evaluated\n".
90 | "\n".
91 | "Optional arguments:\n".
92 | " -c preserves upper-case alphabetic characters\n".
93 | " -b generate BLEU scores only\n".
94 | " -n generate NIST scores only\n".
95 | " -d detailed output flag used in conjunction with \"-b\" or \"-n\" flags:\n".
96 | " 0 (default) for system-level score only\n".
97 | " 1 to include document-level scores\n".
98 | " 2 to include segment-level scores\n".
99 | " 3 to include ngram-level scores\n".
100 | " -h prints this help message to STDOUT\n".
101 | "\n";
102 |
103 | use vars qw ($opt_r $opt_s $opt_t $opt_d $opt_h $opt_b $opt_n $opt_c $opt_x);
104 | use Getopt::Std;
105 | getopts ('r:s:t:d:hbncx:');
106 | die $usage if defined($opt_h);
107 | die "Error in command line: ref_file not defined$usage" unless defined $opt_r;
108 | die "Error in command line: src_file not defined$usage" unless defined $opt_s;
109 | die "Error in command line: tst_file not defined$usage" unless defined $opt_t;
110 | my $max_Ngram = 9;
111 | my $detail = defined $opt_d ? $opt_d : 0;
112 | my $preserve_case = defined $opt_c ? 1 : 0;
113 |
114 | my $METHOD = "BOTH";
115 | if (defined $opt_b) { $METHOD = "BLEU"; }
116 | if (defined $opt_n) { $METHOD = "NIST"; }
117 | my $method;
118 |
119 | my ($ref_file) = $opt_r;
120 | my ($src_file) = $opt_s;
121 | my ($tst_file) = $opt_t;
122 |
123 | ######
124 | # Global variables
125 | my ($src_lang, $tgt_lang, @tst_sys, @ref_sys); # evaluation parameters
126 | my (%tst_data, %ref_data); # the data -- with structure: {system}{document}[segments]
127 | my ($src_id, $ref_id, $tst_id); # unique identifiers for ref and tst translation sets
128 | my %eval_docs; # document information for the evaluation data set
129 | my %ngram_info; # the information obtained from (the last word in) the ngram
130 |
131 | ######
132 | # Get source document ID's
133 | ($src_id) = get_source_info ($src_file);
134 |
135 | ######
136 | # Get reference translations
137 | ($ref_id) = get_MT_data (\%ref_data, "RefSet", $ref_file);
138 |
139 | compute_ngram_info ();
140 |
141 | ######
142 | # Get translations to evaluate
143 | ($tst_id) = get_MT_data (\%tst_data, "TstSet", $tst_file);
144 |
145 | ######
146 | # Check data for completeness and correctness
147 | check_MT_data ();
148 |
149 | ######
150 | #
151 | my %NISTmt = ();
152 | my %BLEUmt = ();
153 |
154 | ######
155 | # Evaluate
156 | print " Evaluation of $src_lang-to-$tgt_lang translation using:\n";
157 | my $cum_seg = 0;
158 | foreach my $doc (sort keys %eval_docs) {
159 | $cum_seg += @{$eval_docs{$doc}{SEGS}};
160 | }
161 | print " src set \"$src_id\" (", scalar keys %eval_docs, " docs, $cum_seg segs)\n";
162 | print " ref set \"$ref_id\" (", scalar keys %ref_data, " refs)\n";
163 | print " tst set \"$tst_id\" (", scalar keys %tst_data, " systems)\n\n";
164 |
165 | foreach my $sys (sort @tst_sys) {
166 | for (my $n=1; $n<=$max_Ngram; $n++) {
167 | $NISTmt{$n}{$sys}{cum} = 0;
168 | $NISTmt{$n}{$sys}{ind} = 0;
169 | $BLEUmt{$n}{$sys}{cum} = 0;
170 | $BLEUmt{$n}{$sys}{ind} = 0;
171 | }
172 |
173 | if (($METHOD eq "BOTH") || ($METHOD eq "NIST")) {
174 | $method="NIST";
175 | score_system ($sys, %NISTmt);
176 | }
177 | if (($METHOD eq "BOTH") || ($METHOD eq "BLEU")) {
178 | $method="BLEU";
179 | score_system ($sys, %BLEUmt);
180 | }
181 | }
182 |
183 | ######
184 | printout_report ();
185 |
186 | ($date, $time) = date_time_stamp();
187 | print "MT evaluation scorer ended on $date at $time\n";
188 |
189 | exit 0;
190 |
191 | #################################
192 |
193 | sub get_source_info {
194 |
195 | my ($file) = @_;
196 | my ($name, $id, $src, $doc);
197 | my ($data, $tag, $span);
198 |
199 |
200 | #read data from file
201 | open (FILE, $file) or die "\nUnable to open translation data file '$file'", $usage;
202 | binmode FILE;
203 | $data .= $_ while ;
204 | close (FILE);
205 |
206 | #get source set info
207 | die "\n\nFATAL INPUT ERROR: no 'src_set' tag in src_file '$file'\n\n"
208 | unless ($tag, $span, $data) = extract_sgml_tag_and_span ("SrcSet", $data);
209 |
210 | die "\n\nFATAL INPUT ERROR: no tag attribute '$name' in file '$file'\n\n"
211 | unless ($id) = extract_sgml_tag_attribute ($name="SetID", $tag);
212 |
213 | die "\n\nFATAL INPUT ERROR: no tag attribute '$name' in file '$file'\n\n"
214 | unless ($src) = extract_sgml_tag_attribute ($name="SrcLang", $tag);
215 | die "\n\nFATAL INPUT ERROR: $name ('$src') in file '$file' inconsistent\n"
216 | ." with $name in previous input data ('$src_lang')\n\n"
217 | unless (not defined $src_lang or $src eq $src_lang);
218 | $src_lang = $src;
219 |
220 | #get doc info -- ID and # of segs
221 | $data = $span;
222 | while (($tag, $span, $data) = extract_sgml_tag_and_span ("Doc", $data)) {
223 | die "\n\nFATAL INPUT ERROR: no tag attribute '$name' in file '$file'\n\n"
224 | unless ($doc) = extract_sgml_tag_attribute ($name="DocID", $tag);
225 | die "\n\nFATAL INPUT ERROR: duplicate '$name' in file '$file'\n\n"
226 | if defined $eval_docs{$doc};
227 | $span =~ s/[\s\n\r]+/ /g; # concatenate records
228 | my $jseg=0, my $seg_data = $span;
229 | while (($tag, $span, $seg_data) = extract_sgml_tag_and_span ("Seg", $seg_data)) {
230 | ($eval_docs{$doc}{SEGS}[$jseg++]) = NormalizeText ($span);
231 | }
232 | die "\n\nFATAL INPUT ERROR: no segments in document '$doc' in file '$file'\n\n"
233 | if $jseg == 0;
234 | }
235 | die "\n\nFATAL INPUT ERROR: no documents in file '$file'\n\n"
236 | unless keys %eval_docs > 0;
237 | return $id;
238 | }
239 |
240 | #################################
241 |
242 | sub get_MT_data {
243 |
244 | my ($docs, $set_tag, $file) = @_;
245 | my ($name, $id, $src, $tgt, $sys, $doc);
246 | my ($tag, $span, $data);
247 |
248 | #read data from file
249 | open (FILE, $file) or die "\nUnable to open translation data file '$file'", $usage;
250 | binmode FILE;
251 | $data .= $_ while ;
252 | close (FILE);
253 |
254 | #get tag info
255 | while (($tag, $span, $data) = extract_sgml_tag_and_span ($set_tag, $data)) {
256 | die "\n\nFATAL INPUT ERROR: no tag attribute '$name' in file '$file'\n\n" unless
257 | ($id) = extract_sgml_tag_attribute ($name="SetID", $tag);
258 |
259 | die "\n\nFATAL INPUT ERROR: no tag attribute '$name' in file '$file'\n\n" unless
260 | ($src) = extract_sgml_tag_attribute ($name="SrcLang", $tag);
261 | die "\n\nFATAL INPUT ERROR: $name ('$src') in file '$file' inconsistent\n"
262 | ." with $name of source ('$src_lang')\n\n"
263 | unless $src eq $src_lang;
264 |
265 | die "\n\nFATAL INPUT ERROR: no tag attribute '$name' in file '$file'\n\n" unless
266 | ($tgt) = extract_sgml_tag_attribute ($name="TrgLang", $tag);
267 | die "\n\nFATAL INPUT ERROR: $name ('$tgt') in file '$file' inconsistent\n"
268 | ." with $name of the evaluation ('$tgt_lang')\n\n"
269 | unless (not defined $tgt_lang or $tgt eq $tgt_lang);
270 | $tgt_lang = $tgt;
271 |
272 | my $mtdata = $span;
273 | while (($tag, $span, $mtdata) = extract_sgml_tag_and_span ("Doc", $mtdata)) {
274 | die "\n\nFATAL INPUT ERROR: no tag attribute '$name' in file '$file'\n\n" unless
275 | (my $sys) = extract_sgml_tag_attribute ($name="SysID", $tag);
276 |
277 | die "\n\nFATAL INPUT ERROR: no tag attribute '$name' in file '$file'\n\n" unless
278 | $doc = extract_sgml_tag_attribute ($name="DocID", $tag);
279 |
280 | die "\n\nFATAL INPUT ERROR: document '$doc' for system '$sys' in file '$file'\n"
281 | ." previously loaded from file '$docs->{$sys}{$doc}{FILE}'\n\n"
282 | unless (not defined $docs->{$sys}{$doc});
283 |
284 | $span =~ s/[\s\n\r]+/ /g; # concatenate records
285 | my $jseg=0, my $seg_data = $span;
286 | while (($tag, $span, $seg_data) = extract_sgml_tag_and_span ("Seg", $seg_data)) {
287 | ($docs->{$sys}{$doc}{SEGS}[$jseg++]) = NormalizeText ($span);
288 | }
289 | die "\n\nFATAL INPUT ERROR: no segments in document '$doc' in file '$file'\n\n"
290 | if $jseg == 0;
291 | $docs->{$sys}{$doc}{FILE} = $file;
292 | }
293 | }
294 | return $id;
295 | }
296 |
297 | #################################
298 |
299 | sub check_MT_data {
300 |
301 | @tst_sys = sort keys %tst_data;
302 | @ref_sys = sort keys %ref_data;
303 |
304 | #every evaluation document must be represented for every system and every reference
305 | foreach my $doc (sort keys %eval_docs) {
306 | my $nseg_source = @{$eval_docs{$doc}{SEGS}};
307 | foreach my $sys (@tst_sys) {
308 | die "\n\nFATAL ERROR: no document '$doc' for system '$sys'\n\n"
309 | unless defined $tst_data{$sys}{$doc};
310 | my $nseg = @{$tst_data{$sys}{$doc}{SEGS}};
311 | die "\n\nFATAL ERROR: translated documents must contain the same # of segments as the source, but\n"
312 | ." document '$doc' for system '$sys' contains $nseg segments, while\n"
313 | ." the source document contains $nseg_source segments.\n\n"
314 | unless $nseg == $nseg_source;
315 | }
316 |
317 | foreach my $sys (@ref_sys) {
318 | die "\n\nFATAL ERROR: no document '$doc' for reference '$sys'\n\n"
319 | unless defined $ref_data{$sys}{$doc};
320 | my $nseg = @{$ref_data{$sys}{$doc}{SEGS}};
321 | die "\n\nFATAL ERROR: translated documents must contain the same # of segments as the source, but\n"
322 | ." document '$doc' for system '$sys' contains $nseg segments, while\n"
323 | ." the source document contains $nseg_source segments.\n\n"
324 | unless $nseg == $nseg_source;
325 | }
326 | }
327 | }
328 |
329 | #################################
330 |
331 | sub compute_ngram_info {
332 |
333 | my ($ref, $doc, $seg);
334 | my (@wrds, $tot_wrds, %ngrams, $ngram, $mgram);
335 | my (%ngram_count, @tot_ngrams);
336 |
337 | foreach $ref (keys %ref_data) {
338 | foreach $doc (keys %{$ref_data{$ref}}) {
339 | foreach $seg (@{$ref_data{$ref}{$doc}{SEGS}}) {
340 | @wrds = split /\s+/, $seg;
341 | $tot_wrds += @wrds;
342 | %ngrams = %{Words2Ngrams (@wrds)};
343 | foreach $ngram (keys %ngrams) {
344 | $ngram_count{$ngram} += $ngrams{$ngram};
345 | }
346 | }
347 | }
348 | }
349 |
350 | foreach $ngram (keys %ngram_count) {
351 | @wrds = split / /, $ngram;
352 | pop @wrds, $mgram = join " ", @wrds;
353 | $ngram_info{$ngram} = - log
354 | ($mgram ? $ngram_count{$ngram}/$ngram_count{$mgram}
355 | : $ngram_count{$ngram}/$tot_wrds) / log 2;
356 | if (defined $opt_x and $opt_x eq "ngram info") {
357 | @wrds = split / /, $ngram;
358 | printf "ngram info:%9.4f%6d%6d%8d%3d %s\n", $ngram_info{$ngram}, $ngram_count{$ngram},
359 | $mgram ? $ngram_count{$mgram} : $tot_wrds, $tot_wrds, scalar @wrds, $ngram;
360 | }
361 | }
362 | }
363 |
364 | #################################
365 |
366 | sub score_system {
367 |
368 | my ($sys, $ref, $doc, %SCOREmt);
369 | ($sys, %SCOREmt) = @_;
370 | my ($shortest_ref_length, $match_cnt, $tst_cnt, $ref_cnt, $tst_info, $ref_info);
371 | my ($cum_ref_length, @cum_match, @cum_tst_cnt, @cum_ref_cnt, @cum_tst_info, @cum_ref_info);
372 |
373 | $cum_ref_length = 0;
374 | for (my $j=1; $j<=$max_Ngram; $j++) {
375 | $cum_match[$j] = $cum_tst_cnt[$j] = $cum_ref_cnt[$j] = $cum_tst_info[$j] = $cum_ref_info[$j] = 0;
376 | }
377 |
378 | foreach $doc (sort keys %eval_docs) {
379 | ($shortest_ref_length, $match_cnt, $tst_cnt, $ref_cnt, $tst_info, $ref_info) = score_document ($sys, $doc);
380 |
381 | #output document summary score
382 | if (($detail >= 1 ) && ($METHOD eq "NIST")) {
383 | my %DOCmt = ();
384 | printf "$method score using 5-grams = %.4f for system \"$sys\" on document \"$doc\" (%d segments, %d words)\n",
385 | nist_score (scalar @ref_sys, $match_cnt, $tst_cnt, $ref_cnt, $tst_info, $ref_info, $sys, %DOCmt),
386 | scalar @{$tst_data{$sys}{$doc}{SEGS}}, $tst_cnt->[1];
387 | }
388 | if (($detail >= 1 ) && ($METHOD eq "BLEU")) {
389 | my %DOCmt = ();
390 | printf "$method score using 4-grams = %.4f for system \"$sys\" on document \"$doc\" (%d segments, %d words)\n",
391 | bleu_score($shortest_ref_length, $match_cnt, $tst_cnt, $sys, %DOCmt),
392 | scalar @{$tst_data{$sys}{$doc}{SEGS}}, $tst_cnt->[1];
393 | }
394 |
395 | $cum_ref_length += $shortest_ref_length;
396 | for (my $j=1; $j<=$max_Ngram; $j++) {
397 | $cum_match[$j] += $match_cnt->[$j];
398 | $cum_tst_cnt[$j] += $tst_cnt->[$j];
399 | $cum_ref_cnt[$j] += $ref_cnt->[$j];
400 | $cum_tst_info[$j] += $tst_info->[$j];
401 | $cum_ref_info[$j] += $ref_info->[$j];
402 | printf "document info: $sys $doc %d-gram %d %d %d %9.4f %9.4f\n", $j, $match_cnt->[$j],
403 | $tst_cnt->[$j], $ref_cnt->[$j], $tst_info->[$j], $ref_info->[$j]
404 | if (defined $opt_x and $opt_x eq "document info");
405 | }
406 | }
407 |
408 | #x #output system summary score
409 | #x printf "$method score = %.4f for system \"$sys\"\n",
410 | #x $method eq "BLEU" ? bleu_score($cum_ref_length, \@cum_match, \@cum_tst_cnt) :
411 | #x nist_score (scalar @ref_sys, \@cum_match, \@cum_tst_cnt, \@cum_ref_cnt, \@cum_tst_info, \@cum_ref_info, $sys, %SCOREmt);
412 | if ($method eq "BLEU") {
413 | bleu_score($cum_ref_length, \@cum_match, \@cum_tst_cnt, $sys, %SCOREmt);
414 | }
415 | if ($method eq "NIST") {
416 | nist_score (scalar @ref_sys, \@cum_match, \@cum_tst_cnt, \@cum_ref_cnt, \@cum_tst_info, \@cum_ref_info, $sys, %SCOREmt);
417 | }
418 | }
419 |
420 | #################################
421 |
422 | sub score_document {
423 |
424 | my ($sys, $ref, $doc);
425 | ($sys, $doc) = @_;
426 | my ($shortest_ref_length, $match_cnt, $tst_cnt, $ref_cnt, $tst_info, $ref_info);
427 | my ($cum_ref_length, @cum_match, @cum_tst_cnt, @cum_ref_cnt, @cum_tst_info, @cum_ref_info);
428 |
429 | $cum_ref_length = 0;
430 | for (my $j=1; $j<=$max_Ngram; $j++) {
431 | $cum_match[$j] = $cum_tst_cnt[$j] = $cum_ref_cnt[$j] = $cum_tst_info[$j] = $cum_ref_info[$j] = 0;
432 | }
433 |
434 | #score each segment
435 | for (my $jseg=0; $jseg<@{$tst_data{$sys}{$doc}{SEGS}}; $jseg++) {
436 | my @ref_segments = ();
437 | foreach $ref (@ref_sys) {
438 | push @ref_segments, $ref_data{$ref}{$doc}{SEGS}[$jseg];
439 | printf "ref '$ref', seg %d: %s\n", $jseg+1, $ref_data{$ref}{$doc}{SEGS}[$jseg]
440 | if $detail >= 3;
441 | }
442 | printf "sys '$sys', seg %d: %s\n", $jseg+1, $tst_data{$sys}{$doc}{SEGS}[$jseg]
443 | if $detail >= 3;
444 | ($shortest_ref_length, $match_cnt, $tst_cnt, $ref_cnt, $tst_info, $ref_info) =
445 | score_segment ($tst_data{$sys}{$doc}{SEGS}[$jseg], @ref_segments);
446 |
447 | #output segment summary score
448 | #x printf "$method score = %.4f for system \"$sys\" on segment %d of document \"$doc\" (%d words)\n",
449 | #x $method eq "BLEU" ? bleu_score($shortest_ref_length, $match_cnt, $tst_cnt) :
450 | #x nist_score (scalar @ref_sys, $match_cnt, $tst_cnt, $ref_cnt, $tst_info, $ref_info),
451 | #x $jseg+1, $tst_cnt->[1]
452 | #x if $detail >= 2;
453 | if (($detail >=2) && ($METHOD eq "BLEU")) {
454 | my %DOCmt = ();
455 | printf " $method score using 4-grams = %.4f for system \"$sys\" on segment %d of document \"$doc\" (%d words)\n",
456 | bleu_score($shortest_ref_length, $match_cnt, $tst_cnt, $sys, %DOCmt), $jseg+1, $tst_cnt->[1];
457 | }
458 | if (($detail >=2) && ($METHOD eq "NIST")) {
459 | my %DOCmt = ();
460 | printf " $method score using 5-grams = %.4f for system \"$sys\" on segment %d of document \"$doc\" (%d words)\n",
461 | nist_score (scalar @ref_sys, $match_cnt, $tst_cnt, $ref_cnt, $tst_info, $ref_info, $sys, %DOCmt), $jseg+1, $tst_cnt->[1];
462 | }
463 |
464 |
465 | $cum_ref_length += $shortest_ref_length;
466 | for (my $j=1; $j<=$max_Ngram; $j++) {
467 | $cum_match[$j] += $match_cnt->[$j];
468 | $cum_tst_cnt[$j] += $tst_cnt->[$j];
469 | $cum_ref_cnt[$j] += $ref_cnt->[$j];
470 | $cum_tst_info[$j] += $tst_info->[$j];
471 | $cum_ref_info[$j] += $ref_info->[$j];
472 | }
473 | }
474 | return ($cum_ref_length, [@cum_match], [@cum_tst_cnt], [@cum_ref_cnt], [@cum_tst_info], [@cum_ref_info]);
475 | }
476 |
477 | #################################
478 |
479 | sub score_segment {
480 |
481 | my ($tst_seg, @ref_segs) = @_;
482 | my (@tst_wrds, %tst_ngrams, @match_count, @tst_count, @tst_info);
483 | my (@ref_wrds, $ref_seg, %ref_ngrams, %ref_ngrams_max, @ref_count, @ref_info);
484 | my ($ngram);
485 | my (@nwrds_ref);
486 | my $shortest_ref_length;
487 |
488 | for (my $j=1; $j<= $max_Ngram; $j++) {
489 | $match_count[$j] = $tst_count[$j] = $ref_count[$j] = $tst_info[$j] = $ref_info[$j] = 0;
490 | }
491 |
492 | # get the ngram counts for the test segment
493 | @tst_wrds = split /\s+/, $tst_seg;
494 | %tst_ngrams = %{Words2Ngrams (@tst_wrds)};
495 | for (my $j=1; $j<=$max_Ngram; $j++) { # compute ngram counts
496 | $tst_count[$j] = $j<=@tst_wrds ? (@tst_wrds - $j + 1) : 0;
497 | }
498 |
499 | # get the ngram counts for the reference segments
500 | foreach $ref_seg (@ref_segs) {
501 | @ref_wrds = split /\s+/, $ref_seg;
502 | %ref_ngrams = %{Words2Ngrams (@ref_wrds)};
503 | foreach $ngram (keys %ref_ngrams) { # find the maximum # of occurrences
504 | my @wrds = split / /, $ngram;
505 | $ref_info[@wrds] += $ngram_info{$ngram};
506 | $ref_ngrams_max{$ngram} = defined $ref_ngrams_max{$ngram} ?
507 | max ($ref_ngrams_max{$ngram}, $ref_ngrams{$ngram}) :
508 | $ref_ngrams{$ngram};
509 | }
510 | for (my $j=1; $j<=$max_Ngram; $j++) { # update ngram counts
511 | $ref_count[$j] += $j<=@ref_wrds ? (@ref_wrds - $j + 1) : 0;
512 | }
513 | $shortest_ref_length = scalar @ref_wrds # find the shortest reference segment
514 | if (not defined $shortest_ref_length) or @ref_wrds < $shortest_ref_length;
515 | }
516 |
517 | # accumulate scoring stats for tst_seg ngrams that match ref_seg ngrams
518 | foreach $ngram (keys %tst_ngrams) {
519 | next unless defined $ref_ngrams_max{$ngram};
520 | my @wrds = split / /, $ngram;
521 | $tst_info[@wrds] += $ngram_info{$ngram} * min($tst_ngrams{$ngram},$ref_ngrams_max{$ngram});
522 | $match_count[@wrds] += my $count = min($tst_ngrams{$ngram},$ref_ngrams_max{$ngram});
523 | printf "%.2f info for each of $count %d-grams = '%s'\n", $ngram_info{$ngram}, scalar @wrds, $ngram
524 | if $detail >= 3;
525 | }
526 |
527 | return ($shortest_ref_length, [@match_count], [@tst_count], [@ref_count], [@tst_info], [@ref_info]);
528 | }
529 |
530 | #################################
531 |
532 | sub bleu_score {
533 |
534 | my ($shortest_ref_length, $matching_ngrams, $tst_ngrams, $sys, %SCOREmt) = @_;
535 |
536 | my $score = 0;
537 | my $iscore = 0;
538 | my $len_score = min (0, 1-$shortest_ref_length/$tst_ngrams->[1]);
539 |
540 | for (my $j=1; $j<=$max_Ngram; $j++) {
541 | if ($matching_ngrams->[$j] == 0) {
542 | $SCOREmt{$j}{$sys}{cum}=0;
543 | } else {
544 | # Cumulative N-Gram score
545 | $score += log ($matching_ngrams->[$j]/$tst_ngrams->[$j]);
546 | $SCOREmt{$j}{$sys}{cum} = exp($score/$j + $len_score);
547 | # Individual N-Gram score
548 | $iscore = log ($matching_ngrams->[$j]/$tst_ngrams->[$j]);
549 | $SCOREmt{$j}{$sys}{ind} = exp($iscore);
550 | }
551 | }
552 | return $SCOREmt{4}{$sys}{cum};
553 | }
554 |
555 | #################################
556 |
557 | sub nist_score {
558 |
559 | my ($nsys, $matching_ngrams, $tst_ngrams, $ref_ngrams, $tst_info, $ref_info, $sys, %SCOREmt) = @_;
560 |
561 | my $score = 0;
562 | my $iscore = 0;
563 |
564 |
565 | for (my $n=1; $n<=$max_Ngram; $n++) {
566 | $score += $tst_info->[$n]/max($tst_ngrams->[$n],1);
567 | $SCOREmt{$n}{$sys}{cum} = $score * nist_length_penalty($tst_ngrams->[1]/($ref_ngrams->[1]/$nsys));
568 |
569 | $iscore = $tst_info->[$n]/max($tst_ngrams->[$n],1);
570 | $SCOREmt{$n}{$sys}{ind} = $iscore * nist_length_penalty($tst_ngrams->[1]/($ref_ngrams->[1]/$nsys));
571 | }
572 | return $SCOREmt{5}{$sys}{cum};
573 | }
574 |
575 | #################################
576 |
577 | sub Words2Ngrams { #convert a string of words to an Ngram count hash
578 |
579 | my %count = ();
580 |
581 | for (; @_; shift) {
582 | my ($j, $ngram, $word);
583 | for ($j=0; $j<$max_Ngram and defined($word=$_[$j]); $j++) {
584 | $ngram .= defined $ngram ? " $word" : $word;
585 | $count{$ngram}++;
586 | }
587 | }
588 | return {%count};
589 | }
590 |
591 | #################################
592 |
593 | sub NormalizeText {
594 | my ($norm_text) = @_;
595 |
596 | # language-independent part:
597 | $norm_text =~ s///g; # strip "skipped" tags
598 | $norm_text =~ s/-\n//g; # strip end-of-line hyphenation and join lines
599 | $norm_text =~ s/\n/ /g; # join lines
600 | $norm_text =~ s/"/"/g; # convert SGML tag for quote to "
601 | $norm_text =~ s/&/&/g; # convert SGML tag for ampersand to &
602 | $norm_text =~ s/</
603 | $norm_text =~ s/>/>/g; # convert SGML tag for greater-than to <
604 |
605 | # language-dependent part (assuming Western languages):
606 | $norm_text = " $norm_text ";
607 | $norm_text =~ tr/[A-Z]/[a-z]/ unless $preserve_case;
608 | $norm_text =~ s/([\{-\~\[-\` -\&\(-\+\:-\@\/])/ $1 /g; # tokenize punctuation
609 | $norm_text =~ s/([^0-9])([\.,])/$1 $2 /g; # tokenize period and comma unless preceded by a digit
610 | $norm_text =~ s/([\.,])([^0-9])/ $1 $2/g; # tokenize period and comma unless followed by a digit
611 | $norm_text =~ s/([0-9])(-)/$1 $2 /g; # tokenize dash when preceded by a digit
612 | $norm_text =~ s/\s+/ /g; # one space only between words
613 | $norm_text =~ s/^\s+//; # no leading space
614 | $norm_text =~ s/\s+$//; # no trailing space
615 |
616 | return $norm_text;
617 | }
618 |
619 | #################################
620 |
621 | sub nist_length_penalty {
622 |
623 | my ($ratio) = @_;
624 | return 1 if $ratio >= 1;
625 | return 0 if $ratio <= 0;
626 | my $ratio_x = 1.5;
627 | my $score_x = 0.5;
628 | my $beta = -log($score_x)/log($ratio_x)/log($ratio_x);
629 | return exp (-$beta*log($ratio)*log($ratio));
630 | }
631 |
632 | #################################
633 |
634 | sub date_time_stamp {
635 |
636 | my ($sec, $min, $hour, $mday, $mon, $year, $wday, $yday, $isdst) = localtime();
637 | my @months = qw(Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec);
638 | my ($date, $time);
639 |
640 | $time = sprintf "%2.2d:%2.2d:%2.2d", $hour, $min, $sec;
641 | $date = sprintf "%4.4s %3.3s %s", 1900+$year, $months[$mon], $mday;
642 | return ($date, $time);
643 | }
644 |
645 | #################################
646 |
647 | sub extract_sgml_tag_and_span {
648 |
649 | my ($name, $data) = @_;
650 |
651 | ($data =~ m|<$name\s*([^>]*)>(.*?)$name\s*>(.*)|si) ? ($1, $2, $3) : ();
652 | }
653 |
654 | #################################
655 |
656 | sub extract_sgml_tag_attribute {
657 |
658 | my ($name, $data) = @_;
659 |
660 | ($data =~ m|$name\s*=\s*\"([^\"]*)\"|si) ? ($1) : ();
661 | }
662 |
663 | #################################
664 |
665 | sub max {
666 |
667 | my ($max, $next);
668 |
669 | return unless defined ($max=pop);
670 | while (defined ($next=pop)) {
671 | $max = $next if $next > $max;
672 | }
673 | return $max;
674 | }
675 |
676 | #################################
677 |
678 | sub min {
679 |
680 | my ($min, $next);
681 |
682 | return unless defined ($min=pop);
683 | while (defined ($next=pop)) {
684 | $min = $next if $next < $min;
685 | }
686 | return $min;
687 | }
688 |
689 | #################################
690 |
691 | sub printout_report
692 | {
693 |
694 | if ( $METHOD eq "BOTH" ) {
695 | foreach my $sys (sort @tst_sys) {
696 | printf "NIST score = %2.4f BLEU score = %.4f for system \"$sys\"\n",$NISTmt{5}{$sys}{cum},$BLEUmt{4}{$sys}{cum};
697 | }
698 | } elsif ($METHOD eq "NIST" ) {
699 | foreach my $sys (sort @tst_sys) {
700 | printf "NIST score = %2.4f for system \"$sys\"\n",$NISTmt{5}{$sys}{cum};
701 | }
702 | } elsif ($METHOD eq "BLEU" ) {
703 | foreach my $sys (sort @tst_sys) {
704 | printf "\nBLEU score = %.4f for system \"$sys\"\n",$BLEUmt{4}{$sys}{cum};
705 | }
706 | }
707 |
708 |
709 | printf "\n# ------------------------------------------------------------------------\n\n";
710 | printf "Individual N-gram scoring\n";
711 | printf " 1-gram 2-gram 3-gram 4-gram 5-gram 6-gram 7-gram 8-gram 9-gram\n";
712 | printf " ------ ------ ------ ------ ------ ------ ------ ------ ------\n";
713 |
714 | if (( $METHOD eq "BOTH" ) || ($METHOD eq "NIST")) {
715 | foreach my $sys (sort @tst_sys) {
716 | printf " NIST:";
717 | for (my $i=1; $i<=$max_Ngram; $i++) {
718 | printf " %2.4f ",$NISTmt{$i}{$sys}{ind}
719 | }
720 | printf " \"$sys\"\n";
721 | }
722 | printf "\n";
723 | }
724 |
725 | if (( $METHOD eq "BOTH" ) || ($METHOD eq "BLEU")) {
726 | foreach my $sys (sort @tst_sys) {
727 | printf " BLEU:";
728 | for (my $i=1; $i<=$max_Ngram; $i++) {
729 | printf " %2.4f ",$BLEUmt{$i}{$sys}{ind}
730 | }
731 | printf " \"$sys\"\n";
732 | }
733 | }
734 |
735 | printf "\n# ------------------------------------------------------------------------\n";
736 | printf "Cumulative N-gram scoring\n";
737 | printf " 1-gram 2-gram 3-gram 4-gram 5-gram 6-gram 7-gram 8-gram 9-gram\n";
738 | printf " ------ ------ ------ ------ ------ ------ ------ ------ ------\n";
739 |
740 | if (( $METHOD eq "BOTH" ) || ($METHOD eq "NIST")) {
741 | foreach my $sys (sort @tst_sys) {
742 | printf " NIST:";
743 | for (my $i=1; $i<=$max_Ngram; $i++) {
744 | printf " %2.4f ",$NISTmt{$i}{$sys}{cum}
745 | }
746 | printf " \"$sys\"\n";
747 | }
748 | }
749 | printf "\n";
750 |
751 |
752 | if (( $METHOD eq "BOTH" ) || ($METHOD eq "BLEU")) {
753 | foreach my $sys (sort @tst_sys) {
754 | printf " BLEU:";
755 | for (my $i=1; $i<=$max_Ngram; $i++) {
756 | printf " %2.4f ",$BLEUmt{$i}{$sys}{cum}
757 | }
758 | printf " \"$sys\"\n";
759 | }
760 | }
761 | }
762 |
--------------------------------------------------------------------------------
/code/eval/multi-bleu.perl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env perl
2 | #
3 | # This file is part of moses. Its use is licensed under the GNU Lesser General
4 | # Public License version 2.1 or, at your option, any later version.
5 |
6 | # $Id$
7 | use warnings;
8 | use strict;
9 |
10 | my $lowercase = 0;
11 | if ($ARGV[0] eq "-lc") {
12 | $lowercase = 1;
13 | shift;
14 | }
15 |
16 | my $stem = $ARGV[0];
17 | if (!defined $stem) {
18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n";
19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n";
20 | exit(1);
21 | }
22 |
23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0";
24 |
25 | my @REF;
26 | my $ref=0;
27 | while(-e "$stem$ref") {
28 | &add_to_ref("$stem$ref",\@REF);
29 | $ref++;
30 | }
31 | &add_to_ref($stem,\@REF) if -e $stem;
32 | die("ERROR: could not find reference file $stem") unless scalar @REF;
33 |
34 | # add additional references explicitly specified on the command line
35 | shift;
36 | foreach my $stem (@ARGV) {
37 | &add_to_ref($stem,\@REF) if -e $stem;
38 | }
39 |
40 |
41 |
42 | sub add_to_ref {
43 | my ($file,$REF) = @_;
44 | my $s=0;
45 | if ($file =~ /.gz$/) {
46 | open(REF,"gzip -dc $file|") or die "Can't read $file";
47 | } else {
48 | open(REF,$file) or die "Can't read $file";
49 | }
50 | while([) {
51 | chop;
52 | push @{$$REF[$s++]}, $_;
53 | }
54 | close(REF);
55 | }
56 |
57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference);
58 | my $s=0;
59 | while() {
60 | chop;
61 | $_ = lc if $lowercase;
62 | my @WORD = split;
63 | my %REF_NGRAM = ();
64 | my $length_translation_this_sentence = scalar(@WORD);
65 | my ($closest_diff,$closest_length) = (9999,9999);
66 | foreach my $reference (@{$REF[$s]}) {
67 | # print "$s $_ <=> $reference\n";
68 | $reference = lc($reference) if $lowercase;
69 | my @WORD = split(' ',$reference);
70 | my $length = scalar(@WORD);
71 | my $diff = abs($length_translation_this_sentence-$length);
72 | if ($diff < $closest_diff) {
73 | $closest_diff = $diff;
74 | $closest_length = $length;
75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n";
76 | } elsif ($diff == $closest_diff) {
77 | $closest_length = $length if $length < $closest_length;
78 | # from two references with the same closeness to me
79 | # take the *shorter* into account, not the "first" one.
80 | }
81 | for(my $n=1;$n<=4;$n++) {
82 | my %REF_NGRAM_N = ();
83 | for(my $start=0;$start<=$#WORD-($n-1);$start++) {
84 | my $ngram = "$n";
85 | for(my $w=0;$w<$n;$w++) {
86 | $ngram .= " ".$WORD[$start+$w];
87 | }
88 | $REF_NGRAM_N{$ngram}++;
89 | }
90 | foreach my $ngram (keys %REF_NGRAM_N) {
91 | if (!defined($REF_NGRAM{$ngram}) ||
92 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) {
93 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram};
94 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}]
\n";
95 | }
96 | }
97 | }
98 | }
99 | $length_translation += $length_translation_this_sentence;
100 | $length_reference += $closest_length;
101 | for(my $n=1;$n<=4;$n++) {
102 | my %T_NGRAM = ();
103 | for(my $start=0;$start<=$#WORD-($n-1);$start++) {
104 | my $ngram = "$n";
105 | for(my $w=0;$w<$n;$w++) {
106 | $ngram .= " ".$WORD[$start+$w];
107 | }
108 | $T_NGRAM{$ngram}++;
109 | }
110 | foreach my $ngram (keys %T_NGRAM) {
111 | $ngram =~ /^(\d+) /;
112 | my $n = $1;
113 | # my $corr = 0;
114 | # print "$i e $ngram $T_NGRAM{$ngram}
\n";
115 | $TOTAL[$n] += $T_NGRAM{$ngram};
116 | if (defined($REF_NGRAM{$ngram})) {
117 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) {
118 | $CORRECT[$n] += $T_NGRAM{$ngram};
119 | # $corr = $T_NGRAM{$ngram};
120 | # print "$i e correct1 $T_NGRAM{$ngram}
\n";
121 | }
122 | else {
123 | $CORRECT[$n] += $REF_NGRAM{$ngram};
124 | # $corr = $REF_NGRAM{$ngram};
125 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n";
126 | }
127 | }
128 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram};
129 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n"
130 | }
131 | }
132 | $s++;
133 | }
134 | my $brevity_penalty = 1;
135 | my $bleu = 0;
136 |
137 | my @bleu=();
138 |
139 | for(my $n=1;$n<=4;$n++) {
140 | if (defined ($TOTAL[$n])){
141 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0;
142 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n";
143 | }else{
144 | $bleu[$n]=0;
145 | }
146 | }
147 |
148 | if ($length_reference==0){
149 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n";
150 | exit(1);
151 | }
152 |
153 | if ($length_translation<$length_reference) {
154 | $brevity_penalty = exp(1-$length_reference/$length_translation);
155 | }
156 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) +
157 | my_log( $bleu[2] ) +
158 | my_log( $bleu[3] ) +
159 | my_log( $bleu[4] ) ) / 4) ;
160 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n",
161 | 100*$bleu,
162 | 100*$bleu[1],
163 | 100*$bleu[2],
164 | 100*$bleu[3],
165 | 100*$bleu[4],
166 | $brevity_penalty,
167 | $length_translation / $length_reference,
168 | $length_translation,
169 | $length_reference;
170 |
171 | sub my_log {
172 | return -9999999999 unless $_[0];
173 | return log($_[0]);
174 | }
175 |
--------------------------------------------------------------------------------
/code/nmt.py:
--------------------------------------------------------------------------------
1 | '''
2 | Build a neural machine translation model with soft attention
3 | '''
4 |
5 | '''
6 | The main framework is shared from `dl4mt`. The proposed model `CAEncoder` is developped by Biao Zhang.
7 | For any quetion, welcome to contact me (zb@stu.xmu.edu.cn)
8 | '''
9 |
10 | import theano
11 | import theano.tensor as tensor
12 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
13 |
14 | import cPickle as pkl
15 | import ipdb
16 | import numpy
17 | import copy
18 |
19 | import os
20 | import warnings
21 | import sys
22 | import time
23 |
24 | from collections import OrderedDict
25 |
26 | from data import TextIterator
27 |
28 | profile = False
29 | numpy.random.seed(1234)
30 |
31 | from eval.evaluate_nist import eval_trans as eval_nist
32 | from eval.evaluate_moses import eval_trans as eval_moses
33 | run = os.system
34 | root_path = os.path.dirname(os.path.realpath(__file__))
35 |
36 | # push parameters to Theano shared variables
37 | def zipp(params, tparams):
38 | for kk, vv in params.iteritems():
39 | tparams[kk].set_value(vv)
40 |
41 |
42 | # pull parameters from Theano shared variables
43 | def unzip(zipped):
44 | new_params = OrderedDict()
45 | for kk, vv in zipped.iteritems():
46 | new_params[kk] = vv.get_value()
47 | return new_params
48 |
49 |
50 | # get the list of parameters: Note that tparams must be OrderedDict
51 | def itemlist(tparams):
52 | return [vv for kk, vv in tparams.iteritems()]
53 |
54 |
55 | # dropout
56 | def dropout_layer(state_before, use_noise, trng):
57 | proj = tensor.switch(
58 | use_noise,
59 | state_before * trng.binomial(state_before.shape, p=0.5, n=1,
60 | dtype=state_before.dtype),
61 | state_before * 0.5)
62 | return proj
63 |
64 |
65 | # make prefix-appended name
66 | def _p(pp, name):
67 | return '%s_%s' % (pp, name)
68 |
69 |
70 | # initialize Theano shared variables according to the initial parameters
71 | def init_tparams(params):
72 | tparams = OrderedDict()
73 | for kk, pp in params.iteritems():
74 | tparams[kk] = theano.shared(params[kk], name=kk)
75 | return tparams
76 |
77 |
78 | # load parameters
79 | def load_params(path, params):
80 | pp = numpy.load(path)
81 | for kk, vv in params.iteritems():
82 | if kk not in pp:
83 | warnings.warn('%s is not in the archive' % kk)
84 | continue
85 | if vv.shape != pp[kk].shape:
86 | warnings.warn('%s is not in the same shap' % kk)
87 | continue
88 | params[kk] = pp[kk]
89 |
90 | return params
91 |
92 | # layers: 'name': ('parameter initializer', 'feedforward')
93 | layers = {'ff': ('param_init_fflayer', 'fflayer'),
94 | 'gru': ('param_init_gru', 'gru_layer'),
95 | 'gru_cond': ('param_init_gru_cond', 'gru_cond_layer'),
96 | 'gru_cae': ('param_init_gru_cae', 'gru_cae_layer'),
97 | }
98 |
99 |
100 | def get_layer(name):
101 | fns = layers[name]
102 | return (eval(fns[0]), eval(fns[1]))
103 |
104 |
105 | # some utilities
106 | def ortho_weight(ndim):
107 | W = numpy.random.randn(ndim, ndim)
108 | u, s, v = numpy.linalg.svd(W)
109 | return u.astype('float32')
110 |
111 |
112 | def norm_weight(nin, nout=None, scale=0.01, ortho=True):
113 | if nout is None:
114 | nout = nin
115 | if nout == nin and ortho:
116 | W = ortho_weight(nin)
117 | else:
118 | W = scale * numpy.random.randn(nin, nout)
119 | return W.astype('float32')
120 |
121 |
122 | def tanh(x):
123 | return tensor.tanh(x)
124 |
125 |
126 | def linear(x):
127 | return x
128 |
129 |
130 | def concatenate(tensor_list, axis=0):
131 | """
132 | Alternative implementation of `theano.tensor.concatenate`.
133 | This function does exactly the same thing, but contrary to Theano's own
134 | implementation, the gradient is implemented on the GPU.
135 | Backpropagating through `theano.tensor.concatenate` yields slowdowns
136 | because the inverse operation (splitting) needs to be done on the CPU.
137 | This implementation does not have that problem.
138 | :usage:
139 | >>> x, y = theano.tensor.matrices('x', 'y')
140 | >>> c = concatenate([x, y], axis=1)
141 | :parameters:
142 | - tensor_list : list
143 | list of Theano tensor expressions that should be concatenated.
144 | - axis : int
145 | the tensors will be joined along this axis.
146 | :returns:
147 | - out : tensor
148 | the concatenated tensor expression.
149 | """
150 | concat_size = sum(tt.shape[axis] for tt in tensor_list)
151 |
152 | output_shape = ()
153 | for k in range(axis):
154 | output_shape += (tensor_list[0].shape[k],)
155 | output_shape += (concat_size,)
156 | for k in range(axis + 1, tensor_list[0].ndim):
157 | output_shape += (tensor_list[0].shape[k],)
158 |
159 | out = tensor.zeros(output_shape)
160 | offset = 0
161 | for tt in tensor_list:
162 | indices = ()
163 | for k in range(axis):
164 | indices += (slice(None),)
165 | indices += (slice(offset, offset + tt.shape[axis]),)
166 | for k in range(axis + 1, tensor_list[0].ndim):
167 | indices += (slice(None),)
168 |
169 | out = tensor.set_subtensor(out[indices], tt)
170 | offset += tt.shape[axis]
171 |
172 | return out
173 |
174 |
175 | # batch preparation
176 | def prepare_data(seqs_x, seqs_y, maxlen=None, n_words_src=30000,
177 | n_words=30000):
178 | # x: a list of sentences
179 | lengths_x = [len(s) for s in seqs_x]
180 | lengths_y = [len(s) for s in seqs_y]
181 |
182 | if maxlen is not None:
183 | new_seqs_x = []
184 | new_seqs_y = []
185 | new_lengths_x = []
186 | new_lengths_y = []
187 | for l_x, s_x, l_y, s_y in zip(lengths_x, seqs_x, lengths_y, seqs_y):
188 | if l_x < maxlen and l_y < maxlen:
189 | new_seqs_x.append(s_x)
190 | new_lengths_x.append(l_x)
191 | new_seqs_y.append(s_y)
192 | new_lengths_y.append(l_y)
193 | lengths_x = new_lengths_x
194 | seqs_x = new_seqs_x
195 | lengths_y = new_lengths_y
196 | seqs_y = new_seqs_y
197 |
198 | if len(lengths_x) < 1 or len(lengths_y) < 1:
199 | return None, None, None, None
200 |
201 | n_samples = len(seqs_x)
202 | maxlen_x = numpy.max(lengths_x) + 1
203 | maxlen_y = numpy.max(lengths_y) + 1
204 |
205 | x = numpy.zeros((maxlen_x, n_samples)).astype('int64')
206 | y = numpy.zeros((maxlen_y, n_samples)).astype('int64')
207 | x_mask = numpy.zeros((maxlen_x, n_samples)).astype('float32')
208 | y_mask = numpy.zeros((maxlen_y, n_samples)).astype('float32')
209 | for idx, [s_x, s_y] in enumerate(zip(seqs_x, seqs_y)):
210 | x[:lengths_x[idx], idx] = s_x
211 | x_mask[:lengths_x[idx]+1, idx] = 1.
212 | y[:lengths_y[idx], idx] = s_y
213 | y_mask[:lengths_y[idx]+1, idx] = 1.
214 |
215 | return x, x_mask, y, y_mask
216 |
217 |
218 | # feedforward layer: affine transformation + point-wise nonlinearity
219 | def param_init_fflayer(options, params, prefix='ff', nin=None, nout=None,
220 | ortho=True):
221 | if nin is None:
222 | nin = options['dim_proj']
223 | if nout is None:
224 | nout = options['dim_proj']
225 | params[_p(prefix, 'W')] = norm_weight(nin, nout, scale=0.01, ortho=ortho)
226 | params[_p(prefix, 'b')] = numpy.zeros((nout,)).astype('float32')
227 |
228 | return params
229 |
230 |
231 | def fflayer(tparams, state_below, options, prefix='rconv',
232 | activ='lambda x: tensor.tanh(x)', **kwargs):
233 | return eval(activ)(
234 | tensor.dot(state_below, tparams[_p(prefix, 'W')]) +
235 | tparams[_p(prefix, 'b')])
236 |
237 |
238 | # GRU layer
239 | def param_init_gru(options, params, prefix='gru', nin=None, dim=None):
240 | if nin is None:
241 | nin = options['dim_proj']
242 | if dim is None:
243 | dim = options['dim_proj']
244 |
245 | # embedding to gates transformation weights, biases
246 | W = numpy.concatenate([norm_weight(nin, dim),
247 | norm_weight(nin, dim)], axis=1)
248 | params[_p(prefix, 'W')] = W
249 | params[_p(prefix, 'b')] = numpy.zeros((2 * dim,)).astype('float32')
250 |
251 | # recurrent transformation weights for gates
252 | U = numpy.concatenate([ortho_weight(dim),
253 | ortho_weight(dim)], axis=1)
254 | params[_p(prefix, 'U')] = U
255 |
256 | # embedding to hidden state proposal weights, biases
257 | Wx = norm_weight(nin, dim)
258 | params[_p(prefix, 'Wx')] = Wx
259 | params[_p(prefix, 'bx')] = numpy.zeros((dim,)).astype('float32')
260 |
261 | # recurrent transformation weights for hidden state proposal
262 | Ux = ortho_weight(dim)
263 | params[_p(prefix, 'Ux')] = Ux
264 |
265 | return params
266 |
267 |
268 | def gru_layer(tparams, state_below, options, prefix='gru', mask=None,
269 | **kwargs):
270 | nsteps = state_below.shape[0]
271 | if state_below.ndim == 3:
272 | n_samples = state_below.shape[1]
273 | else:
274 | n_samples = 1
275 |
276 | dim = tparams[_p(prefix, 'Ux')].shape[1]
277 |
278 | if mask is None:
279 | mask = tensor.alloc(1., state_below.shape[0], 1)
280 |
281 | # utility function to slice a tensor
282 | def _slice(_x, n, dim):
283 | if _x.ndim == 3:
284 | return _x[:, :, n*dim:(n+1)*dim]
285 | return _x[:, n*dim:(n+1)*dim]
286 |
287 | # state_below is the input word embeddings
288 | # input to the gates, concatenated
289 | state_below_ = tensor.dot(state_below, tparams[_p(prefix, 'W')]) + \
290 | tparams[_p(prefix, 'b')]
291 | # input to compute the hidden state proposal
292 | state_belowx = tensor.dot(state_below, tparams[_p(prefix, 'Wx')]) + \
293 | tparams[_p(prefix, 'bx')]
294 |
295 | # step function to be used by scan
296 | # arguments | sequences |outputs-info| non-seqs
297 | def _step_slice(m_, x_, xx_, h_, U, Ux):
298 | preact = tensor.dot(h_, U)
299 | preact += x_
300 |
301 | # reset and update gates
302 | r = tensor.nnet.sigmoid(_slice(preact, 0, dim))
303 | u = tensor.nnet.sigmoid(_slice(preact, 1, dim))
304 |
305 | # compute the hidden state proposal
306 | preactx = tensor.dot(h_, Ux)
307 | preactx = preactx * r
308 | preactx = preactx + xx_
309 |
310 | # hidden state proposal
311 | h = tensor.tanh(preactx)
312 |
313 | # leaky integrate and obtain next hidden state
314 | h = u * h_ + (1. - u) * h
315 | h = m_[:, None] * h + (1. - m_)[:, None] * h_
316 |
317 | return h
318 |
319 | # prepare scan arguments
320 | seqs = [mask, state_below_, state_belowx]
321 | init_states = [tensor.alloc(0., n_samples, dim)]
322 | _step = _step_slice
323 | shared_vars = [tparams[_p(prefix, 'U')],
324 | tparams[_p(prefix, 'Ux')]]
325 |
326 | rval, updates = theano.scan(_step,
327 | sequences=seqs,
328 | outputs_info=init_states,
329 | non_sequences=shared_vars,
330 | name=_p(prefix, '_layers'),
331 | n_steps=nsteps,
332 | profile=profile,
333 | strict=True)
334 | rval = [rval]
335 | return rval
336 |
337 | # GRU Deep layer
338 | def param_init_gru_cae(options, params, prefix='gru', nin=None, dim=None, dimctx=None):
339 | if nin is None:
340 | nin = options['dim_proj']
341 | if dim is None:
342 | dim = options['dim_proj']
343 | if dimctx is None:
344 | dimctx = options['dim_proj']
345 |
346 | # embedding to gates transformation weights, biases
347 | W = numpy.concatenate([norm_weight(nin, dim),
348 | norm_weight(nin, dim)], axis=1)
349 | params[_p(prefix, 'W')] = W
350 | params[_p(prefix, 'b')] = numpy.zeros((2 * dim,)).astype('float32')
351 |
352 | # recurrent transformation weights for gates
353 | U = numpy.concatenate([ortho_weight(dim),
354 | ortho_weight(dim)], axis=1)
355 | params[_p(prefix, 'U')] = U
356 |
357 | # embedding to hidden state proposal weights, biases
358 | Wx = norm_weight(nin, dim)
359 | params[_p(prefix, 'Wx')] = Wx
360 | params[_p(prefix, 'bx')] = numpy.zeros((dim,)).astype('float32')
361 |
362 | # recurrent transformation weights for hidden state proposal
363 | Ux = ortho_weight(dim)
364 | params[_p(prefix, 'Ux')] = Ux
365 |
366 | U_nl = numpy.concatenate([ortho_weight(dim),
367 | ortho_weight(dim)], axis=1)
368 | params[_p(prefix, 'U_nl')] = U_nl
369 | params[_p(prefix, 'b_nl')] = numpy.zeros((2 * dim,)).astype('float32')
370 |
371 | Ux_nl = ortho_weight(dim)
372 | params[_p(prefix, 'Ux_nl')] = Ux_nl
373 | params[_p(prefix, 'bx_nl')] = numpy.zeros((dim,)).astype('float32')
374 |
375 | # context to LSTM
376 | Wc = norm_weight(dimctx, dim*2)
377 | params[_p(prefix, 'Wc')] = Wc
378 |
379 | Wcx = norm_weight(dimctx, dim)
380 | params[_p(prefix, 'Wcx')] = Wcx
381 |
382 | return params
383 |
384 |
385 | def gru_cae_layer(tparams, state_below, options, prefix='gru', mask=None, context_below=None,
386 | **kwargs):
387 | nsteps = state_below.shape[0]
388 | if state_below.ndim == 3:
389 | if context_below:
390 | assert context_below.ndim == 3, 'Hi, context should have the same dimension as states'
391 | n_samples = state_below.shape[1]
392 | else:
393 | n_samples = 1
394 | if context_below is None:
395 | context_below = tensor.alloc(0, nsteps, n_samples, tparams[_p(prefix, 'Wc')].shape[0])
396 |
397 | dim = tparams[_p(prefix, 'Ux')].shape[1]
398 |
399 | if mask is None:
400 | mask = tensor.alloc(1., state_below.shape[0], 1)
401 |
402 | # utility function to slice a tensor
403 | def _slice(_x, n, dim):
404 | if _x.ndim == 3:
405 | return _x[:, :, n*dim:(n+1)*dim]
406 | return _x[:, n*dim:(n+1)*dim]
407 |
408 | # state_below is the input word embeddings
409 | # input to the gates, concatenated
410 | state_below_ = tensor.dot(state_below, tparams[_p(prefix, 'W')]) + \
411 | tparams[_p(prefix, 'b')]
412 | # input to compute the hidden state proposal
413 | state_belowx = tensor.dot(state_below, tparams[_p(prefix, 'Wx')]) + \
414 | tparams[_p(prefix, 'bx')]
415 |
416 | # step function to be used by scan
417 | # arguments | sequences |outputs-info| non-seqs
418 | def _step_slice(m_, x_, xx_, ctx_, h_, U, Ux, Wc, Wcx, U_nl, Ux_nl, b_nl, bx_nl):
419 | preact = tensor.dot(h_, U)
420 | preact += x_
421 |
422 | # reset and update gates
423 | r = tensor.nnet.sigmoid(_slice(preact, 0, dim))
424 | u = tensor.nnet.sigmoid(_slice(preact, 1, dim))
425 |
426 | # compute the hidden state proposal
427 | preactx = tensor.dot(h_, Ux)
428 | preactx = preactx * r
429 | preactx = preactx + xx_
430 |
431 | # hidden state proposal
432 | h = tensor.tanh(preactx)
433 |
434 | # leaky integrate and obtain next hidden state
435 | h = u * h_ + (1. - u) * h
436 | h1 = m_[:, None] * h + (1. - m_)[:, None] * h_
437 |
438 | preact2 = tensor.dot(h1, U_nl)+b_nl
439 | preact2 += tensor.dot(ctx_, Wc)
440 | preact2 = tensor.nnet.sigmoid(preact2)
441 |
442 | r2 = _slice(preact2, 0, dim)
443 | u2 = _slice(preact2, 1, dim)
444 |
445 | preactx2 = tensor.dot(h1, Ux_nl)+bx_nl
446 | preactx2 *= r2
447 | preactx2 += tensor.dot(ctx_, Wcx)
448 |
449 | h2 = tensor.tanh(preactx2)
450 |
451 | h2 = u2 * h1 + (1. - u2) * h2
452 | h2 = m_[:, None] * h2 + (1. - m_)[:, None] * h1
453 |
454 | return h2
455 |
456 | # prepare scan arguments
457 | seqs = [mask, state_below_, state_belowx, context_below]
458 | init_states = [tensor.alloc(0., n_samples, dim)]
459 | _step = _step_slice
460 | shared_vars = [tparams[_p(prefix, 'U')],
461 | tparams[_p(prefix, 'Ux')],
462 | tparams[_p(prefix, 'Wc')],
463 | tparams[_p(prefix, 'Wcx')],
464 | tparams[_p(prefix, 'U_nl')],
465 | tparams[_p(prefix, 'Ux_nl')],
466 | tparams[_p(prefix, 'b_nl')],
467 | tparams[_p(prefix, 'bx_nl')]]
468 |
469 | rval, updates = theano.scan(_step,
470 | sequences=seqs,
471 | outputs_info=init_states,
472 | non_sequences=shared_vars,
473 | name=_p(prefix, '_layers'),
474 | n_steps=nsteps,
475 | profile=profile,
476 | strict=True)
477 | rval = [rval]
478 | return rval
479 |
480 |
481 | # Conditional GRU layer with Attention
482 | def param_init_gru_cond(options, params, prefix='gru_cond',
483 | nin=None, dim=None, dimctx=None,
484 | nin_nonlin=None, dim_nonlin=None):
485 | if nin is None:
486 | nin = options['dim']
487 | if dim is None:
488 | dim = options['dim']
489 | if dimctx is None:
490 | dimctx = options['dim']
491 | if nin_nonlin is None:
492 | nin_nonlin = nin
493 | if dim_nonlin is None:
494 | dim_nonlin = dim
495 |
496 | W = numpy.concatenate([norm_weight(nin, dim),
497 | norm_weight(nin, dim)], axis=1)
498 | params[_p(prefix, 'W')] = W
499 | params[_p(prefix, 'b')] = numpy.zeros((2 * dim,)).astype('float32')
500 | U = numpy.concatenate([ortho_weight(dim_nonlin),
501 | ortho_weight(dim_nonlin)], axis=1)
502 | params[_p(prefix, 'U')] = U
503 |
504 | Wx = norm_weight(nin_nonlin, dim_nonlin)
505 | params[_p(prefix, 'Wx')] = Wx
506 | Ux = ortho_weight(dim_nonlin)
507 | params[_p(prefix, 'Ux')] = Ux
508 | params[_p(prefix, 'bx')] = numpy.zeros((dim_nonlin,)).astype('float32')
509 |
510 | U_nl = numpy.concatenate([ortho_weight(dim_nonlin),
511 | ortho_weight(dim_nonlin)], axis=1)
512 | params[_p(prefix, 'U_nl')] = U_nl
513 | params[_p(prefix, 'b_nl')] = numpy.zeros((2 * dim_nonlin,)).astype('float32')
514 |
515 | Ux_nl = ortho_weight(dim_nonlin)
516 | params[_p(prefix, 'Ux_nl')] = Ux_nl
517 | params[_p(prefix, 'bx_nl')] = numpy.zeros((dim_nonlin,)).astype('float32')
518 |
519 | # context to LSTM
520 | Wc = norm_weight(dimctx, dim*2)
521 | params[_p(prefix, 'Wc')] = Wc
522 |
523 | Wcx = norm_weight(dimctx, dim)
524 | params[_p(prefix, 'Wcx')] = Wcx
525 |
526 | # attention: combined -> hidden
527 | W_comb_att = norm_weight(dim, dimctx)
528 | params[_p(prefix, 'W_comb_att')] = W_comb_att
529 |
530 | # attention: context -> hidden
531 | Wc_att = norm_weight(dimctx)
532 | params[_p(prefix, 'Wc_att')] = Wc_att
533 |
534 | # attention: hidden bias
535 | b_att = numpy.zeros((dimctx,)).astype('float32')
536 | params[_p(prefix, 'b_att')] = b_att
537 |
538 | # attention:
539 | U_att = norm_weight(dimctx, 1)
540 | params[_p(prefix, 'U_att')] = U_att
541 | c_att = numpy.zeros((1,)).astype('float32')
542 | params[_p(prefix, 'c_tt')] = c_att
543 |
544 | return params
545 |
546 |
547 | def gru_cond_layer(tparams, state_below, options, prefix='gru',
548 | mask=None, context=None, one_step=False,
549 | init_memory=None, init_state=None,
550 | context_mask=None,
551 | **kwargs):
552 |
553 | assert context, 'Context must be provided'
554 |
555 | if one_step:
556 | assert init_state, 'previous state must be provided'
557 |
558 | nsteps = state_below.shape[0]
559 | if state_below.ndim == 3:
560 | n_samples = state_below.shape[1]
561 | else:
562 | n_samples = 1
563 |
564 | # mask
565 | if mask is None:
566 | mask = tensor.alloc(1., state_below.shape[0], 1)
567 |
568 | dim = tparams[_p(prefix, 'Wcx')].shape[1]
569 |
570 | # initial/previous state
571 | if init_state is None:
572 | init_state = tensor.alloc(0., n_samples, dim)
573 |
574 | # projected context
575 | assert context.ndim == 3, \
576 | 'Context must be 3-d: #annotation x #sample x dim'
577 | pctx_ = tensor.dot(context, tparams[_p(prefix, 'Wc_att')]) +\
578 | tparams[_p(prefix, 'b_att')]
579 |
580 | def _slice(_x, n, dim):
581 | if _x.ndim == 3:
582 | return _x[:, :, n*dim:(n+1)*dim]
583 | return _x[:, n*dim:(n+1)*dim]
584 |
585 | # projected x
586 | state_belowx = tensor.dot(state_below, tparams[_p(prefix, 'Wx')]) +\
587 | tparams[_p(prefix, 'bx')]
588 | state_below_ = tensor.dot(state_below, tparams[_p(prefix, 'W')]) +\
589 | tparams[_p(prefix, 'b')]
590 |
591 | def _step_slice(m_, x_, xx_, h_, ctx_, alpha_, pctx_, cc_,
592 | U, Wc, W_comb_att, U_att, c_tt, Ux, Wcx,
593 | U_nl, Ux_nl, b_nl, bx_nl):
594 | preact1 = tensor.dot(h_, U)
595 | preact1 += x_
596 | preact1 = tensor.nnet.sigmoid(preact1)
597 |
598 | r1 = _slice(preact1, 0, dim)
599 | u1 = _slice(preact1, 1, dim)
600 |
601 | preactx1 = tensor.dot(h_, Ux)
602 | preactx1 *= r1
603 | preactx1 += xx_
604 |
605 | h1 = tensor.tanh(preactx1)
606 |
607 | h1 = u1 * h_ + (1. - u1) * h1
608 | h1 = m_[:, None] * h1 + (1. - m_)[:, None] * h_
609 |
610 | # attention
611 | pstate_ = tensor.dot(h1, W_comb_att)
612 | pctx__ = pctx_ + pstate_[None, :, :]
613 | #pctx__ += xc_
614 | pctx__ = tensor.tanh(pctx__)
615 | alpha = tensor.dot(pctx__, U_att)+c_tt
616 | alpha = alpha.reshape([alpha.shape[0], alpha.shape[1]])
617 | alpha = tensor.exp(alpha)
618 | if context_mask:
619 | alpha = alpha * context_mask
620 | alpha = alpha / alpha.sum(0, keepdims=True)
621 | ctx_ = (cc_ * alpha[:, :, None]).sum(0) # current context
622 |
623 | preact2 = tensor.dot(h1, U_nl)+b_nl
624 | preact2 += tensor.dot(ctx_, Wc)
625 | preact2 = tensor.nnet.sigmoid(preact2)
626 |
627 | r2 = _slice(preact2, 0, dim)
628 | u2 = _slice(preact2, 1, dim)
629 |
630 | preactx2 = tensor.dot(h1, Ux_nl)+bx_nl
631 | preactx2 *= r2
632 | preactx2 += tensor.dot(ctx_, Wcx)
633 |
634 | h2 = tensor.tanh(preactx2)
635 |
636 | h2 = u2 * h1 + (1. - u2) * h2
637 | h2 = m_[:, None] * h2 + (1. - m_)[:, None] * h1
638 |
639 | return h2, ctx_, alpha.T # pstate_, preact, preactx, r, u
640 |
641 | seqs = [mask, state_below_, state_belowx]
642 | #seqs = [mask, state_below_, state_belowx, state_belowc]
643 | _step = _step_slice
644 |
645 | shared_vars = [tparams[_p(prefix, 'U')],
646 | tparams[_p(prefix, 'Wc')],
647 | tparams[_p(prefix, 'W_comb_att')],
648 | tparams[_p(prefix, 'U_att')],
649 | tparams[_p(prefix, 'c_tt')],
650 | tparams[_p(prefix, 'Ux')],
651 | tparams[_p(prefix, 'Wcx')],
652 | tparams[_p(prefix, 'U_nl')],
653 | tparams[_p(prefix, 'Ux_nl')],
654 | tparams[_p(prefix, 'b_nl')],
655 | tparams[_p(prefix, 'bx_nl')]]
656 |
657 | if one_step:
658 | rval = _step(*(seqs + [init_state, None, None, pctx_, context] +
659 | shared_vars))
660 | else:
661 | rval, updates = theano.scan(_step,
662 | sequences=seqs,
663 | outputs_info=[init_state,
664 | tensor.alloc(0., n_samples,
665 | context.shape[2]),
666 | tensor.alloc(0., n_samples,
667 | context.shape[0])],
668 | non_sequences=[pctx_, context]+shared_vars,
669 | name=_p(prefix, '_layers'),
670 | n_steps=nsteps,
671 | profile=profile,
672 | strict=True)
673 | return rval
674 |
675 |
676 | # initialize all parameters
677 | def init_params(options):
678 | params = OrderedDict()
679 |
680 | # embedding
681 | params['Wemb'] = norm_weight(options['n_words_src'], options['dim_word'])
682 | params['Wemb_dec'] = norm_weight(options['n_words'], options['dim_word'])
683 |
684 | # encoder: bidirectional RNN
685 | params = get_layer('gru')[0](options, params,
686 | prefix='encoder',
687 | nin=options['dim_word'],
688 | dim=options['dim'])
689 | params = get_layer('gru_cae')[0](options, params,
690 | prefix='encoder_r',
691 | nin=options['dim_word'],
692 | dim=options['dim'],
693 | dimctx=options['dim'])
694 | ctxdim = options['dim']
695 |
696 | # init_state, init_cell
697 | params = get_layer('ff')[0](options, params, prefix='ff_state',
698 | nin=ctxdim, nout=options['dim'])
699 | # decoder
700 | params = get_layer(options['decoder'])[0](options, params,
701 | prefix='decoder',
702 | nin=options['dim_word'],
703 | dim=options['dim'],
704 | dimctx=ctxdim)
705 | # readout
706 | params = get_layer('ff')[0](options, params, prefix='ff_logit_lstm',
707 | nin=options['dim'], nout=options['dim_word'],
708 | ortho=False)
709 | params = get_layer('ff')[0](options, params, prefix='ff_logit_prev',
710 | nin=options['dim_word'],
711 | nout=options['dim_word'], ortho=False)
712 | params = get_layer('ff')[0](options, params, prefix='ff_logit_ctx',
713 | nin=ctxdim, nout=options['dim_word'],
714 | ortho=False)
715 | params = get_layer('ff')[0](options, params, prefix='ff_logit',
716 | nin=options['dim_word'],
717 | nout=options['n_words'])
718 |
719 | return params
720 |
721 |
722 | # build a training model
723 | def build_model(tparams, options):
724 | opt_ret = dict()
725 |
726 | trng = RandomStreams(1234)
727 | use_noise = theano.shared(numpy.float32(0.))
728 |
729 | # description string: #words x #samples
730 | x = tensor.matrix('x', dtype='int64')
731 | x_mask = tensor.matrix('x_mask', dtype='float32')
732 | y = tensor.matrix('y', dtype='int64')
733 | y_mask = tensor.matrix('y_mask', dtype='float32')
734 |
735 | # for the backward rnn, we just need to invert x and x_mask
736 | xr = x[::-1]
737 | xr_mask = x_mask[::-1]
738 |
739 | n_timesteps = x.shape[0]
740 | n_timesteps_trg = y.shape[0]
741 | n_samples = x.shape[1]
742 |
743 | # word embedding for forward rnn (source)
744 | emb = tparams['Wemb'][x.flatten()]
745 | emb = emb.reshape([n_timesteps, n_samples, options['dim_word']])
746 | proj = get_layer('gru')[1](tparams, emb, options,
747 | prefix='encoder',
748 | mask=x_mask)
749 | # word embedding for backward rnn (source)
750 | embr = tparams['Wemb'][xr.flatten()]
751 | embr = embr.reshape([n_timesteps, n_samples, options['dim_word']])
752 | projr = get_layer('gru_cae')[1](tparams, embr, options,
753 | prefix='encoder_r',
754 | context_below=proj[0][::-1],
755 | mask=xr_mask)
756 |
757 | # context will be the concatenation of forward and backward rnns
758 | ctx = projr[0][::-1] ## concatenate([proj[0], projr[0][::-1]], axis=proj[0].ndim-1)
759 |
760 | # mean of the context (across time) will be used to initialize decoder rnn
761 | ctx_mean = (ctx * x_mask[:, :, None]).sum(0) / x_mask.sum(0)[:, None]
762 |
763 | # or you can use the last state of forward + backward encoder rnns
764 | # ctx_mean = concatenate([proj[0][-1], projr[0][-1]], axis=proj[0].ndim-2)
765 |
766 | # initial decoder state
767 | init_state = get_layer('ff')[1](tparams, ctx_mean, options,
768 | prefix='ff_state', activ='tanh')
769 |
770 | # word embedding (target), we will shift the target sequence one time step
771 | # to the right. This is done because of the bi-gram connections in the
772 | # readout and decoder rnn. The first target will be all zeros and we will
773 | # not condition on the last output.
774 | emb = tparams['Wemb_dec'][y.flatten()]
775 | emb = emb.reshape([n_timesteps_trg, n_samples, options['dim_word']])
776 | emb_shifted = tensor.zeros_like(emb)
777 | emb_shifted = tensor.set_subtensor(emb_shifted[1:], emb[:-1])
778 | emb = emb_shifted
779 |
780 | # decoder - pass through the decoder conditional gru with attention
781 | proj = get_layer(options['decoder'])[1](tparams, emb, options,
782 | prefix='decoder',
783 | mask=y_mask, context=ctx,
784 | context_mask=x_mask,
785 | one_step=False,
786 | init_state=init_state)
787 | # hidden states of the decoder gru
788 | proj_h = proj[0]
789 |
790 | # weighted averages of context, generated by attention module
791 | ctxs = proj[1]
792 |
793 | # weights (alignment matrix)
794 | opt_ret['dec_alphas'] = proj[2]
795 |
796 | # compute word probabilities
797 | logit_lstm = get_layer('ff')[1](tparams, proj_h, options,
798 | prefix='ff_logit_lstm', activ='linear')
799 | logit_prev = get_layer('ff')[1](tparams, emb, options,
800 | prefix='ff_logit_prev', activ='linear')
801 | logit_ctx = get_layer('ff')[1](tparams, ctxs, options,
802 | prefix='ff_logit_ctx', activ='linear')
803 | logit = tensor.tanh(logit_lstm+logit_prev+logit_ctx)
804 | if options['use_dropout']:
805 | logit = dropout_layer(logit, use_noise, trng)
806 | logit = get_layer('ff')[1](tparams, logit, options,
807 | prefix='ff_logit', activ='linear')
808 | logit_shp = logit.shape
809 | probs = tensor.nnet.softmax(logit.reshape([logit_shp[0]*logit_shp[1],
810 | logit_shp[2]]))
811 |
812 | # cost
813 | y_flat = y.flatten()
814 | y_flat_idx = tensor.arange(y_flat.shape[0]) * options['n_words'] + y_flat
815 | cost = -tensor.log(probs.flatten()[y_flat_idx])
816 | cost = cost.reshape([y.shape[0], y.shape[1]])
817 | cost = (cost * y_mask).sum(0)
818 |
819 | return trng, use_noise, x, x_mask, y, y_mask, opt_ret, cost
820 |
821 |
822 | # build a sampler
823 | def build_sampler(tparams, options, trng, use_noise):
824 | x = tensor.matrix('x', dtype='int64')
825 | xr = x[::-1]
826 | n_timesteps = x.shape[0]
827 | n_samples = x.shape[1]
828 |
829 | # word embedding (source), forward and backward
830 | emb = tparams['Wemb'][x.flatten()]
831 | emb = emb.reshape([n_timesteps, n_samples, options['dim_word']])
832 | embr = tparams['Wemb'][xr.flatten()]
833 | embr = embr.reshape([n_timesteps, n_samples, options['dim_word']])
834 |
835 | # encoder
836 | proj = get_layer('gru')[1](tparams, emb, options,
837 | prefix='encoder')
838 | projr = get_layer('gru_cae')[1](tparams, embr, options,
839 | prefix='encoder_r',
840 | context_below=proj[0][::-1])
841 |
842 | # concatenate forward and backward rnn hidden states
843 | ctx = projr[0][::-1] ##concatenate([proj[0], projr[0][::-1]], axis=proj[0].ndim-1)
844 |
845 | # get the input for decoder rnn initializer mlp
846 | ctx_mean = ctx.mean(0)
847 | # ctx_mean = concatenate([proj[0][-1],projr[0][-1]], axis=proj[0].ndim-2)
848 | init_state = get_layer('ff')[1](tparams, ctx_mean, options,
849 | prefix='ff_state', activ='tanh')
850 |
851 | print 'Building f_init...',
852 | outs = [init_state, ctx]
853 | f_init = theano.function([x], outs, name='f_init', profile=profile)
854 | print 'Done'
855 |
856 | # x: 1 x 1
857 | y = tensor.vector('y_sampler', dtype='int64')
858 | init_state = tensor.matrix('init_state', dtype='float32')
859 |
860 | # if it's the first word, emb should be all zero and it is indicated by -1
861 | emb = tensor.switch(y[:, None] < 0,
862 | tensor.alloc(0., 1, tparams['Wemb_dec'].shape[1]),
863 | tparams['Wemb_dec'][y])
864 |
865 | # apply one step of conditional gru with attention
866 | proj = get_layer(options['decoder'])[1](tparams, emb, options,
867 | prefix='decoder',
868 | mask=None, context=ctx,
869 | one_step=True,
870 | init_state=init_state)
871 | # get the next hidden state
872 | next_state = proj[0]
873 |
874 | # get the weighted averages of context for this target word y
875 | ctxs = proj[1]
876 |
877 | logit_lstm = get_layer('ff')[1](tparams, next_state, options,
878 | prefix='ff_logit_lstm', activ='linear')
879 | logit_prev = get_layer('ff')[1](tparams, emb, options,
880 | prefix='ff_logit_prev', activ='linear')
881 | logit_ctx = get_layer('ff')[1](tparams, ctxs, options,
882 | prefix='ff_logit_ctx', activ='linear')
883 | logit = tensor.tanh(logit_lstm+logit_prev+logit_ctx)
884 | if options['use_dropout']:
885 | logit = dropout_layer(logit, use_noise, trng)
886 | logit = get_layer('ff')[1](tparams, logit, options,
887 | prefix='ff_logit', activ='linear')
888 |
889 | # compute the softmax probability
890 | next_probs = tensor.nnet.softmax(logit)
891 |
892 | # sample from softmax distribution to get the sample
893 | next_sample = trng.multinomial(pvals=next_probs).argmax(1)
894 |
895 | # compile a function to do the whole thing above, next word probability,
896 | # sampled word for the next target, next hidden state to be used
897 | print 'Building f_next..',
898 | inps = [y, ctx, init_state]
899 | outs = [next_probs, next_sample, next_state]
900 | f_next = theano.function(inps, outs, name='f_next', profile=profile)
901 | print 'Done'
902 |
903 | return f_init, f_next
904 |
905 |
906 | # generate sample, either with stochastic sampling or beam search. Note that,
907 | # this function iteratively calls f_init and f_next functions.
908 | def gen_sample(tparams, f_init, f_next, x, options, trng=None, k=1, maxlen=30,
909 | stochastic=True, argmax=False):
910 |
911 | # k is the beam size we have
912 | if k > 1:
913 | assert not stochastic, \
914 | 'Beam search does not support stochastic sampling'
915 |
916 | sample = []
917 | sample_score = []
918 | if stochastic:
919 | sample_score = 0
920 |
921 | live_k = 1
922 | dead_k = 0
923 |
924 | hyp_samples = [[]] * live_k
925 | hyp_scores = numpy.zeros(live_k).astype('float32')
926 | hyp_states = []
927 |
928 | # get initial state of decoder rnn and encoder context
929 | ret = f_init(x)
930 | next_state, ctx0 = ret[0], ret[1]
931 | next_w = -1 * numpy.ones((1,)).astype('int64') # bos indicator
932 |
933 | for ii in xrange(maxlen):
934 | ctx = numpy.tile(ctx0, [live_k, 1])
935 | inps = [next_w, ctx, next_state]
936 | ret = f_next(*inps)
937 | next_p, next_w, next_state = ret[0], ret[1], ret[2]
938 |
939 | if stochastic:
940 | if argmax:
941 | nw = next_p[0].argmax()
942 | else:
943 | nw = next_w[0]
944 | sample.append(nw)
945 | sample_score -= numpy.log(next_p[0, nw])
946 | if nw == 0:
947 | break
948 | else:
949 | cand_scores = hyp_scores[:, None] - numpy.log(next_p)
950 | cand_flat = cand_scores.flatten()
951 | ranks_flat = cand_flat.argsort()[:(k-dead_k)]
952 |
953 | voc_size = next_p.shape[1]
954 | trans_indices = ranks_flat / voc_size
955 | word_indices = ranks_flat % voc_size
956 | costs = cand_flat[ranks_flat]
957 |
958 | new_hyp_samples = []
959 | new_hyp_scores = numpy.zeros(k-dead_k).astype('float32')
960 | new_hyp_states = []
961 |
962 | for idx, [ti, wi] in enumerate(zip(trans_indices, word_indices)):
963 | new_hyp_samples.append(hyp_samples[ti]+[wi])
964 | new_hyp_scores[idx] = copy.copy(costs[idx])
965 | new_hyp_states.append(copy.copy(next_state[ti]))
966 |
967 | # check the finished samples
968 | new_live_k = 0
969 | hyp_samples = []
970 | hyp_scores = []
971 | hyp_states = []
972 |
973 | for idx in xrange(len(new_hyp_samples)):
974 | if new_hyp_samples[idx][-1] == 0:
975 | sample.append(new_hyp_samples[idx])
976 | sample_score.append(new_hyp_scores[idx])
977 | dead_k += 1
978 | else:
979 | new_live_k += 1
980 | hyp_samples.append(new_hyp_samples[idx])
981 | hyp_scores.append(new_hyp_scores[idx])
982 | hyp_states.append(new_hyp_states[idx])
983 | hyp_scores = numpy.array(hyp_scores)
984 | live_k = new_live_k
985 |
986 | if new_live_k < 1:
987 | break
988 | if dead_k >= k:
989 | break
990 |
991 | next_w = numpy.array([w[-1] for w in hyp_samples])
992 | next_state = numpy.array(hyp_states)
993 |
994 | if not stochastic:
995 | # dump every remaining one
996 | if live_k > 0:
997 | for idx in xrange(live_k):
998 | sample.append(hyp_samples[idx])
999 | sample_score.append(hyp_scores[idx])
1000 |
1001 | return sample, sample_score
1002 |
1003 | # calculate the BLEU probabilities on a given corpus using translation model
1004 | def pred_bleus(source, refer, dic_src, dic_tgt, model, saveto="hypo.trans.plain",
1005 | k_=5, n_=False, c_=False, b_=1, is_nist=True):
1006 | # step 1. run the translation script
1007 | cmd = "python %s/sample.py -b %d -k %d" % (root_path, b_, k_)
1008 | if n_:
1009 | cmd += ' -n'
1010 | if c_:
1011 | cmd += ' -c'
1012 | cmd += ' %s %s %s %s %s' % (model, dic_src, dic_tgt, source, saveto+".bpe")
1013 | print cmd
1014 | run(cmd)
1015 |
1016 | cmd = 'sed "s/@@ //g" < %s > %s' % (saveto+".bpe", saveto)
1017 | print cmd
1018 | run(cmd)
1019 |
1020 | # step 2. beginning evaluation
1021 | if is_nist:
1022 | dev_bleu = eval_nist(source, refer, saveto)
1023 | else:
1024 | dev_bleu = eval_moses(source, refer, saveto)
1025 |
1026 | return -1*dev_bleu
1027 |
1028 | # calculate the log probablities on a given corpus using translation model
1029 | def pred_probs(f_log_probs, prepare_data, options, iterator, verbose=True):
1030 | probs = []
1031 |
1032 | n_done = 0
1033 |
1034 | for x, y in iterator:
1035 | n_done += len(x)
1036 |
1037 | x, x_mask, y, y_mask = prepare_data(x, y,
1038 | n_words_src=options['n_words_src'],
1039 | n_words=options['n_words'])
1040 |
1041 | pprobs = f_log_probs(x, x_mask, y, y_mask)
1042 | for pp in pprobs:
1043 | probs.append(pp)
1044 |
1045 | if numpy.isnan(numpy.mean(probs)):
1046 | ipdb.set_trace()
1047 |
1048 | if verbose:
1049 | print >>sys.stderr, '%d samples computed' % (n_done)
1050 |
1051 | return numpy.array(probs)
1052 |
1053 |
1054 | # optimizers
1055 | # name(hyperp, tparams, grads, inputs (list), cost) = f_grad_shared, f_update
1056 | def adam(lr, tparams, grads, inp, cost, beta1=0.9, beta2=0.999, e=1e-8):
1057 |
1058 | gshared = [theano.shared(p.get_value() * 0., name='%s_grad' % k)
1059 | for k, p in tparams.iteritems()]
1060 | gsup = [(gs, g) for gs, g in zip(gshared, grads)]
1061 |
1062 | f_grad_shared = theano.function(inp, cost, updates=gsup, profile=profile)
1063 |
1064 | updates = []
1065 |
1066 | t_prev = theano.shared(numpy.float32(0.))
1067 | t = t_prev + 1.
1068 | lr_t = lr * tensor.sqrt(1. - beta2**t) / (1. - beta1**t)
1069 |
1070 | for p, g in zip(tparams.values(), gshared):
1071 | m = theano.shared(p.get_value() * 0., p.name + '_mean')
1072 | v = theano.shared(p.get_value() * 0., p.name + '_variance')
1073 | m_t = beta1 * m + (1. - beta1) * g
1074 | v_t = beta2 * v + (1. - beta2) * g**2
1075 | step = lr_t * m_t / (tensor.sqrt(v_t) + e)
1076 | p_t = p - step
1077 | updates.append((m, m_t))
1078 | updates.append((v, v_t))
1079 | updates.append((p, p_t))
1080 | updates.append((t_prev, t))
1081 |
1082 | f_update = theano.function([lr], [], updates=updates,
1083 | on_unused_input='ignore', profile=profile)
1084 |
1085 | return f_grad_shared, f_update
1086 |
1087 |
1088 | def adadelta(lr, tparams, grads, inp, cost):
1089 | zipped_grads = [theano.shared(p.get_value() * numpy.float32(0.),
1090 | name='%s_grad' % k)
1091 | for k, p in tparams.iteritems()]
1092 | running_up2 = [theano.shared(p.get_value() * numpy.float32(0.),
1093 | name='%s_rup2' % k)
1094 | for k, p in tparams.iteritems()]
1095 | running_grads2 = [theano.shared(p.get_value() * numpy.float32(0.),
1096 | name='%s_rgrad2' % k)
1097 | for k, p in tparams.iteritems()]
1098 |
1099 | zgup = [(zg, g) for zg, g in zip(zipped_grads, grads)]
1100 | rg2up = [(rg2, 0.95 * rg2 + 0.05 * (g ** 2))
1101 | for rg2, g in zip(running_grads2, grads)]
1102 |
1103 | f_grad_shared = theano.function(inp, cost, updates=zgup+rg2up,
1104 | profile=profile)
1105 |
1106 | updir = [-tensor.sqrt(ru2 + 1e-6) / tensor.sqrt(rg2 + 1e-6) * zg
1107 | for zg, ru2, rg2 in zip(zipped_grads, running_up2,
1108 | running_grads2)]
1109 | ru2up = [(ru2, 0.95 * ru2 + 0.05 * (ud ** 2))
1110 | for ru2, ud in zip(running_up2, updir)]
1111 | param_up = [(p, p + lr*ud) for p, ud in zip(itemlist(tparams), updir)]
1112 |
1113 | f_update = theano.function([lr], [], updates=ru2up+param_up,
1114 | on_unused_input='ignore', profile=profile)
1115 |
1116 | return f_grad_shared, f_update
1117 |
1118 |
1119 | def rmsprop(lr, tparams, grads, inp, cost):
1120 | zipped_grads = [theano.shared(p.get_value() * numpy.float32(0.),
1121 | name='%s_grad' % k)
1122 | for k, p in tparams.iteritems()]
1123 | running_grads = [theano.shared(p.get_value() * numpy.float32(0.),
1124 | name='%s_rgrad' % k)
1125 | for k, p in tparams.iteritems()]
1126 | running_grads2 = [theano.shared(p.get_value() * numpy.float32(0.),
1127 | name='%s_rgrad2' % k)
1128 | for k, p in tparams.iteritems()]
1129 |
1130 | zgup = [(zg, g) for zg, g in zip(zipped_grads, grads)]
1131 | # change pho to 0.99
1132 | ### rgup = [(rg, 0.95 * rg + 0.05 * g) for rg, g in zip(running_grads, grads)]
1133 | ### rg2up = [(rg2, 0.95 * rg2 + 0.05 * (g ** 2))
1134 | ### for rg2, g in zip(running_grads2, grads)]
1135 | rgup = [(rg, 0.99 * rg + 0.01 * g) for rg, g in zip(running_grads, grads)]
1136 | rg2up = [(rg2, 0.99 * rg2 + 0.01 * (g ** 2))
1137 | for rg2, g in zip(running_grads2, grads)]
1138 |
1139 | f_grad_shared = theano.function(inp, cost, updates=zgup+rgup+rg2up,
1140 | profile=profile)
1141 |
1142 | updir = [theano.shared(p.get_value() * numpy.float32(0.),
1143 | name='%s_updir' % k)
1144 | for k, p in tparams.iteritems()]
1145 | # changed 0.9 to 0; change 1e-4 to lr
1146 | ### updir_new = [(ud, 0.9 * ud - 1e-4 * zg / tensor.sqrt(rg2 - rg ** 2 + 1e-4))
1147 | ### for ud, zg, rg, rg2 in zip(updir, zipped_grads, running_grads,
1148 | ### running_grads2)]
1149 | updir_new = [(ud, 0. * ud - lr * zg / tensor.sqrt(rg2 - rg ** 2 + 1e-4))
1150 | for ud, zg, rg, rg2 in zip(updir, zipped_grads, running_grads,
1151 | running_grads2)]
1152 | param_up = [(p, p + udn[1])
1153 | for p, udn in zip(itemlist(tparams), updir_new)]
1154 | f_update = theano.function([lr], [], updates=updir_new+param_up,
1155 | on_unused_input='ignore', profile=profile)
1156 |
1157 | return f_grad_shared, f_update
1158 |
1159 |
1160 | def sgd(lr, tparams, grads, x, mask, y, cost):
1161 | gshared = [theano.shared(p.get_value() * 0.,
1162 | name='%s_grad' % k)
1163 | for k, p in tparams.iteritems()]
1164 | gsup = [(gs, g) for gs, g in zip(gshared, grads)]
1165 |
1166 | f_grad_shared = theano.function([x, mask, y], cost, updates=gsup,
1167 | profile=profile)
1168 |
1169 | pup = [(p, p - lr * g) for p, g in zip(itemlist(tparams), gshared)]
1170 | f_update = theano.function([lr], [], updates=pup, profile=profile)
1171 |
1172 | return f_grad_shared, f_update
1173 |
1174 |
1175 | def train(dim_word=100, # word vector dimensionality
1176 | dim=1000, # the number of LSTM units
1177 | max_epochs=5000,
1178 | finish_after=10000000, # finish after this many updates
1179 | dispFreq=100,
1180 | decay_c=0., # L2 regularization penalty
1181 | alpha_c=0., # alignment regularization
1182 | clip_c=-1., # gradient clipping threshold
1183 | lrate=0.01, # learning rate
1184 | n_words_src=100000, # source vocabulary size
1185 | n_words=100000, # target vocabulary size
1186 | maxlen=100, # maximum length of the description
1187 | optimizer='rmsprop',
1188 | batch_size=16,
1189 | valid_batch_size=16,
1190 | saveto='model.npz',
1191 | validFreq=1000,
1192 | validFreqLeast=10000, # at least greater this can be validated
1193 | validFreqFires=10000, # a split, before is relatively meaningless, after is more important
1194 | validFreqRefine=1000, # refined valid frequency
1195 | saveFreq=1000, # save the parameters after every saveFreq updates
1196 | sampleFreq=100, # generate some samples after every sampleFreq
1197 | datasets=[
1198 | '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.en.tok',
1199 | '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.fr.tok'],
1200 | valid_datasets=['../data/dev/newstest2011.en.tok',
1201 | '../data/dev/newstest2011.fr.tok'],
1202 | dictionaries=[
1203 | '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.en.tok.pkl',
1204 | '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.fr.tok.pkl'],
1205 | use_dropout=False,
1206 | reload_=False,
1207 | overwrite=False,
1208 | use_bleueval=True,
1209 | save_devscore_to='score.log',
1210 | save_devtrans_to='trans.txt',
1211 | beam_size=10,
1212 | normalize=False,
1213 | output_nbest=1,
1214 | shuffle_train=1,
1215 | is_eval_nist=True,
1216 | **known_args):
1217 |
1218 | # Model options
1219 | model_options = locals().copy()
1220 |
1221 | # load dictionaries and invert them
1222 | worddicts = [None] * len(dictionaries)
1223 | worddicts_r = [None] * len(dictionaries)
1224 | for ii, dd in enumerate(dictionaries):
1225 | with open(dd, 'rb') as f:
1226 | worddicts[ii] = pkl.load(f)
1227 | worddicts_r[ii] = dict()
1228 | for kk, vv in worddicts[ii].iteritems():
1229 | worddicts_r[ii][vv] = kk
1230 |
1231 | # reload options
1232 | if reload_ and os.path.exists(saveto):
1233 | print 'Reloading model options'
1234 | with open('%s.pkl' % saveto, 'rb') as f:
1235 | model_options.update(pkl.load(f))
1236 |
1237 | print 'Loading data'
1238 | train = TextIterator(datasets[0], datasets[1],
1239 | dictionaries[0], dictionaries[1],
1240 | n_words_source=n_words_src, n_words_target=n_words,
1241 | batch_size=batch_size,
1242 | maxlen=maxlen,shuffle_prob=shuffle_train)
1243 | try:
1244 | valid = TextIterator(valid_datasets[0], valid_datasets[1],
1245 | dictionaries[0], dictionaries[1],
1246 | n_words_source=n_words_src, n_words_target=n_words,
1247 | batch_size=valid_batch_size,
1248 | maxlen=maxlen,shuffle_prob=1)
1249 | except IOError:
1250 | if use_bleueval:
1251 | pass
1252 | else:
1253 | print 'You do not use BLEU eval, please check your validate dataset!'
1254 | raise
1255 |
1256 | print 'Building model'
1257 | params = init_params(model_options)
1258 | # reload parameters
1259 | if reload_ and os.path.exists(saveto):
1260 | print 'Reloading model parameters'
1261 | params = load_params(saveto, params)
1262 |
1263 | tparams = init_tparams(params)
1264 |
1265 | trng, use_noise, \
1266 | x, x_mask, y, y_mask, \
1267 | opt_ret, \
1268 | cost = \
1269 | build_model(tparams, model_options)
1270 | inps = [x, x_mask, y, y_mask]
1271 |
1272 | print 'Building sampler'
1273 | f_init, f_next = build_sampler(tparams, model_options, trng, use_noise)
1274 |
1275 | # before any regularizer
1276 | print 'Building f_log_probs...',
1277 | f_log_probs = theano.function(inps, cost, profile=profile)
1278 | print 'Done'
1279 |
1280 | cost = cost.mean()
1281 |
1282 | # apply L2 regularization on weights
1283 | if decay_c > 0.:
1284 | decay_c = theano.shared(numpy.float32(decay_c), name='decay_c')
1285 | weight_decay = 0.
1286 | for kk, vv in tparams.iteritems():
1287 | weight_decay += (vv ** 2).sum()
1288 | weight_decay *= decay_c
1289 | cost += weight_decay
1290 |
1291 | # regularize the alpha weights
1292 | if alpha_c > 0. and not model_options['decoder'].endswith('simple'):
1293 | alpha_c = theano.shared(numpy.float32(alpha_c), name='alpha_c')
1294 | alpha_reg = alpha_c * (
1295 | (tensor.cast(y_mask.sum(0)//x_mask.sum(0), 'float32')[:, None] -
1296 | opt_ret['dec_alphas'].sum(0))**2).sum(1).mean()
1297 | cost += alpha_reg
1298 |
1299 | # after all regularizers - compile the computational graph for cost
1300 | print 'Building f_cost...',
1301 | f_cost = theano.function(inps, cost, profile=profile)
1302 | print 'Done'
1303 |
1304 | print 'Computing gradient...',
1305 | grads = tensor.grad(cost, wrt=itemlist(tparams))
1306 | print 'Done'
1307 |
1308 | # apply gradient clipping here
1309 | if clip_c > 0.:
1310 | g2 = 0.
1311 | for g in grads:
1312 | g2 += (g**2).sum()
1313 | new_grads = []
1314 | for g in grads:
1315 | new_grads.append(tensor.switch(g2 > (clip_c**2),
1316 | g / tensor.sqrt(g2) * clip_c,
1317 | g))
1318 | grads = new_grads
1319 |
1320 | # compile the optimizer, the actual computational graph is compiled here
1321 | lr = tensor.scalar(name='lr')
1322 | print 'Building optimizers...',
1323 | f_grad_shared, f_update = eval(optimizer)(lr, tparams, grads, inps, cost)
1324 | print 'Done'
1325 |
1326 | print 'Optimization'
1327 |
1328 | best_p = None
1329 | uidx = 0
1330 | estop = False
1331 | history_errs = []
1332 | # reload history
1333 | if reload_ and os.path.exists(saveto):
1334 | rmodel = numpy.load(saveto)
1335 | history_errs = list(rmodel['history_errs'])
1336 | if 'uidx' in rmodel:
1337 | uidx = rmodel['uidx']
1338 |
1339 | if validFreq == -1:
1340 | validFreq = len(train[0])/batch_size
1341 | if saveFreq == -1:
1342 | saveFreq = len(train[0])/batch_size
1343 | if sampleFreq == -1:
1344 | sampleFreq = len(train[0])/batch_size
1345 |
1346 | for eidx in xrange(max_epochs):
1347 | if eidx >= 1:
1348 | lrate = lrate / 2.
1349 | print 'learning rate decay from %s to %s' % (lrate * 2., lrate)
1350 | n_samples = 0
1351 |
1352 | for x, y in train:
1353 | n_samples += len(x)
1354 | uidx += 1
1355 | use_noise.set_value(1.)
1356 |
1357 | x, x_mask, y, y_mask = prepare_data(x, y, maxlen=maxlen,
1358 | n_words_src=n_words_src,
1359 | n_words=n_words)
1360 |
1361 | if x is None:
1362 | print 'Minibatch with zero sample under length ', maxlen
1363 | uidx -= 1
1364 | continue
1365 |
1366 | ud_start = time.time()
1367 |
1368 | # compute cost, grads and copy grads to shared variables
1369 | cost = f_grad_shared(x, x_mask, y, y_mask)
1370 |
1371 | # do the update on parameters
1372 | f_update(lrate)
1373 |
1374 | ud = time.time() - ud_start
1375 |
1376 | # check for bad numbers, usually we remove non-finite elements
1377 | # and continue training - but not done here
1378 | if numpy.isnan(cost) or numpy.isinf(cost):
1379 | print 'NaN detected'
1380 | return 1., 1., 1.
1381 |
1382 | # verbose
1383 | if numpy.mod(uidx, dispFreq) == 0:
1384 | print 'Epoch ', eidx, 'Update ', uidx, 'Cost ', cost, 'UD ', ud, 's'
1385 |
1386 | # save the best model so far, in addition, save the latest model
1387 | # into a separate file with the iteration number for external eval
1388 | if numpy.mod(uidx, saveFreq) == 0:
1389 | print 'Saving the best model...',
1390 | if best_p is not None:
1391 | params = best_p
1392 | else:
1393 | params = unzip(tparams)
1394 | numpy.savez("best_"+saveto, history_errs=history_errs, uidx=uidx, **params)
1395 | pkl.dump(model_options, open('%s.pkl' % ("best_"+saveto), 'wb'))
1396 | print 'Done'
1397 |
1398 | # save with uidx
1399 | if not overwrite:
1400 | print 'Saving the model at iteration {}...'.format(uidx),
1401 | saveto_uidx = '{}.iter{}.npz'.format(
1402 | os.path.splitext(saveto)[0], uidx)
1403 | numpy.savez(saveto_uidx, history_errs=history_errs,
1404 | uidx=uidx, **unzip(tparams))
1405 | print 'Done'
1406 |
1407 |
1408 | # generate some samples with the model and display them
1409 | if numpy.mod(uidx, sampleFreq) == 0:
1410 | # FIXME: random selection?
1411 | for jj in xrange(numpy.minimum(5, x.shape[1])):
1412 | stochastic = True
1413 | sample, score = gen_sample(tparams, f_init, f_next,
1414 | x[:, jj][:, None],
1415 | model_options, trng=trng, k=1,
1416 | maxlen=30,
1417 | stochastic=stochastic,
1418 | argmax=False)
1419 | print 'Source ', jj, ': ',
1420 | for vv in x[:, jj]:
1421 | if vv == 0:
1422 | break
1423 | if vv in worddicts_r[0]:
1424 | print worddicts_r[0][vv],
1425 | else:
1426 | print 'UNK',
1427 | print
1428 | print 'Truth ', jj, ' : ',
1429 | for vv in y[:, jj]:
1430 | if vv == 0:
1431 | break
1432 | if vv in worddicts_r[1]:
1433 | print worddicts_r[1][vv],
1434 | else:
1435 | print 'UNK',
1436 | print
1437 | print 'Sample ', jj, ': ',
1438 | if stochastic:
1439 | ss = sample
1440 | else:
1441 | score = score / numpy.array([len(s) for s in sample])
1442 | ss = sample[score.argmin()]
1443 | for vv in ss:
1444 | if vv == 0:
1445 | break
1446 | if vv in worddicts_r[1]:
1447 | print worddicts_r[1][vv],
1448 | else:
1449 | print 'UNK',
1450 | print
1451 |
1452 | # validate model on validation set and early stop if necessary
1453 | if (uidx >= validFreqLeast) and \
1454 | ((uidx < validFreqFires and numpy.mod(uidx, validFreq) == 0) \
1455 | or (uidx >= validFreqFires and numpy.mod(uidx, validFreqRefine) == 0)):
1456 | use_noise.set_value(0.)
1457 | if not use_bleueval:
1458 | valid_errs = pred_probs(f_log_probs, prepare_data,
1459 | model_options, valid)
1460 | valid_err = valid_errs.mean()
1461 | else:
1462 | params = unzip(tparams)
1463 | numpy.savez(saveto, history_errs=history_errs, uidx=uidx, **params)
1464 | pkl.dump(model_options, open('%s.pkl' % saveto, 'wb'))
1465 | valid_err = pred_bleus(valid_datasets[0], valid_datasets[1],
1466 | dictionaries[0], dictionaries[1], saveto,
1467 | saveto=save_devtrans_to, k_=beam_size, n_=normalize,
1468 | b_=output_nbest, is_nist=is_eval_nist)
1469 | print "development set bleu:\t", valid_err * -1
1470 | run('echo "Step:%s Development set bleu:%s" | cat >> %s' \
1471 | % (uidx, valid_err * -1, save_devscore_to))
1472 |
1473 | history_errs.append(valid_err)
1474 | if valid_err <= numpy.array(history_errs).min():
1475 | best_p = unzip(tparams)
1476 | numpy.savez("best_"+saveto, history_errs=history_errs, uidx=uidx, **params)
1477 | pkl.dump(model_options, open('%s.pkl' % ("best_"+saveto), 'wb'))
1478 | os.system('cp %s best_%s' % (save_devtrans_to, save_devtrans_to))
1479 | os.system('cp %s.eval.nmt best_%s.eval.nmt' % (save_devtrans_to, save_devtrans_to))
1480 |
1481 | if numpy.isnan(valid_err):
1482 | ipdb.set_trace()
1483 |
1484 | print 'Valid ', valid_err
1485 |
1486 | # finish after this many updates
1487 | if uidx >= finish_after:
1488 | print 'Finishing after %d iterations!' % uidx
1489 | estop = True
1490 | break
1491 |
1492 | print 'Seen %d samples' % n_samples
1493 |
1494 | if estop:
1495 | break
1496 |
1497 | if best_p is not None:
1498 | zipp(best_p, tparams)
1499 |
1500 | use_noise.set_value(0.)
1501 | valid_err = pred_probs(f_log_probs, prepare_data,
1502 | model_options, valid).mean()
1503 |
1504 | print 'Valid ', valid_err
1505 |
1506 | params = copy.copy(best_p)
1507 | numpy.savez(saveto, zipped_params=best_p,
1508 | history_errs=history_errs,
1509 | uidx=uidx,
1510 | **params)
1511 |
1512 | return valid_err
1513 |
1514 |
1515 | if __name__ == '__main__':
1516 | pass
1517 |
--------------------------------------------------------------------------------
/code/sample.py:
--------------------------------------------------------------------------------
1 | '''
2 | Translates a source file using a translation model.
3 | '''
4 | import argparse
5 |
6 | import numpy
7 | import cPickle as pkl
8 |
9 | from nmt import (build_sampler, gen_sample, load_params,
10 | init_params, init_tparams)
11 |
12 | def translate_model(model, options, k, normalize, n_best):
13 |
14 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
15 | from theano import shared
16 | trng = RandomStreams(1234)
17 | use_noise = shared(numpy.float32(0.))
18 |
19 | # allocate model parameters
20 | params = init_params(options)
21 |
22 | # load model parameters and set theano shared variables
23 | params = load_params(model, params)
24 | tparams = init_tparams(params)
25 |
26 | # word index
27 | f_init, f_next = build_sampler(tparams, options, trng, use_noise)
28 |
29 | def _translate(seq):
30 | # sample given an input sequence and obtain scores
31 | sample, score = gen_sample(tparams, f_init, f_next,
32 | numpy.array(seq).reshape([len(seq), 1]),
33 | options, trng=trng, k=k, maxlen=200,
34 | stochastic=False, argmax=False)
35 |
36 | # normalize scores according to sequence lengths
37 | if normalize:
38 | lengths = numpy.array([len(s) for s in sample])
39 | score = score / lengths
40 | if n_best > 1:
41 | sidx = numpy.argsort(score)[:n_best]
42 | else:
43 | sidx = numpy.argmin(score)
44 | return numpy.array(sample)[sidx], numpy.array(score)[sidx]
45 |
46 | return _translate
47 |
48 |
49 | def main(model, dictionary, dictionary_target, source_file, saveto, k=5,
50 | normalize=False, n_process=5, chr_level=False, n_best=1):
51 |
52 | # load model model_options
53 | with open('%s.pkl' % model, 'rb') as f:
54 | options = pkl.load(f)
55 |
56 | # load source dictionary and invert
57 | with open(dictionary, 'rb') as f:
58 | word_dict = pkl.load(f)
59 | word_idict = dict()
60 | for kk, vv in word_dict.iteritems():
61 | word_idict[vv] = kk
62 | word_idict[0] = ''
63 | word_idict[1] = 'UNK'
64 |
65 | # load target dictionary and invert
66 | with open(dictionary_target, 'rb') as f:
67 | word_dict_trg = pkl.load(f)
68 | word_idict_trg = dict()
69 | for kk, vv in word_dict_trg.iteritems():
70 | word_idict_trg[vv] = kk
71 | word_idict_trg[0] = ''
72 | word_idict_trg[1] = 'UNK'
73 |
74 | # create input and output queues for processes
75 | trser = translate_model(model, options, k, normalize, n_best)
76 |
77 | # utility function
78 | def _seqs2words(caps):
79 | capsw = []
80 | for cc in caps:
81 | ww = []
82 | for w in cc:
83 | if w == 0:
84 | break
85 | ww.append(word_idict_trg[w])
86 | trs = ' '.join(ww)
87 | if trs.strip() == '':
88 | trs = 'UNK'
89 | capsw.append(trs)
90 | return capsw
91 |
92 | xs = []
93 | srcs = []
94 | with open(source_file, 'r') as f:
95 | for idx, line in enumerate(f):
96 | if chr_level:
97 | words = list(line.decode('utf-8').strip())
98 | else:
99 | words = line.strip().split()
100 | x = map(lambda w: word_dict[w] if w in word_dict else 1, words)
101 | x = map(lambda ii: ii if ii < options['n_words'] else 1, x)
102 | x += [0]
103 | xs.append((idx, x))
104 | srcs.append(line.strip())
105 | print 'Data loading over'
106 |
107 | print 'Translating ', source_file, '...'
108 | trans = []
109 | scores = []
110 | for req in xs:
111 | idx, x = req[0], req[1]
112 | tran, score = trser(x)
113 | trans.append(tran)
114 | scores.append(score)
115 | print 'the %d-th sentence' % idx
116 | print 'source side:\t%s' % srcs[idx]
117 | print 'target translation:\t%s' % ''.join(_seqs2words([trans[-1]]))
118 |
119 | if n_best == 1:
120 | trans = _seqs2words(trans)
121 | else:
122 | n_best_trans = []
123 | for idx, (n_best_tr, score_) in enumerate(zip(trans, scores)):
124 | sentences = _seqs2words(n_best_tr)
125 | for ids, trans_ in enumerate(sentences):
126 | n_best_trans.append(
127 | '|||'.join(
128 | ['{}'.format(idx), trans_,
129 | '{}'.format(score_[ids])]))
130 | trans = n_best_trans
131 |
132 | with open(saveto, 'w') as f:
133 | print >>f, '\n'.join(trans)
134 | print 'Done'
135 |
136 |
137 | if __name__ == "__main__":
138 | parser = argparse.ArgumentParser()
139 | parser.add_argument('-k', type=int, default=5, help="Beam size")
140 | parser.add_argument('-n', action="store_true", default=False,
141 | help="Normalize wrt sequence length")
142 | parser.add_argument('-c', action="store_true", default=False,
143 | help="Character level")
144 | parser.add_argument('-b', type=int, default=1, help="Output n-best list")
145 | parser.add_argument('model', type=str)
146 | parser.add_argument('dictionary', type=str)
147 | parser.add_argument('dictionary_target', type=str)
148 | parser.add_argument('source', type=str)
149 | parser.add_argument('saveto', type=str)
150 |
151 | args = parser.parse_args()
152 |
153 | main(args.model, args.dictionary, args.dictionary_target, args.source,
154 | args.saveto, k=args.k, normalize=args.n,
155 | chr_level=args.c, n_best=args.b)
156 |
--------------------------------------------------------------------------------
/code/train.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | import os
3 | import sys
4 |
5 | from nmt import train
6 |
7 | def main(job_id, params):
8 | print params
9 | validerr = train(**params)
10 | return validerr
11 |
12 | if __name__ == '__main__':
13 | if len(sys.argv) < 2:
14 | print '%s config.py' % sys.argv[0]
15 | sys.exit(1)
16 |
17 | options = eval(open(sys.argv[-1]).read())
18 | numpy.random.seed(options['seed'])
19 |
20 | main(0, options)
21 |
--------------------------------------------------------------------------------
/work/german.py:
--------------------------------------------------------------------------------
1 | dict(
2 | # network structure
3 | dim_word=620, # word vector dimensionality
4 | dim=1000, # the number of LSTM units
5 | n_words_src=40000, # source vocabulary size
6 | n_words=40000, # target vocabulary size
7 | maxlen=80, # maximum length of the description
8 |
9 | # process control
10 | max_epochs=10,
11 | finish_after=100000000, # finish after this many updates
12 | dispFreq=1,
13 | saveto='search_model.npz',
14 | validFreq=5000,
15 | validFreqLeast=100000,
16 | validFreqFires=150000,
17 | validFreqRefine=3000,
18 | saveFreq=1000, # save the parameters after every saveFreq updates
19 | sampleFreq=1000, # generate some samples after every sampleFreq
20 | reload_=True,
21 | overwrite=True,
22 | is_eval_nist=False,
23 |
24 | # optimization
25 | decay_c=0., # L2 regularization penalty
26 | alpha_c=0., # alignment regularization
27 | clip_c=5., # gradient clipping threshold
28 | lrate=1.0, # learning rate
29 | optimizer='adadelta',
30 | batch_size=80,
31 | valid_batch_size=80,
32 | use_dropout=False,
33 | shuffle_train=0.999,
34 | seed=1234,
35 |
36 | # development evaluation
37 | use_bleueval=True,
38 | save_devscore_to='search_bleu.log',
39 | save_devtrans_to='search_trans.txt',
40 | beam_size=10,
41 | proc_num=1,
42 | normalize=False,
43 | output_nbest=1,
44 |
45 | # datasets
46 | use_bpe=True,
47 | datasets=[
48 | '/Path-to-training-data/train.en.bpe',
49 | '/Path-to-training-data/train.de.bpe'],
50 | valid_datasets=['/Path-to-dev-data/dev.en.plain.bpe',
51 | '/Path-to-dev-data/dev.de'],
52 | dictionaries=[
53 | '/Vocabulary/vocab.en.pkl',
54 | '/Vocabulary/vocab.de.pkl'],
55 | )
56 |
--------------------------------------------------------------------------------
/work/run.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | THEANO_FLAGS='floatX=float32,device=gpu3,nvcc.fastmath=True' python train.py --proto=english-german-caencoder-nmt-bzhang --state german.py
4 |
--------------------------------------------------------------------------------
/work/train.py:
--------------------------------------------------------------------------------
1 | ../code/train.py
--------------------------------------------------------------------------------