├── CONTRIBUTING.md ├── LICENSE ├── README.md └── nmt ├── .gitignore ├── __init__.py ├── attention_model.py ├── g3doc └── img │ ├── attention_equation_0.jpg │ ├── attention_equation_1.jpg │ ├── attention_mechanism.jpg │ ├── attention_vis.jpg │ ├── encdec.jpg │ ├── greedy_dec.jpg │ └── seq2seq.jpg ├── gnmt_model.py ├── inference.py ├── inference_test.py ├── model.py ├── model_helper.py ├── model_test.py ├── nmt.py ├── nmt_test.py ├── scripts ├── __init__.py ├── bleu.py ├── download_iwslt15.sh ├── rouge.py └── wmt16_en_de.sh ├── standard_hparams ├── iwslt15.json ├── wmt16.json ├── wmt16_gnmt_4_layer.json └── wmt16_gnmt_8_layer.json ├── testdata ├── deen_output ├── deen_ref_bpe ├── deen_ref_spm ├── iwslt15.tst2013.100.en ├── iwslt15.tst2013.100.vi ├── iwslt15.vocab.100.en ├── iwslt15.vocab.100.vi ├── label_ref ├── pred_output ├── test_embed.txt ├── test_embed_with_header.txt ├── test_infer_file ├── test_infer_vocab.src └── test_infer_vocab.tgt ├── train.py └── utils ├── __init__.py ├── common_test_utils.py ├── evaluation_utils.py ├── evaluation_utils_test.py ├── iterator_utils.py ├── iterator_utils_test.py ├── misc_utils.py ├── misc_utils_test.py ├── nmt_utils.py ├── standard_hparams_utils.py ├── vocab_utils.py └── vocab_utils_test.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Want to contribute? Great! First, read this page (including the small print at the end). 2 | 3 | ### Before you contribute 4 | 5 | Before we can use your code, you must sign the 6 | [Google Individual Contributor License Agreement] 7 | (https://cla.developers.google.com/about/google-individual) 8 | (CLA), which you can do online. The CLA is necessary mainly because you own the 9 | copyright to your changes, even after your contribution becomes part of our 10 | codebase, so we need your permission to use and distribute your code. We also 11 | need to be sure of various other things—for instance that you'll tell us if you 12 | know that your code infringes on other people's patents. You don't have to sign 13 | the CLA until after you've submitted your code for review and a member has 14 | approved it, but you must do it before we can put your code into our codebase. 15 | Before you start working on a larger contribution, you should get in touch with 16 | us first through the issue tracker with your idea so that we can help out and 17 | possibly guide you. Coordinating up front makes it much easier to avoid 18 | frustration later on. 19 | 20 | ### Code reviews 21 | 22 | All submissions, including submissions by project members, require review. We 23 | use Github pull requests for this purpose. 24 | 25 | ### The small print 26 | 27 | Contributions made by corporations are covered by a different agreement than 28 | the one above, the 29 | [Software Grant and Corporate Contributor License Agreement] 30 | (https://cla.developers.google.com/about/google-corporate). 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /nmt/.gitignore: -------------------------------------------------------------------------------- 1 | bazel-bin 2 | bazel-genfiles 3 | bazel-out 4 | bazel-testlogs 5 | -------------------------------------------------------------------------------- /nmt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/nmt/0be864257a76c151eef20ea689755f08bc1faf4e/nmt/__init__.py -------------------------------------------------------------------------------- /nmt/attention_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Attention-based sequence-to-sequence model with dynamic RNN support.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from . import model 23 | from . import model_helper 24 | 25 | __all__ = ["AttentionModel"] 26 | 27 | 28 | class AttentionModel(model.Model): 29 | """Sequence-to-sequence dynamic model with attention. 30 | 31 | This class implements a multi-layer recurrent neural network as encoder, 32 | and an attention-based decoder. This is the same as the model described in 33 | (Luong et al., EMNLP'2015) paper: https://arxiv.org/pdf/1508.04025v5.pdf. 34 | This class also allows to use GRU cells in addition to LSTM cells with 35 | support for dropout. 36 | """ 37 | 38 | def __init__(self, 39 | hparams, 40 | mode, 41 | iterator, 42 | source_vocab_table, 43 | target_vocab_table, 44 | reverse_target_vocab_table=None, 45 | scope=None, 46 | extra_args=None): 47 | self.has_attention = hparams.attention_architecture and hparams.attention 48 | 49 | # Set attention_mechanism_fn 50 | if self.has_attention: 51 | if extra_args and extra_args.attention_mechanism_fn: 52 | self.attention_mechanism_fn = extra_args.attention_mechanism_fn 53 | else: 54 | self.attention_mechanism_fn = create_attention_mechanism 55 | 56 | super(AttentionModel, self).__init__( 57 | hparams=hparams, 58 | mode=mode, 59 | iterator=iterator, 60 | source_vocab_table=source_vocab_table, 61 | target_vocab_table=target_vocab_table, 62 | reverse_target_vocab_table=reverse_target_vocab_table, 63 | scope=scope, 64 | extra_args=extra_args) 65 | 66 | def _prepare_beam_search_decoder_inputs( 67 | self, beam_width, memory, source_sequence_length, encoder_state): 68 | memory = tf.contrib.seq2seq.tile_batch( 69 | memory, multiplier=beam_width) 70 | source_sequence_length = tf.contrib.seq2seq.tile_batch( 71 | source_sequence_length, multiplier=beam_width) 72 | encoder_state = tf.contrib.seq2seq.tile_batch( 73 | encoder_state, multiplier=beam_width) 74 | batch_size = self.batch_size * beam_width 75 | return memory, source_sequence_length, encoder_state, batch_size 76 | 77 | def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, 78 | source_sequence_length): 79 | """Build a RNN cell with attention mechanism that can be used by decoder.""" 80 | # No Attention 81 | if not self.has_attention: 82 | return super(AttentionModel, self)._build_decoder_cell( 83 | hparams, encoder_outputs, encoder_state, source_sequence_length) 84 | elif hparams.attention_architecture != "standard": 85 | raise ValueError( 86 | "Unknown attention architecture %s" % hparams.attention_architecture) 87 | 88 | num_units = hparams.num_units 89 | num_layers = self.num_decoder_layers 90 | num_residual_layers = self.num_decoder_residual_layers 91 | infer_mode = hparams.infer_mode 92 | 93 | dtype = tf.float32 94 | 95 | # Ensure memory is batch-major 96 | if self.time_major: 97 | memory = tf.transpose(encoder_outputs, [1, 0, 2]) 98 | else: 99 | memory = encoder_outputs 100 | 101 | if (self.mode == tf.contrib.learn.ModeKeys.INFER and 102 | infer_mode == "beam_search"): 103 | memory, source_sequence_length, encoder_state, batch_size = ( 104 | self._prepare_beam_search_decoder_inputs( 105 | hparams.beam_width, memory, source_sequence_length, 106 | encoder_state)) 107 | else: 108 | batch_size = self.batch_size 109 | 110 | # Attention 111 | attention_mechanism = self.attention_mechanism_fn( 112 | hparams.attention, num_units, memory, source_sequence_length, self.mode) 113 | 114 | cell = model_helper.create_rnn_cell( 115 | unit_type=hparams.unit_type, 116 | num_units=num_units, 117 | num_layers=num_layers, 118 | num_residual_layers=num_residual_layers, 119 | forget_bias=hparams.forget_bias, 120 | dropout=hparams.dropout, 121 | num_gpus=self.num_gpus, 122 | mode=self.mode, 123 | single_cell_fn=self.single_cell_fn) 124 | 125 | # Only generate alignment in greedy INFER mode. 126 | alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and 127 | infer_mode != "beam_search") 128 | cell = tf.contrib.seq2seq.AttentionWrapper( 129 | cell, 130 | attention_mechanism, 131 | attention_layer_size=num_units, 132 | alignment_history=alignment_history, 133 | output_attention=hparams.output_attention, 134 | name="attention") 135 | 136 | # TODO(thangluong): do we need num_layers, num_gpus? 137 | cell = tf.contrib.rnn.DeviceWrapper(cell, 138 | model_helper.get_device_str( 139 | num_layers - 1, self.num_gpus)) 140 | 141 | if hparams.pass_hidden_state: 142 | decoder_initial_state = cell.zero_state(batch_size, dtype).clone( 143 | cell_state=encoder_state) 144 | else: 145 | decoder_initial_state = cell.zero_state(batch_size, dtype) 146 | 147 | return cell, decoder_initial_state 148 | 149 | def _get_infer_summary(self, hparams): 150 | if not self.has_attention or hparams.infer_mode == "beam_search": 151 | return tf.no_op() 152 | return _create_attention_images_summary(self.final_context_state) 153 | 154 | 155 | def create_attention_mechanism(attention_option, num_units, memory, 156 | source_sequence_length, mode): 157 | """Create attention mechanism based on the attention_option.""" 158 | del mode # unused 159 | 160 | # Mechanism 161 | if attention_option == "luong": 162 | attention_mechanism = tf.contrib.seq2seq.LuongAttention( 163 | num_units, memory, memory_sequence_length=source_sequence_length) 164 | elif attention_option == "scaled_luong": 165 | attention_mechanism = tf.contrib.seq2seq.LuongAttention( 166 | num_units, 167 | memory, 168 | memory_sequence_length=source_sequence_length, 169 | scale=True) 170 | elif attention_option == "bahdanau": 171 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( 172 | num_units, memory, memory_sequence_length=source_sequence_length) 173 | elif attention_option == "normed_bahdanau": 174 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( 175 | num_units, 176 | memory, 177 | memory_sequence_length=source_sequence_length, 178 | normalize=True) 179 | else: 180 | raise ValueError("Unknown attention option %s" % attention_option) 181 | 182 | return attention_mechanism 183 | 184 | 185 | def _create_attention_images_summary(final_context_state): 186 | """create attention image and attention summary.""" 187 | attention_images = (final_context_state.alignment_history.stack()) 188 | # Reshape to (batch, src_seq_len, tgt_seq_len,1) 189 | attention_images = tf.expand_dims( 190 | tf.transpose(attention_images, [1, 2, 0]), -1) 191 | # Scale to range [0, 255] 192 | attention_images *= 255 193 | attention_summary = tf.summary.image("attention_images", attention_images) 194 | return attention_summary 195 | -------------------------------------------------------------------------------- /nmt/g3doc/img/attention_equation_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/nmt/0be864257a76c151eef20ea689755f08bc1faf4e/nmt/g3doc/img/attention_equation_0.jpg -------------------------------------------------------------------------------- /nmt/g3doc/img/attention_equation_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/nmt/0be864257a76c151eef20ea689755f08bc1faf4e/nmt/g3doc/img/attention_equation_1.jpg -------------------------------------------------------------------------------- /nmt/g3doc/img/attention_mechanism.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/nmt/0be864257a76c151eef20ea689755f08bc1faf4e/nmt/g3doc/img/attention_mechanism.jpg -------------------------------------------------------------------------------- /nmt/g3doc/img/attention_vis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/nmt/0be864257a76c151eef20ea689755f08bc1faf4e/nmt/g3doc/img/attention_vis.jpg -------------------------------------------------------------------------------- /nmt/g3doc/img/encdec.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/nmt/0be864257a76c151eef20ea689755f08bc1faf4e/nmt/g3doc/img/encdec.jpg -------------------------------------------------------------------------------- /nmt/g3doc/img/greedy_dec.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/nmt/0be864257a76c151eef20ea689755f08bc1faf4e/nmt/g3doc/img/greedy_dec.jpg -------------------------------------------------------------------------------- /nmt/g3doc/img/seq2seq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/nmt/0be864257a76c151eef20ea689755f08bc1faf4e/nmt/g3doc/img/seq2seq.jpg -------------------------------------------------------------------------------- /nmt/gnmt_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """GNMT attention sequence-to-sequence model with dynamic RNN support.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from . import attention_model 24 | from . import model_helper 25 | from .utils import misc_utils as utils 26 | from .utils import vocab_utils 27 | 28 | __all__ = ["GNMTModel"] 29 | 30 | 31 | class GNMTModel(attention_model.AttentionModel): 32 | """Sequence-to-sequence dynamic model with GNMT attention architecture. 33 | """ 34 | 35 | def __init__(self, 36 | hparams, 37 | mode, 38 | iterator, 39 | source_vocab_table, 40 | target_vocab_table, 41 | reverse_target_vocab_table=None, 42 | scope=None, 43 | extra_args=None): 44 | self.is_gnmt_attention = ( 45 | hparams.attention_architecture in ["gnmt", "gnmt_v2"]) 46 | 47 | super(GNMTModel, self).__init__( 48 | hparams=hparams, 49 | mode=mode, 50 | iterator=iterator, 51 | source_vocab_table=source_vocab_table, 52 | target_vocab_table=target_vocab_table, 53 | reverse_target_vocab_table=reverse_target_vocab_table, 54 | scope=scope, 55 | extra_args=extra_args) 56 | 57 | def _build_encoder(self, hparams): 58 | """Build a GNMT encoder.""" 59 | if hparams.encoder_type == "uni" or hparams.encoder_type == "bi": 60 | return super(GNMTModel, self)._build_encoder(hparams) 61 | 62 | if hparams.encoder_type != "gnmt": 63 | raise ValueError("Unknown encoder_type %s" % hparams.encoder_type) 64 | 65 | # Build GNMT encoder. 66 | num_bi_layers = 1 67 | num_uni_layers = self.num_encoder_layers - num_bi_layers 68 | utils.print_out("# Build a GNMT encoder") 69 | utils.print_out(" num_bi_layers = %d" % num_bi_layers) 70 | utils.print_out(" num_uni_layers = %d" % num_uni_layers) 71 | 72 | iterator = self.iterator 73 | source = iterator.source 74 | if self.time_major: 75 | source = tf.transpose(source) 76 | 77 | with tf.variable_scope("encoder") as scope: 78 | dtype = scope.dtype 79 | 80 | self.encoder_emb_inp = self.encoder_emb_lookup_fn( 81 | self.embedding_encoder, source) 82 | 83 | # Execute _build_bidirectional_rnn from Model class 84 | bi_encoder_outputs, bi_encoder_state = self._build_bidirectional_rnn( 85 | inputs=self.encoder_emb_inp, 86 | sequence_length=iterator.source_sequence_length, 87 | dtype=dtype, 88 | hparams=hparams, 89 | num_bi_layers=num_bi_layers, 90 | num_bi_residual_layers=0, # no residual connection 91 | ) 92 | 93 | # Build unidirectional layers 94 | if self.extract_encoder_layers: 95 | encoder_state, encoder_outputs = self._build_individual_encoder_layers( 96 | bi_encoder_outputs, num_uni_layers, dtype, hparams) 97 | else: 98 | encoder_state, encoder_outputs = self._build_all_encoder_layers( 99 | bi_encoder_outputs, num_uni_layers, dtype, hparams) 100 | 101 | # Pass all encoder states to the decoder 102 | # except the first bi-directional layer 103 | encoder_state = (bi_encoder_state[1],) + ( 104 | (encoder_state,) if num_uni_layers == 1 else encoder_state) 105 | 106 | return encoder_outputs, encoder_state 107 | 108 | def _build_all_encoder_layers(self, bi_encoder_outputs, 109 | num_uni_layers, dtype, hparams): 110 | """Build encoder layers all at once.""" 111 | uni_cell = model_helper.create_rnn_cell( 112 | unit_type=hparams.unit_type, 113 | num_units=hparams.num_units, 114 | num_layers=num_uni_layers, 115 | num_residual_layers=self.num_encoder_residual_layers, 116 | forget_bias=hparams.forget_bias, 117 | dropout=hparams.dropout, 118 | num_gpus=self.num_gpus, 119 | base_gpu=1, 120 | mode=self.mode, 121 | single_cell_fn=self.single_cell_fn) 122 | encoder_outputs, encoder_state = tf.nn.dynamic_rnn( 123 | uni_cell, 124 | bi_encoder_outputs, 125 | dtype=dtype, 126 | sequence_length=self.iterator.source_sequence_length, 127 | time_major=self.time_major) 128 | 129 | # Use the top layer for now 130 | self.encoder_state_list = [encoder_outputs] 131 | 132 | return encoder_state, encoder_outputs 133 | 134 | def _build_individual_encoder_layers(self, bi_encoder_outputs, 135 | num_uni_layers, dtype, hparams): 136 | """Run each of the encoder layer separately, not used in general seq2seq.""" 137 | uni_cell_lists = model_helper._cell_list( 138 | unit_type=hparams.unit_type, 139 | num_units=hparams.num_units, 140 | num_layers=num_uni_layers, 141 | num_residual_layers=self.num_encoder_residual_layers, 142 | forget_bias=hparams.forget_bias, 143 | dropout=hparams.dropout, 144 | num_gpus=self.num_gpus, 145 | base_gpu=1, 146 | mode=self.mode, 147 | single_cell_fn=self.single_cell_fn) 148 | 149 | encoder_inp = bi_encoder_outputs 150 | encoder_states = [] 151 | self.encoder_state_list = [bi_encoder_outputs[:, :, :hparams.num_units], 152 | bi_encoder_outputs[:, :, hparams.num_units:]] 153 | with tf.variable_scope("rnn/multi_rnn_cell"): 154 | for i, cell in enumerate(uni_cell_lists): 155 | with tf.variable_scope("cell_%d" % i) as scope: 156 | encoder_inp, encoder_state = tf.nn.dynamic_rnn( 157 | cell, 158 | encoder_inp, 159 | dtype=dtype, 160 | sequence_length=self.iterator.source_sequence_length, 161 | time_major=self.time_major, 162 | scope=scope) 163 | encoder_states.append(encoder_state) 164 | self.encoder_state_list.append(encoder_inp) 165 | 166 | encoder_state = tuple(encoder_states) 167 | encoder_outputs = self.encoder_state_list[-1] 168 | return encoder_state, encoder_outputs 169 | 170 | def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, 171 | source_sequence_length): 172 | """Build a RNN cell with GNMT attention architecture.""" 173 | # Standard attention 174 | if not self.is_gnmt_attention: 175 | return super(GNMTModel, self)._build_decoder_cell( 176 | hparams, encoder_outputs, encoder_state, source_sequence_length) 177 | 178 | # GNMT attention 179 | attention_option = hparams.attention 180 | attention_architecture = hparams.attention_architecture 181 | num_units = hparams.num_units 182 | infer_mode = hparams.infer_mode 183 | 184 | dtype = tf.float32 185 | 186 | if self.time_major: 187 | memory = tf.transpose(encoder_outputs, [1, 0, 2]) 188 | else: 189 | memory = encoder_outputs 190 | 191 | if (self.mode == tf.contrib.learn.ModeKeys.INFER and 192 | infer_mode == "beam_search"): 193 | memory, source_sequence_length, encoder_state, batch_size = ( 194 | self._prepare_beam_search_decoder_inputs( 195 | hparams.beam_width, memory, source_sequence_length, 196 | encoder_state)) 197 | else: 198 | batch_size = self.batch_size 199 | 200 | attention_mechanism = self.attention_mechanism_fn( 201 | attention_option, num_units, memory, source_sequence_length, self.mode) 202 | 203 | cell_list = model_helper._cell_list( # pylint: disable=protected-access 204 | unit_type=hparams.unit_type, 205 | num_units=num_units, 206 | num_layers=self.num_decoder_layers, 207 | num_residual_layers=self.num_decoder_residual_layers, 208 | forget_bias=hparams.forget_bias, 209 | dropout=hparams.dropout, 210 | num_gpus=self.num_gpus, 211 | mode=self.mode, 212 | single_cell_fn=self.single_cell_fn, 213 | residual_fn=gnmt_residual_fn 214 | ) 215 | 216 | # Only wrap the bottom layer with the attention mechanism. 217 | attention_cell = cell_list.pop(0) 218 | 219 | # Only generate alignment in greedy INFER mode. 220 | alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and 221 | infer_mode != "beam_search") 222 | attention_cell = tf.contrib.seq2seq.AttentionWrapper( 223 | attention_cell, 224 | attention_mechanism, 225 | attention_layer_size=None, # don't use attention layer. 226 | output_attention=False, 227 | alignment_history=alignment_history, 228 | name="attention") 229 | 230 | if attention_architecture == "gnmt": 231 | cell = GNMTAttentionMultiCell( 232 | attention_cell, cell_list) 233 | elif attention_architecture == "gnmt_v2": 234 | cell = GNMTAttentionMultiCell( 235 | attention_cell, cell_list, use_new_attention=True) 236 | else: 237 | raise ValueError( 238 | "Unknown attention_architecture %s" % attention_architecture) 239 | 240 | if hparams.pass_hidden_state: 241 | decoder_initial_state = tuple( 242 | zs.clone(cell_state=es) 243 | if isinstance(zs, tf.contrib.seq2seq.AttentionWrapperState) else es 244 | for zs, es in zip( 245 | cell.zero_state(batch_size, dtype), encoder_state)) 246 | else: 247 | decoder_initial_state = cell.zero_state(batch_size, dtype) 248 | 249 | return cell, decoder_initial_state 250 | 251 | def _get_infer_summary(self, hparams): 252 | if hparams.infer_mode == "beam_search": 253 | return tf.no_op() 254 | elif self.is_gnmt_attention: 255 | return attention_model._create_attention_images_summary( 256 | self.final_context_state[0]) 257 | else: 258 | return super(GNMTModel, self)._get_infer_summary(hparams) 259 | 260 | 261 | class GNMTAttentionMultiCell(tf.nn.rnn_cell.MultiRNNCell): 262 | """A MultiCell with GNMT attention style.""" 263 | 264 | def __init__(self, attention_cell, cells, use_new_attention=False): 265 | """Creates a GNMTAttentionMultiCell. 266 | 267 | Args: 268 | attention_cell: An instance of AttentionWrapper. 269 | cells: A list of RNNCell wrapped with AttentionInputWrapper. 270 | use_new_attention: Whether to use the attention generated from current 271 | step bottom layer's output. Default is False. 272 | """ 273 | cells = [attention_cell] + cells 274 | self.use_new_attention = use_new_attention 275 | super(GNMTAttentionMultiCell, self).__init__(cells, state_is_tuple=True) 276 | 277 | def __call__(self, inputs, state, scope=None): 278 | """Run the cell with bottom layer's attention copied to all upper layers.""" 279 | if not tf.contrib.framework.nest.is_sequence(state): 280 | raise ValueError( 281 | "Expected state to be a tuple of length %d, but received: %s" 282 | % (len(self.state_size), state)) 283 | 284 | with tf.variable_scope(scope or "multi_rnn_cell"): 285 | new_states = [] 286 | 287 | with tf.variable_scope("cell_0_attention"): 288 | attention_cell = self._cells[0] 289 | attention_state = state[0] 290 | cur_inp, new_attention_state = attention_cell(inputs, attention_state) 291 | new_states.append(new_attention_state) 292 | 293 | for i in range(1, len(self._cells)): 294 | with tf.variable_scope("cell_%d" % i): 295 | 296 | cell = self._cells[i] 297 | cur_state = state[i] 298 | 299 | if self.use_new_attention: 300 | cur_inp = tf.concat([cur_inp, new_attention_state.attention], -1) 301 | else: 302 | cur_inp = tf.concat([cur_inp, attention_state.attention], -1) 303 | 304 | cur_inp, new_state = cell(cur_inp, cur_state) 305 | new_states.append(new_state) 306 | 307 | return cur_inp, tuple(new_states) 308 | 309 | 310 | def gnmt_residual_fn(inputs, outputs): 311 | """Residual function that handles different inputs and outputs inner dims. 312 | 313 | Args: 314 | inputs: cell inputs, this is actual inputs concatenated with the attention 315 | vector. 316 | outputs: cell outputs 317 | 318 | Returns: 319 | outputs + actual inputs 320 | """ 321 | def split_input(inp, out): 322 | out_dim = out.get_shape().as_list()[-1] 323 | inp_dim = inp.get_shape().as_list()[-1] 324 | return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1) 325 | actual_inputs, _ = tf.contrib.framework.nest.map_structure( 326 | split_input, inputs, outputs) 327 | def assert_shape_match(inp, out): 328 | inp.get_shape().assert_is_compatible_with(out.get_shape()) 329 | tf.contrib.framework.nest.assert_same_structure(actual_inputs, outputs) 330 | tf.contrib.framework.nest.map_structure( 331 | assert_shape_match, actual_inputs, outputs) 332 | return tf.contrib.framework.nest.map_structure( 333 | lambda inp, out: inp + out, actual_inputs, outputs) 334 | -------------------------------------------------------------------------------- /nmt/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """To perform inference on test set given a trained model.""" 17 | from __future__ import print_function 18 | 19 | import codecs 20 | import time 21 | 22 | import tensorflow as tf 23 | 24 | from . import attention_model 25 | from . import gnmt_model 26 | from . import model as nmt_model 27 | from . import model_helper 28 | from .utils import misc_utils as utils 29 | from .utils import nmt_utils 30 | 31 | __all__ = ["load_data", "inference", 32 | "single_worker_inference", "multi_worker_inference"] 33 | 34 | 35 | def _decode_inference_indices(model, sess, output_infer, 36 | output_infer_summary_prefix, 37 | inference_indices, 38 | tgt_eos, 39 | subword_option): 40 | """Decoding only a specific set of sentences.""" 41 | utils.print_out(" decoding to output %s , num sents %d." % 42 | (output_infer, len(inference_indices))) 43 | start_time = time.time() 44 | with codecs.getwriter("utf-8")( 45 | tf.gfile.GFile(output_infer, mode="wb")) as trans_f: 46 | trans_f.write("") # Write empty string to ensure file is created. 47 | for decode_id in inference_indices: 48 | nmt_outputs, infer_summary = model.decode(sess) 49 | 50 | # get text translation 51 | assert nmt_outputs.shape[0] == 1 52 | translation = nmt_utils.get_translation( 53 | nmt_outputs, 54 | sent_id=0, 55 | tgt_eos=tgt_eos, 56 | subword_option=subword_option) 57 | 58 | if infer_summary is not None: # Attention models 59 | image_file = output_infer_summary_prefix + str(decode_id) + ".png" 60 | utils.print_out(" save attention image to %s*" % image_file) 61 | image_summ = tf.Summary() 62 | image_summ.ParseFromString(infer_summary) 63 | with tf.gfile.GFile(image_file, mode="w") as img_f: 64 | img_f.write(image_summ.value[0].image.encoded_image_string) 65 | 66 | trans_f.write("%s\n" % translation) 67 | utils.print_out(translation + b"\n") 68 | utils.print_time(" done", start_time) 69 | 70 | 71 | def load_data(inference_input_file, hparams=None): 72 | """Load inference data.""" 73 | with codecs.getreader("utf-8")( 74 | tf.gfile.GFile(inference_input_file, mode="rb")) as f: 75 | inference_data = f.read().splitlines() 76 | 77 | if hparams and hparams.inference_indices: 78 | inference_data = [inference_data[i] for i in hparams.inference_indices] 79 | 80 | return inference_data 81 | 82 | 83 | def get_model_creator(hparams): 84 | """Get the right model class depending on configuration.""" 85 | if (hparams.encoder_type == "gnmt" or 86 | hparams.attention_architecture in ["gnmt", "gnmt_v2"]): 87 | model_creator = gnmt_model.GNMTModel 88 | elif hparams.attention_architecture == "standard": 89 | model_creator = attention_model.AttentionModel 90 | elif not hparams.attention: 91 | model_creator = nmt_model.Model 92 | else: 93 | raise ValueError("Unknown attention architecture %s" % 94 | hparams.attention_architecture) 95 | return model_creator 96 | 97 | 98 | def start_sess_and_load_model(infer_model, ckpt_path): 99 | """Start session and load model.""" 100 | sess = tf.Session( 101 | graph=infer_model.graph, config=utils.get_config_proto()) 102 | with infer_model.graph.as_default(): 103 | loaded_infer_model = model_helper.load_model( 104 | infer_model.model, ckpt_path, sess, "infer") 105 | return sess, loaded_infer_model 106 | 107 | 108 | def inference(ckpt_path, 109 | inference_input_file, 110 | inference_output_file, 111 | hparams, 112 | num_workers=1, 113 | jobid=0, 114 | scope=None): 115 | """Perform translation.""" 116 | if hparams.inference_indices: 117 | assert num_workers == 1 118 | 119 | model_creator = get_model_creator(hparams) 120 | infer_model = model_helper.create_infer_model(model_creator, hparams, scope) 121 | sess, loaded_infer_model = start_sess_and_load_model(infer_model, ckpt_path) 122 | 123 | if num_workers == 1: 124 | single_worker_inference( 125 | sess, 126 | infer_model, 127 | loaded_infer_model, 128 | inference_input_file, 129 | inference_output_file, 130 | hparams) 131 | else: 132 | multi_worker_inference( 133 | sess, 134 | infer_model, 135 | loaded_infer_model, 136 | inference_input_file, 137 | inference_output_file, 138 | hparams, 139 | num_workers=num_workers, 140 | jobid=jobid) 141 | sess.close() 142 | 143 | 144 | def single_worker_inference(sess, 145 | infer_model, 146 | loaded_infer_model, 147 | inference_input_file, 148 | inference_output_file, 149 | hparams): 150 | """Inference with a single worker.""" 151 | output_infer = inference_output_file 152 | 153 | # Read data 154 | infer_data = load_data(inference_input_file, hparams) 155 | 156 | with infer_model.graph.as_default(): 157 | sess.run( 158 | infer_model.iterator.initializer, 159 | feed_dict={ 160 | infer_model.src_placeholder: infer_data, 161 | infer_model.batch_size_placeholder: hparams.infer_batch_size 162 | }) 163 | # Decode 164 | utils.print_out("# Start decoding") 165 | if hparams.inference_indices: 166 | _decode_inference_indices( 167 | loaded_infer_model, 168 | sess, 169 | output_infer=output_infer, 170 | output_infer_summary_prefix=output_infer, 171 | inference_indices=hparams.inference_indices, 172 | tgt_eos=hparams.eos, 173 | subword_option=hparams.subword_option) 174 | else: 175 | nmt_utils.decode_and_evaluate( 176 | "infer", 177 | loaded_infer_model, 178 | sess, 179 | output_infer, 180 | ref_file=None, 181 | metrics=hparams.metrics, 182 | subword_option=hparams.subword_option, 183 | beam_width=hparams.beam_width, 184 | tgt_eos=hparams.eos, 185 | num_translations_per_input=hparams.num_translations_per_input, 186 | infer_mode=hparams.infer_mode) 187 | 188 | 189 | def multi_worker_inference(sess, 190 | infer_model, 191 | loaded_infer_model, 192 | inference_input_file, 193 | inference_output_file, 194 | hparams, 195 | num_workers, 196 | jobid): 197 | """Inference using multiple workers.""" 198 | assert num_workers > 1 199 | 200 | final_output_infer = inference_output_file 201 | output_infer = "%s_%d" % (inference_output_file, jobid) 202 | output_infer_done = "%s_done_%d" % (inference_output_file, jobid) 203 | 204 | # Read data 205 | infer_data = load_data(inference_input_file, hparams) 206 | 207 | # Split data to multiple workers 208 | total_load = len(infer_data) 209 | load_per_worker = int((total_load - 1) / num_workers) + 1 210 | start_position = jobid * load_per_worker 211 | end_position = min(start_position + load_per_worker, total_load) 212 | infer_data = infer_data[start_position:end_position] 213 | 214 | with infer_model.graph.as_default(): 215 | sess.run(infer_model.iterator.initializer, 216 | { 217 | infer_model.src_placeholder: infer_data, 218 | infer_model.batch_size_placeholder: hparams.infer_batch_size 219 | }) 220 | # Decode 221 | utils.print_out("# Start decoding") 222 | nmt_utils.decode_and_evaluate( 223 | "infer", 224 | loaded_infer_model, 225 | sess, 226 | output_infer, 227 | ref_file=None, 228 | metrics=hparams.metrics, 229 | subword_option=hparams.subword_option, 230 | beam_width=hparams.beam_width, 231 | tgt_eos=hparams.eos, 232 | num_translations_per_input=hparams.num_translations_per_input, 233 | infer_mode=hparams.infer_mode) 234 | 235 | # Change file name to indicate the file writing is completed. 236 | tf.gfile.Rename(output_infer, output_infer_done, overwrite=True) 237 | 238 | # Job 0 is responsible for the clean up. 239 | if jobid != 0: return 240 | 241 | # Now write all translations 242 | with codecs.getwriter("utf-8")( 243 | tf.gfile.GFile(final_output_infer, mode="wb")) as final_f: 244 | for worker_id in range(num_workers): 245 | worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) 246 | while not tf.gfile.Exists(worker_infer_done): 247 | utils.print_out(" waiting job %d to complete." % worker_id) 248 | time.sleep(10) 249 | 250 | with codecs.getreader("utf-8")( 251 | tf.gfile.GFile(worker_infer_done, mode="rb")) as f: 252 | for translation in f: 253 | final_f.write("%s" % translation) 254 | 255 | for worker_id in range(num_workers): 256 | worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) 257 | tf.gfile.Remove(worker_infer_done) 258 | -------------------------------------------------------------------------------- /nmt/inference_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for model inference.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | from . import inference 27 | from . import model_helper 28 | from .utils import common_test_utils 29 | 30 | float32 = np.float32 31 | int32 = np.int32 32 | array = np.array 33 | 34 | 35 | class InferenceTest(tf.test.TestCase): 36 | 37 | def _createTestInferCheckpoint(self, hparams, name): 38 | # Prepare 39 | hparams.vocab_prefix = ( 40 | "nmt/testdata/test_infer_vocab") 41 | hparams.src_vocab_file = hparams.vocab_prefix + "." + hparams.src 42 | hparams.tgt_vocab_file = hparams.vocab_prefix + "." + hparams.tgt 43 | out_dir = os.path.join(tf.test.get_temp_dir(), name) 44 | os.makedirs(out_dir) 45 | hparams.out_dir = out_dir 46 | 47 | # Create check point 48 | model_creator = inference.get_model_creator(hparams) 49 | infer_model = model_helper.create_infer_model(model_creator, hparams) 50 | with self.test_session(graph=infer_model.graph) as sess: 51 | loaded_model, global_step = model_helper.create_or_load_model( 52 | infer_model.model, out_dir, sess, "infer_name") 53 | ckpt_path = loaded_model.saver.save( 54 | sess, os.path.join(out_dir, "translate.ckpt"), 55 | global_step=global_step) 56 | return ckpt_path 57 | 58 | def testBasicModel(self): 59 | hparams = common_test_utils.create_test_hparams( 60 | encoder_type="uni", 61 | num_layers=1, 62 | attention="", 63 | attention_architecture="", 64 | use_residual=False,) 65 | ckpt_path = self._createTestInferCheckpoint(hparams, "basic_infer") 66 | infer_file = "nmt/testdata/test_infer_file" 67 | output_infer = os.path.join(hparams.out_dir, "output_infer") 68 | inference.inference(ckpt_path, infer_file, output_infer, hparams) 69 | with open(output_infer) as f: 70 | self.assertEqual(5, len(list(f))) 71 | 72 | def testBasicModelWithMultipleTranslations(self): 73 | hparams = common_test_utils.create_test_hparams( 74 | encoder_type="uni", 75 | num_layers=1, 76 | attention="", 77 | attention_architecture="", 78 | use_residual=False, 79 | num_translations_per_input=2, 80 | beam_width=2, 81 | ) 82 | hparams.infer_mode = "beam_search" 83 | 84 | ckpt_path = self._createTestInferCheckpoint(hparams, "multi_basic_infer") 85 | infer_file = "nmt/testdata/test_infer_file" 86 | output_infer = os.path.join(hparams.out_dir, "output_infer") 87 | inference.inference(ckpt_path, infer_file, output_infer, hparams) 88 | with open(output_infer) as f: 89 | self.assertEqual(10, len(list(f))) 90 | 91 | def testAttentionModel(self): 92 | hparams = common_test_utils.create_test_hparams( 93 | encoder_type="uni", 94 | num_layers=1, 95 | attention="scaled_luong", 96 | attention_architecture="standard", 97 | use_residual=False,) 98 | ckpt_path = self._createTestInferCheckpoint(hparams, "attention_infer") 99 | infer_file = "nmt/testdata/test_infer_file" 100 | output_infer = os.path.join(hparams.out_dir, "output_infer") 101 | inference.inference(ckpt_path, infer_file, output_infer, hparams) 102 | with open(output_infer) as f: 103 | self.assertEqual(5, len(list(f))) 104 | 105 | def testMultiWorkers(self): 106 | hparams = common_test_utils.create_test_hparams( 107 | encoder_type="uni", 108 | num_layers=2, 109 | attention="scaled_luong", 110 | attention_architecture="standard", 111 | use_residual=False,) 112 | 113 | num_workers = 3 114 | 115 | # There are 5 examples, make batch_size=3 makes job0 has 3 examples, job1 116 | # has 2 examples, and job2 has 0 example. This helps testing some edge 117 | # cases. 118 | hparams.batch_size = 3 119 | 120 | ckpt_path = self._createTestInferCheckpoint(hparams, "multi_worker_infer") 121 | infer_file = "nmt/testdata/test_infer_file" 122 | output_infer = os.path.join(hparams.out_dir, "output_infer") 123 | inference.inference( 124 | ckpt_path, infer_file, output_infer, hparams, num_workers, jobid=1) 125 | 126 | inference.inference( 127 | ckpt_path, infer_file, output_infer, hparams, num_workers, jobid=2) 128 | 129 | # Note: Need to start job 0 at the end; otherwise, it will block the testing 130 | # thread. 131 | inference.inference( 132 | ckpt_path, infer_file, output_infer, hparams, num_workers, jobid=0) 133 | 134 | with open(output_infer) as f: 135 | self.assertEqual(5, len(list(f))) 136 | 137 | def testBasicModelWithInferIndices(self): 138 | hparams = common_test_utils.create_test_hparams( 139 | encoder_type="uni", 140 | num_layers=1, 141 | attention="", 142 | attention_architecture="", 143 | use_residual=False, 144 | inference_indices=[0]) 145 | ckpt_path = self._createTestInferCheckpoint(hparams, 146 | "basic_infer_with_indices") 147 | infer_file = "nmt/testdata/test_infer_file" 148 | output_infer = os.path.join(hparams.out_dir, "output_infer") 149 | inference.inference(ckpt_path, infer_file, output_infer, hparams) 150 | with open(output_infer) as f: 151 | self.assertEqual(1, len(list(f))) 152 | 153 | def testAttentionModelWithInferIndices(self): 154 | hparams = common_test_utils.create_test_hparams( 155 | encoder_type="uni", 156 | num_layers=1, 157 | attention="scaled_luong", 158 | attention_architecture="standard", 159 | use_residual=False, 160 | inference_indices=[1, 2]) 161 | # TODO(rzhao): Make infer indices support batch_size > 1. 162 | hparams.infer_batch_size = 1 163 | ckpt_path = self._createTestInferCheckpoint(hparams, 164 | "attention_infer_with_indices") 165 | infer_file = "nmt/testdata/test_infer_file" 166 | output_infer = os.path.join(hparams.out_dir, "output_infer") 167 | inference.inference(ckpt_path, infer_file, output_infer, hparams) 168 | with open(output_infer) as f: 169 | self.assertEqual(2, len(list(f))) 170 | self.assertTrue(os.path.exists(output_infer+str(1)+".png")) 171 | self.assertTrue(os.path.exists(output_infer+str(2)+".png")) 172 | 173 | 174 | if __name__ == "__main__": 175 | tf.test.main() 176 | -------------------------------------------------------------------------------- /nmt/model_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions for building models.""" 17 | from __future__ import print_function 18 | 19 | import collections 20 | import os 21 | import time 22 | import numpy as np 23 | import six 24 | import tensorflow as tf 25 | 26 | from tensorflow.python.ops import lookup_ops 27 | from .utils import iterator_utils 28 | from .utils import misc_utils as utils 29 | from .utils import vocab_utils 30 | 31 | __all__ = [ 32 | "get_initializer", "get_device_str", "create_train_model", 33 | "create_eval_model", "create_infer_model", 34 | "create_emb_for_encoder_and_decoder", "create_rnn_cell", "gradient_clip", 35 | "create_or_load_model", "load_model", "avg_checkpoints", 36 | "compute_perplexity" 37 | ] 38 | 39 | # If a vocab size is greater than this value, put the embedding on cpu instead 40 | VOCAB_SIZE_THRESHOLD_CPU = 50000 41 | 42 | 43 | def get_initializer(init_op, seed=None, init_weight=None): 44 | """Create an initializer. init_weight is only for uniform.""" 45 | if init_op == "uniform": 46 | assert init_weight 47 | return tf.random_uniform_initializer( 48 | -init_weight, init_weight, seed=seed) 49 | elif init_op == "glorot_normal": 50 | return tf.keras.initializers.glorot_normal( 51 | seed=seed) 52 | elif init_op == "glorot_uniform": 53 | return tf.keras.initializers.glorot_uniform( 54 | seed=seed) 55 | else: 56 | raise ValueError("Unknown init_op %s" % init_op) 57 | 58 | 59 | def get_device_str(device_id, num_gpus): 60 | """Return a device string for multi-GPU setup.""" 61 | if num_gpus == 0: 62 | return "/cpu:0" 63 | device_str_output = "/gpu:%d" % (device_id % num_gpus) 64 | return device_str_output 65 | 66 | 67 | class ExtraArgs(collections.namedtuple( 68 | "ExtraArgs", ("single_cell_fn", "model_device_fn", 69 | "attention_mechanism_fn", "encoder_emb_lookup_fn"))): 70 | pass 71 | 72 | 73 | class TrainModel( 74 | collections.namedtuple("TrainModel", ("graph", "model", "iterator", 75 | "skip_count_placeholder"))): 76 | pass 77 | 78 | 79 | def create_train_model( 80 | model_creator, hparams, scope=None, num_workers=1, jobid=0, 81 | extra_args=None): 82 | """Create train graph, model, and iterator.""" 83 | src_file = "%s.%s" % (hparams.train_prefix, hparams.src) 84 | tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) 85 | src_vocab_file = hparams.src_vocab_file 86 | tgt_vocab_file = hparams.tgt_vocab_file 87 | 88 | graph = tf.Graph() 89 | 90 | with graph.as_default(), tf.container(scope or "train"): 91 | src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( 92 | src_vocab_file, tgt_vocab_file, hparams.share_vocab) 93 | 94 | src_dataset = tf.data.TextLineDataset(tf.gfile.Glob(src_file)) 95 | tgt_dataset = tf.data.TextLineDataset(tf.gfile.Glob(tgt_file)) 96 | skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) 97 | 98 | iterator = iterator_utils.get_iterator( 99 | src_dataset, 100 | tgt_dataset, 101 | src_vocab_table, 102 | tgt_vocab_table, 103 | batch_size=hparams.batch_size, 104 | sos=hparams.sos, 105 | eos=hparams.eos, 106 | random_seed=hparams.random_seed, 107 | num_buckets=hparams.num_buckets, 108 | src_max_len=hparams.src_max_len, 109 | tgt_max_len=hparams.tgt_max_len, 110 | skip_count=skip_count_placeholder, 111 | num_shards=num_workers, 112 | shard_index=jobid, 113 | use_char_encode=hparams.use_char_encode) 114 | 115 | # Note: One can set model_device_fn to 116 | # `tf.train.replica_device_setter(ps_tasks)` for distributed training. 117 | model_device_fn = None 118 | if extra_args: model_device_fn = extra_args.model_device_fn 119 | with tf.device(model_device_fn): 120 | model = model_creator( 121 | hparams, 122 | iterator=iterator, 123 | mode=tf.contrib.learn.ModeKeys.TRAIN, 124 | source_vocab_table=src_vocab_table, 125 | target_vocab_table=tgt_vocab_table, 126 | scope=scope, 127 | extra_args=extra_args) 128 | 129 | return TrainModel( 130 | graph=graph, 131 | model=model, 132 | iterator=iterator, 133 | skip_count_placeholder=skip_count_placeholder) 134 | 135 | 136 | class EvalModel( 137 | collections.namedtuple("EvalModel", 138 | ("graph", "model", "src_file_placeholder", 139 | "tgt_file_placeholder", "iterator"))): 140 | pass 141 | 142 | 143 | def create_eval_model(model_creator, hparams, scope=None, extra_args=None): 144 | """Create train graph, model, src/tgt file holders, and iterator.""" 145 | src_vocab_file = hparams.src_vocab_file 146 | tgt_vocab_file = hparams.tgt_vocab_file 147 | graph = tf.Graph() 148 | 149 | with graph.as_default(), tf.container(scope or "eval"): 150 | src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( 151 | src_vocab_file, tgt_vocab_file, hparams.share_vocab) 152 | reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( 153 | tgt_vocab_file, default_value=vocab_utils.UNK) 154 | 155 | src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) 156 | tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) 157 | src_dataset = tf.data.TextLineDataset(src_file_placeholder) 158 | tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder) 159 | iterator = iterator_utils.get_iterator( 160 | src_dataset, 161 | tgt_dataset, 162 | src_vocab_table, 163 | tgt_vocab_table, 164 | hparams.batch_size, 165 | sos=hparams.sos, 166 | eos=hparams.eos, 167 | random_seed=hparams.random_seed, 168 | num_buckets=hparams.num_buckets, 169 | src_max_len=hparams.src_max_len_infer, 170 | tgt_max_len=hparams.tgt_max_len_infer, 171 | use_char_encode=hparams.use_char_encode) 172 | model = model_creator( 173 | hparams, 174 | iterator=iterator, 175 | mode=tf.contrib.learn.ModeKeys.EVAL, 176 | source_vocab_table=src_vocab_table, 177 | target_vocab_table=tgt_vocab_table, 178 | reverse_target_vocab_table=reverse_tgt_vocab_table, 179 | scope=scope, 180 | extra_args=extra_args) 181 | return EvalModel( 182 | graph=graph, 183 | model=model, 184 | src_file_placeholder=src_file_placeholder, 185 | tgt_file_placeholder=tgt_file_placeholder, 186 | iterator=iterator) 187 | 188 | 189 | class InferModel( 190 | collections.namedtuple("InferModel", 191 | ("graph", "model", "src_placeholder", 192 | "batch_size_placeholder", "iterator"))): 193 | pass 194 | 195 | 196 | def create_infer_model(model_creator, hparams, scope=None, extra_args=None): 197 | """Create inference model.""" 198 | graph = tf.Graph() 199 | src_vocab_file = hparams.src_vocab_file 200 | tgt_vocab_file = hparams.tgt_vocab_file 201 | 202 | with graph.as_default(), tf.container(scope or "infer"): 203 | src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( 204 | src_vocab_file, tgt_vocab_file, hparams.share_vocab) 205 | reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( 206 | tgt_vocab_file, default_value=vocab_utils.UNK) 207 | 208 | src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) 209 | batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) 210 | 211 | src_dataset = tf.data.Dataset.from_tensor_slices( 212 | src_placeholder) 213 | iterator = iterator_utils.get_infer_iterator( 214 | src_dataset, 215 | src_vocab_table, 216 | batch_size=batch_size_placeholder, 217 | eos=hparams.eos, 218 | src_max_len=hparams.src_max_len_infer, 219 | use_char_encode=hparams.use_char_encode) 220 | model = model_creator( 221 | hparams, 222 | iterator=iterator, 223 | mode=tf.contrib.learn.ModeKeys.INFER, 224 | source_vocab_table=src_vocab_table, 225 | target_vocab_table=tgt_vocab_table, 226 | reverse_target_vocab_table=reverse_tgt_vocab_table, 227 | scope=scope, 228 | extra_args=extra_args) 229 | return InferModel( 230 | graph=graph, 231 | model=model, 232 | src_placeholder=src_placeholder, 233 | batch_size_placeholder=batch_size_placeholder, 234 | iterator=iterator) 235 | 236 | 237 | def _get_embed_device(vocab_size): 238 | """Decide on which device to place an embed matrix given its vocab size.""" 239 | if vocab_size > VOCAB_SIZE_THRESHOLD_CPU: 240 | return "/cpu:0" 241 | else: 242 | return "/gpu:0" 243 | 244 | 245 | def _create_pretrained_emb_from_txt( 246 | vocab_file, embed_file, num_trainable_tokens=3, dtype=tf.float32, 247 | scope=None): 248 | """Load pretrain embeding from embed_file, and return an embedding matrix. 249 | 250 | Args: 251 | embed_file: Path to a Glove formated embedding txt file. 252 | num_trainable_tokens: Make the first n tokens in the vocab file as trainable 253 | variables. Default is 3, which is "", "" and "". 254 | """ 255 | vocab, _ = vocab_utils.load_vocab(vocab_file) 256 | trainable_tokens = vocab[:num_trainable_tokens] 257 | 258 | utils.print_out("# Using pretrained embedding: %s." % embed_file) 259 | utils.print_out(" with trainable tokens: ") 260 | 261 | emb_dict, emb_size = vocab_utils.load_embed_txt(embed_file) 262 | for token in trainable_tokens: 263 | utils.print_out(" %s" % token) 264 | if token not in emb_dict: 265 | emb_dict[token] = [0.0] * emb_size 266 | 267 | emb_mat = np.array( 268 | [emb_dict[token] for token in vocab], dtype=dtype.as_numpy_dtype()) 269 | emb_mat = tf.constant(emb_mat) 270 | emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1]) 271 | with tf.variable_scope(scope or "pretrain_embeddings", dtype=dtype) as scope: 272 | with tf.device(_get_embed_device(num_trainable_tokens)): 273 | emb_mat_var = tf.get_variable( 274 | "emb_mat_var", [num_trainable_tokens, emb_size]) 275 | return tf.concat([emb_mat_var, emb_mat_const], 0) 276 | 277 | 278 | def _create_or_load_embed(embed_name, vocab_file, embed_file, 279 | vocab_size, embed_size, dtype): 280 | """Create a new or load an existing embedding matrix.""" 281 | if vocab_file and embed_file: 282 | embedding = _create_pretrained_emb_from_txt(vocab_file, embed_file) 283 | else: 284 | with tf.device(_get_embed_device(vocab_size)): 285 | embedding = tf.get_variable( 286 | embed_name, [vocab_size, embed_size], dtype) 287 | return embedding 288 | 289 | 290 | def create_emb_for_encoder_and_decoder(share_vocab, 291 | src_vocab_size, 292 | tgt_vocab_size, 293 | src_embed_size, 294 | tgt_embed_size, 295 | dtype=tf.float32, 296 | num_enc_partitions=0, 297 | num_dec_partitions=0, 298 | src_vocab_file=None, 299 | tgt_vocab_file=None, 300 | src_embed_file=None, 301 | tgt_embed_file=None, 302 | use_char_encode=False, 303 | scope=None): 304 | """Create embedding matrix for both encoder and decoder. 305 | 306 | Args: 307 | share_vocab: A boolean. Whether to share embedding matrix for both 308 | encoder and decoder. 309 | src_vocab_size: An integer. The source vocab size. 310 | tgt_vocab_size: An integer. The target vocab size. 311 | src_embed_size: An integer. The embedding dimension for the encoder's 312 | embedding. 313 | tgt_embed_size: An integer. The embedding dimension for the decoder's 314 | embedding. 315 | dtype: dtype of the embedding matrix. Default to float32. 316 | num_enc_partitions: number of partitions used for the encoder's embedding 317 | vars. 318 | num_dec_partitions: number of partitions used for the decoder's embedding 319 | vars. 320 | scope: VariableScope for the created subgraph. Default to "embedding". 321 | 322 | Returns: 323 | embedding_encoder: Encoder's embedding matrix. 324 | embedding_decoder: Decoder's embedding matrix. 325 | 326 | Raises: 327 | ValueError: if use share_vocab but source and target have different vocab 328 | size. 329 | """ 330 | if num_enc_partitions <= 1: 331 | enc_partitioner = None 332 | else: 333 | # Note: num_partitions > 1 is required for distributed training due to 334 | # embedding_lookup tries to colocate single partition-ed embedding variable 335 | # with lookup ops. This may cause embedding variables being placed on worker 336 | # jobs. 337 | enc_partitioner = tf.fixed_size_partitioner(num_enc_partitions) 338 | 339 | if num_dec_partitions <= 1: 340 | dec_partitioner = None 341 | else: 342 | # Note: num_partitions > 1 is required for distributed training due to 343 | # embedding_lookup tries to colocate single partition-ed embedding variable 344 | # with lookup ops. This may cause embedding variables being placed on worker 345 | # jobs. 346 | dec_partitioner = tf.fixed_size_partitioner(num_dec_partitions) 347 | 348 | if src_embed_file and enc_partitioner: 349 | raise ValueError( 350 | "Can't set num_enc_partitions > 1 when using pretrained encoder " 351 | "embedding") 352 | 353 | if tgt_embed_file and dec_partitioner: 354 | raise ValueError( 355 | "Can't set num_dec_partitions > 1 when using pretrained decdoer " 356 | "embedding") 357 | 358 | with tf.variable_scope( 359 | scope or "embeddings", dtype=dtype, partitioner=enc_partitioner) as scope: 360 | # Share embedding 361 | if share_vocab: 362 | if src_vocab_size != tgt_vocab_size: 363 | raise ValueError("Share embedding but different src/tgt vocab sizes" 364 | " %d vs. %d" % (src_vocab_size, tgt_vocab_size)) 365 | assert src_embed_size == tgt_embed_size 366 | utils.print_out("# Use the same embedding for source and target") 367 | vocab_file = src_vocab_file or tgt_vocab_file 368 | embed_file = src_embed_file or tgt_embed_file 369 | 370 | embedding_encoder = _create_or_load_embed( 371 | "embedding_share", vocab_file, embed_file, 372 | src_vocab_size, src_embed_size, dtype) 373 | embedding_decoder = embedding_encoder 374 | else: 375 | if not use_char_encode: 376 | with tf.variable_scope("encoder", partitioner=enc_partitioner): 377 | embedding_encoder = _create_or_load_embed( 378 | "embedding_encoder", src_vocab_file, src_embed_file, 379 | src_vocab_size, src_embed_size, dtype) 380 | else: 381 | embedding_encoder = None 382 | 383 | with tf.variable_scope("decoder", partitioner=dec_partitioner): 384 | embedding_decoder = _create_or_load_embed( 385 | "embedding_decoder", tgt_vocab_file, tgt_embed_file, 386 | tgt_vocab_size, tgt_embed_size, dtype) 387 | 388 | return embedding_encoder, embedding_decoder 389 | 390 | 391 | def _single_cell(unit_type, num_units, forget_bias, dropout, mode, 392 | residual_connection=False, device_str=None, residual_fn=None): 393 | """Create an instance of a single RNN cell.""" 394 | # dropout (= 1 - keep_prob) is set to 0 during eval and infer 395 | dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0 396 | 397 | # Cell Type 398 | if unit_type == "lstm": 399 | utils.print_out(" LSTM, forget_bias=%g" % forget_bias, new_line=False) 400 | single_cell = tf.contrib.rnn.BasicLSTMCell( 401 | num_units, 402 | forget_bias=forget_bias) 403 | elif unit_type == "gru": 404 | utils.print_out(" GRU", new_line=False) 405 | single_cell = tf.contrib.rnn.GRUCell(num_units) 406 | elif unit_type == "layer_norm_lstm": 407 | utils.print_out(" Layer Normalized LSTM, forget_bias=%g" % forget_bias, 408 | new_line=False) 409 | single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell( 410 | num_units, 411 | forget_bias=forget_bias, 412 | layer_norm=True) 413 | elif unit_type == "nas": 414 | utils.print_out(" NASCell", new_line=False) 415 | single_cell = tf.contrib.rnn.NASCell(num_units) 416 | else: 417 | raise ValueError("Unknown unit type %s!" % unit_type) 418 | 419 | # Dropout (= 1 - keep_prob) 420 | if dropout > 0.0: 421 | single_cell = tf.contrib.rnn.DropoutWrapper( 422 | cell=single_cell, input_keep_prob=(1.0 - dropout)) 423 | utils.print_out(" %s, dropout=%g " %(type(single_cell).__name__, dropout), 424 | new_line=False) 425 | 426 | # Residual 427 | if residual_connection: 428 | single_cell = tf.contrib.rnn.ResidualWrapper( 429 | single_cell, residual_fn=residual_fn) 430 | utils.print_out(" %s" % type(single_cell).__name__, new_line=False) 431 | 432 | # Device Wrapper 433 | if device_str: 434 | single_cell = tf.contrib.rnn.DeviceWrapper(single_cell, device_str) 435 | utils.print_out(" %s, device=%s" % 436 | (type(single_cell).__name__, device_str), new_line=False) 437 | 438 | return single_cell 439 | 440 | 441 | def _cell_list(unit_type, num_units, num_layers, num_residual_layers, 442 | forget_bias, dropout, mode, num_gpus, base_gpu=0, 443 | single_cell_fn=None, residual_fn=None): 444 | """Create a list of RNN cells.""" 445 | if not single_cell_fn: 446 | single_cell_fn = _single_cell 447 | 448 | # Multi-GPU 449 | cell_list = [] 450 | for i in range(num_layers): 451 | utils.print_out(" cell %d" % i, new_line=False) 452 | single_cell = single_cell_fn( 453 | unit_type=unit_type, 454 | num_units=num_units, 455 | forget_bias=forget_bias, 456 | dropout=dropout, 457 | mode=mode, 458 | residual_connection=(i >= num_layers - num_residual_layers), 459 | device_str=get_device_str(i + base_gpu, num_gpus), 460 | residual_fn=residual_fn 461 | ) 462 | utils.print_out("") 463 | cell_list.append(single_cell) 464 | 465 | return cell_list 466 | 467 | 468 | def create_rnn_cell(unit_type, num_units, num_layers, num_residual_layers, 469 | forget_bias, dropout, mode, num_gpus, base_gpu=0, 470 | single_cell_fn=None): 471 | """Create multi-layer RNN cell. 472 | 473 | Args: 474 | unit_type: string representing the unit type, i.e. "lstm". 475 | num_units: the depth of each unit. 476 | num_layers: number of cells. 477 | num_residual_layers: Number of residual layers from top to bottom. For 478 | example, if `num_layers=4` and `num_residual_layers=2`, the last 2 RNN 479 | cells in the returned list will be wrapped with `ResidualWrapper`. 480 | forget_bias: the initial forget bias of the RNNCell(s). 481 | dropout: floating point value between 0.0 and 1.0: 482 | the probability of dropout. this is ignored if `mode != TRAIN`. 483 | mode: either tf.contrib.learn.TRAIN/EVAL/INFER 484 | num_gpus: The number of gpus to use when performing round-robin 485 | placement of layers. 486 | base_gpu: The gpu device id to use for the first RNN cell in the 487 | returned list. The i-th RNN cell will use `(base_gpu + i) % num_gpus` 488 | as its device id. 489 | single_cell_fn: allow for adding customized cell. 490 | When not specified, we default to model_helper._single_cell 491 | Returns: 492 | An `RNNCell` instance. 493 | """ 494 | cell_list = _cell_list(unit_type=unit_type, 495 | num_units=num_units, 496 | num_layers=num_layers, 497 | num_residual_layers=num_residual_layers, 498 | forget_bias=forget_bias, 499 | dropout=dropout, 500 | mode=mode, 501 | num_gpus=num_gpus, 502 | base_gpu=base_gpu, 503 | single_cell_fn=single_cell_fn) 504 | 505 | if len(cell_list) == 1: # Single layer. 506 | return cell_list[0] 507 | else: # Multi layers 508 | return tf.contrib.rnn.MultiRNNCell(cell_list) 509 | 510 | 511 | def gradient_clip(gradients, max_gradient_norm): 512 | """Clipping gradients of a model.""" 513 | clipped_gradients, gradient_norm = tf.clip_by_global_norm( 514 | gradients, max_gradient_norm) 515 | gradient_norm_summary = [tf.summary.scalar("grad_norm", gradient_norm)] 516 | gradient_norm_summary.append( 517 | tf.summary.scalar("clipped_gradient", tf.global_norm(clipped_gradients))) 518 | 519 | return clipped_gradients, gradient_norm_summary, gradient_norm 520 | 521 | 522 | def print_variables_in_ckpt(ckpt_path): 523 | """Print a list of variables in a checkpoint together with their shapes.""" 524 | utils.print_out("# Variables in ckpt %s" % ckpt_path) 525 | reader = tf.train.NewCheckpointReader(ckpt_path) 526 | variable_map = reader.get_variable_to_shape_map() 527 | for key in sorted(variable_map.keys()): 528 | utils.print_out(" %s: %s" % (key, variable_map[key])) 529 | 530 | 531 | def load_model(model, ckpt_path, session, name): 532 | """Load model from a checkpoint.""" 533 | start_time = time.time() 534 | try: 535 | model.saver.restore(session, ckpt_path) 536 | except tf.errors.NotFoundError as e: 537 | utils.print_out("Can't load checkpoint") 538 | print_variables_in_ckpt(ckpt_path) 539 | utils.print_out("%s" % str(e)) 540 | 541 | session.run(tf.tables_initializer()) 542 | utils.print_out( 543 | " loaded %s model parameters from %s, time %.2fs" % 544 | (name, ckpt_path, time.time() - start_time)) 545 | return model 546 | 547 | 548 | def avg_checkpoints(model_dir, num_last_checkpoints, global_step, 549 | global_step_name): 550 | """Average the last N checkpoints in the model_dir.""" 551 | checkpoint_state = tf.train.get_checkpoint_state(model_dir) 552 | if not checkpoint_state: 553 | utils.print_out("# No checkpoint file found in directory: %s" % model_dir) 554 | return None 555 | 556 | # Checkpoints are ordered from oldest to newest. 557 | checkpoints = ( 558 | checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:]) 559 | 560 | if len(checkpoints) < num_last_checkpoints: 561 | utils.print_out( 562 | "# Skipping averaging checkpoints because not enough checkpoints is " 563 | "avaliable." 564 | ) 565 | return None 566 | 567 | avg_model_dir = os.path.join(model_dir, "avg_checkpoints") 568 | if not tf.gfile.Exists(avg_model_dir): 569 | utils.print_out( 570 | "# Creating new directory %s for saving averaged checkpoints." % 571 | avg_model_dir) 572 | tf.gfile.MakeDirs(avg_model_dir) 573 | 574 | utils.print_out("# Reading and averaging variables in checkpoints:") 575 | var_list = tf.contrib.framework.list_variables(checkpoints[0]) 576 | var_values, var_dtypes = {}, {} 577 | for (name, shape) in var_list: 578 | if name != global_step_name: 579 | var_values[name] = np.zeros(shape) 580 | 581 | for checkpoint in checkpoints: 582 | utils.print_out(" %s" % checkpoint) 583 | reader = tf.contrib.framework.load_checkpoint(checkpoint) 584 | for name in var_values: 585 | tensor = reader.get_tensor(name) 586 | var_dtypes[name] = tensor.dtype 587 | var_values[name] += tensor 588 | 589 | for name in var_values: 590 | var_values[name] /= len(checkpoints) 591 | 592 | # Build a graph with same variables in the checkpoints, and save the averaged 593 | # variables into the avg_model_dir. 594 | with tf.Graph().as_default(): 595 | tf_vars = [ 596 | tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name]) 597 | for v in var_values 598 | ] 599 | 600 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 601 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 602 | global_step_var = tf.Variable( 603 | global_step, name=global_step_name, trainable=False) 604 | saver = tf.train.Saver(tf.all_variables()) 605 | 606 | with tf.Session() as sess: 607 | sess.run(tf.initialize_all_variables()) 608 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 609 | six.iteritems(var_values)): 610 | sess.run(assign_op, {p: value}) 611 | 612 | # Use the built saver to save the averaged checkpoint. Only keep 1 613 | # checkpoint and the best checkpoint will be moved to avg_best_metric_dir. 614 | saver.save( 615 | sess, 616 | os.path.join(avg_model_dir, "translate.ckpt")) 617 | 618 | return avg_model_dir 619 | 620 | 621 | def create_or_load_model(model, model_dir, session, name): 622 | """Create translation model and initialize or load parameters in session.""" 623 | latest_ckpt = tf.train.latest_checkpoint(model_dir) 624 | if latest_ckpt: 625 | model = load_model(model, latest_ckpt, session, name) 626 | else: 627 | start_time = time.time() 628 | session.run(tf.global_variables_initializer()) 629 | session.run(tf.tables_initializer()) 630 | utils.print_out(" created %s model with fresh parameters, time %.2fs" % 631 | (name, time.time() - start_time)) 632 | 633 | global_step = model.global_step.eval(session=session) 634 | return model, global_step 635 | 636 | 637 | def compute_perplexity(model, sess, name): 638 | """Compute perplexity of the output of the model. 639 | 640 | Args: 641 | model: model for compute perplexity. 642 | sess: tensorflow session to use. 643 | name: name of the batch. 644 | 645 | Returns: 646 | The perplexity of the eval outputs. 647 | """ 648 | total_loss = 0 649 | total_predict_count = 0 650 | start_time = time.time() 651 | 652 | while True: 653 | try: 654 | output_tuple = model.eval(sess) 655 | total_loss += output_tuple.eval_loss * output_tuple.batch_size 656 | total_predict_count += output_tuple.predict_count 657 | except tf.errors.OutOfRangeError: 658 | break 659 | 660 | perplexity = utils.safe_exp(total_loss / total_predict_count) 661 | utils.print_time(" eval %s: perplexity %.2f" % (name, perplexity), 662 | start_time) 663 | return perplexity 664 | -------------------------------------------------------------------------------- /nmt/nmt_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for nmt.py, train.py and inference.py.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import os 23 | 24 | import tensorflow as tf 25 | 26 | from . import inference 27 | from . import nmt 28 | from . import train 29 | 30 | 31 | def _update_flags(flags, test_name): 32 | """Update flags for basic training.""" 33 | flags.num_train_steps = 100 34 | flags.steps_per_stats = 5 35 | flags.src = "en" 36 | flags.tgt = "vi" 37 | flags.train_prefix = ("nmt/testdata/" 38 | "iwslt15.tst2013.100") 39 | flags.vocab_prefix = ("nmt/testdata/" 40 | "iwslt15.vocab.100") 41 | flags.dev_prefix = ("nmt/testdata/" 42 | "iwslt15.tst2013.100") 43 | flags.test_prefix = ("nmt/testdata/" 44 | "iwslt15.tst2013.100") 45 | flags.out_dir = os.path.join(tf.test.get_temp_dir(), test_name) 46 | 47 | 48 | class NMTTest(tf.test.TestCase): 49 | 50 | def testTrain(self): 51 | """Test the training loop is functional with basic hparams.""" 52 | nmt_parser = argparse.ArgumentParser() 53 | nmt.add_arguments(nmt_parser) 54 | FLAGS, unparsed = nmt_parser.parse_known_args() 55 | 56 | _update_flags(FLAGS, "nmt_train_test") 57 | 58 | default_hparams = nmt.create_hparams(FLAGS) 59 | 60 | train_fn = train.train 61 | nmt.run_main(FLAGS, default_hparams, train_fn, None) 62 | 63 | 64 | def testTrainWithAvgCkpts(self): 65 | """Test the training loop is functional with basic hparams.""" 66 | nmt_parser = argparse.ArgumentParser() 67 | nmt.add_arguments(nmt_parser) 68 | FLAGS, unparsed = nmt_parser.parse_known_args() 69 | 70 | _update_flags(FLAGS, "nmt_train_test_avg_ckpts") 71 | FLAGS.avg_ckpts = True 72 | 73 | default_hparams = nmt.create_hparams(FLAGS) 74 | 75 | train_fn = train.train 76 | nmt.run_main(FLAGS, default_hparams, train_fn, None) 77 | 78 | 79 | def testInference(self): 80 | """Test inference is function with basic hparams.""" 81 | nmt_parser = argparse.ArgumentParser() 82 | nmt.add_arguments(nmt_parser) 83 | FLAGS, unparsed = nmt_parser.parse_known_args() 84 | 85 | _update_flags(FLAGS, "nmt_train_infer") 86 | 87 | # Train one step so we have a checkpoint. 88 | FLAGS.num_train_steps = 1 89 | default_hparams = nmt.create_hparams(FLAGS) 90 | train_fn = train.train 91 | nmt.run_main(FLAGS, default_hparams, train_fn, None) 92 | 93 | # Update FLAGS for inference. 94 | FLAGS.inference_input_file = ("nmt/testdata/" 95 | "iwslt15.tst2013.100.en") 96 | FLAGS.inference_output_file = os.path.join(FLAGS.out_dir, "output") 97 | FLAGS.inference_ref_file = ("nmt/testdata/" 98 | "iwslt15.tst2013.100.vi") 99 | 100 | default_hparams = nmt.create_hparams(FLAGS) 101 | 102 | inference_fn = inference.inference 103 | nmt.run_main(FLAGS, default_hparams, None, inference_fn) 104 | 105 | 106 | if __name__ == "__main__": 107 | tf.test.main() 108 | -------------------------------------------------------------------------------- /nmt/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/nmt/0be864257a76c151eef20ea689755f08bc1faf4e/nmt/scripts/__init__.py -------------------------------------------------------------------------------- /nmt/scripts/bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Python implementation of BLEU and smooth-BLEU. 17 | 18 | This module provides a Python implementation of BLEU and smooth-BLEU. 19 | Smooth BLEU is computed following the method outlined in the paper: 20 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 21 | evaluation metrics for machine translation. COLING 2004. 22 | """ 23 | 24 | import collections 25 | import math 26 | 27 | 28 | def _get_ngrams(segment, max_order): 29 | """Extracts all n-grams upto a given maximum order from an input segment. 30 | 31 | Args: 32 | segment: text segment from which n-grams will be extracted. 33 | max_order: maximum length in tokens of the n-grams returned by this 34 | methods. 35 | 36 | Returns: 37 | The Counter containing all n-grams upto max_order in segment 38 | with a count of how many times each n-gram occurred. 39 | """ 40 | ngram_counts = collections.Counter() 41 | for order in range(1, max_order + 1): 42 | for i in range(0, len(segment) - order + 1): 43 | ngram = tuple(segment[i:i+order]) 44 | ngram_counts[ngram] += 1 45 | return ngram_counts 46 | 47 | 48 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 49 | smooth=False): 50 | """Computes BLEU score of translated segments against one or more references. 51 | 52 | Args: 53 | reference_corpus: list of lists of references for each translation. Each 54 | reference should be tokenized into a list of tokens. 55 | translation_corpus: list of translations to score. Each translation 56 | should be tokenized into a list of tokens. 57 | max_order: Maximum n-gram order to use when computing BLEU score. 58 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 59 | 60 | Returns: 61 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 62 | precisions and brevity penalty. 63 | """ 64 | matches_by_order = [0] * max_order 65 | possible_matches_by_order = [0] * max_order 66 | reference_length = 0 67 | translation_length = 0 68 | for (references, translation) in zip(reference_corpus, 69 | translation_corpus): 70 | reference_length += min(len(r) for r in references) 71 | translation_length += len(translation) 72 | 73 | merged_ref_ngram_counts = collections.Counter() 74 | for reference in references: 75 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 76 | translation_ngram_counts = _get_ngrams(translation, max_order) 77 | overlap = translation_ngram_counts & merged_ref_ngram_counts 78 | for ngram in overlap: 79 | matches_by_order[len(ngram)-1] += overlap[ngram] 80 | for order in range(1, max_order+1): 81 | possible_matches = len(translation) - order + 1 82 | if possible_matches > 0: 83 | possible_matches_by_order[order-1] += possible_matches 84 | 85 | precisions = [0] * max_order 86 | for i in range(0, max_order): 87 | if smooth: 88 | precisions[i] = ((matches_by_order[i] + 1.) / 89 | (possible_matches_by_order[i] + 1.)) 90 | else: 91 | if possible_matches_by_order[i] > 0: 92 | precisions[i] = (float(matches_by_order[i]) / 93 | possible_matches_by_order[i]) 94 | else: 95 | precisions[i] = 0.0 96 | 97 | if min(precisions) > 0: 98 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 99 | geo_mean = math.exp(p_log_sum) 100 | else: 101 | geo_mean = 0 102 | 103 | ratio = float(translation_length) / reference_length 104 | 105 | if ratio > 1.0: 106 | bp = 1. 107 | else: 108 | bp = math.exp(1 - 1. / ratio) 109 | 110 | bleu = geo_mean * bp 111 | 112 | return (bleu, precisions, bp, ratio, translation_length, reference_length) 113 | -------------------------------------------------------------------------------- /nmt/scripts/download_iwslt15.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Download small-scale IWSLT15 Vietnames to English translation data for NMT 3 | # model training. 4 | # 5 | # Usage: 6 | # ./download_iwslt15.sh path-to-output-dir 7 | # 8 | # If output directory is not specified, "./iwslt15" will be used as the default 9 | # output directory. 10 | OUT_DIR="${1:-iwslt15}" 11 | SITE_PREFIX="https://nlp.stanford.edu/projects/nmt/data" 12 | 13 | mkdir -v -p $OUT_DIR 14 | 15 | # Download iwslt15 small dataset from standford website. 16 | echo "Download training dataset train.en and train.vi." 17 | curl -o "$OUT_DIR/train.en" "$SITE_PREFIX/iwslt15.en-vi/train.en" 18 | curl -o "$OUT_DIR/train.vi" "$SITE_PREFIX/iwslt15.en-vi/train.vi" 19 | 20 | echo "Download dev dataset tst2012.en and tst2012.vi." 21 | curl -o "$OUT_DIR/tst2012.en" "$SITE_PREFIX/iwslt15.en-vi/tst2012.en" 22 | curl -o "$OUT_DIR/tst2012.vi" "$SITE_PREFIX/iwslt15.en-vi/tst2012.vi" 23 | 24 | echo "Download test dataset tst2013.en and tst2013.vi." 25 | curl -o "$OUT_DIR/tst2013.en" "$SITE_PREFIX/iwslt15.en-vi/tst2013.en" 26 | curl -o "$OUT_DIR/tst2013.vi" "$SITE_PREFIX/iwslt15.en-vi/tst2013.vi" 27 | 28 | echo "Download vocab file vocab.en and vocab.vi." 29 | curl -o "$OUT_DIR/vocab.en" "$SITE_PREFIX/iwslt15.en-vi/vocab.en" 30 | curl -o "$OUT_DIR/vocab.vi" "$SITE_PREFIX/iwslt15.en-vi/vocab.vi" 31 | -------------------------------------------------------------------------------- /nmt/scripts/rouge.py: -------------------------------------------------------------------------------- 1 | """ROUGE metric implementation. 2 | 3 | Copy from tf_seq2seq/seq2seq/metrics/rouge.py. 4 | This is a modified and slightly extended verison of 5 | https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import itertools 14 | import numpy as np 15 | 16 | #pylint: disable=C0103 17 | 18 | 19 | def _get_ngrams(n, text): 20 | """Calcualtes n-grams. 21 | 22 | Args: 23 | n: which n-grams to calculate 24 | text: An array of tokens 25 | 26 | Returns: 27 | A set of n-grams 28 | """ 29 | ngram_set = set() 30 | text_length = len(text) 31 | max_index_ngram_start = text_length - n 32 | for i in range(max_index_ngram_start + 1): 33 | ngram_set.add(tuple(text[i:i + n])) 34 | return ngram_set 35 | 36 | 37 | def _split_into_words(sentences): 38 | """Splits multiple sentences into words and flattens the result""" 39 | return list(itertools.chain(*[_.split(" ") for _ in sentences])) 40 | 41 | 42 | def _get_word_ngrams(n, sentences): 43 | """Calculates word n-grams for multiple sentences. 44 | """ 45 | assert len(sentences) > 0 46 | assert n > 0 47 | 48 | words = _split_into_words(sentences) 49 | return _get_ngrams(n, words) 50 | 51 | 52 | def _len_lcs(x, y): 53 | """ 54 | Returns the length of the Longest Common Subsequence between sequences x 55 | and y. 56 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 57 | 58 | Args: 59 | x: sequence of words 60 | y: sequence of words 61 | 62 | Returns 63 | integer: Length of LCS between x and y 64 | """ 65 | table = _lcs(x, y) 66 | n, m = len(x), len(y) 67 | return table[n, m] 68 | 69 | 70 | def _lcs(x, y): 71 | """ 72 | Computes the length of the longest common subsequence (lcs) between two 73 | strings. The implementation below uses a DP programming algorithm and runs 74 | in O(nm) time where n = len(x) and m = len(y). 75 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 76 | 77 | Args: 78 | x: collection of words 79 | y: collection of words 80 | 81 | Returns: 82 | Table of dictionary of coord and len lcs 83 | """ 84 | n, m = len(x), len(y) 85 | table = dict() 86 | for i in range(n + 1): 87 | for j in range(m + 1): 88 | if i == 0 or j == 0: 89 | table[i, j] = 0 90 | elif x[i - 1] == y[j - 1]: 91 | table[i, j] = table[i - 1, j - 1] + 1 92 | else: 93 | table[i, j] = max(table[i - 1, j], table[i, j - 1]) 94 | return table 95 | 96 | 97 | def _recon_lcs(x, y): 98 | """ 99 | Returns the Longest Subsequence between x and y. 100 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 101 | 102 | Args: 103 | x: sequence of words 104 | y: sequence of words 105 | 106 | Returns: 107 | sequence: LCS of x and y 108 | """ 109 | i, j = len(x), len(y) 110 | table = _lcs(x, y) 111 | 112 | def _recon(i, j): 113 | """private recon calculation""" 114 | if i == 0 or j == 0: 115 | return [] 116 | elif x[i - 1] == y[j - 1]: 117 | return _recon(i - 1, j - 1) + [(x[i - 1], i)] 118 | elif table[i - 1, j] > table[i, j - 1]: 119 | return _recon(i - 1, j) 120 | else: 121 | return _recon(i, j - 1) 122 | 123 | recon_tuple = tuple(map(lambda x: x[0], _recon(i, j))) 124 | return recon_tuple 125 | 126 | 127 | def rouge_n(evaluated_sentences, reference_sentences, n=2): 128 | """ 129 | Computes ROUGE-N of two text collections of sentences. 130 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/ 131 | papers/rouge-working-note-v1.3.1.pdf 132 | 133 | Args: 134 | evaluated_sentences: The sentences that have been picked by the summarizer 135 | reference_sentences: The sentences from the referene set 136 | n: Size of ngram. Defaults to 2. 137 | 138 | Returns: 139 | A tuple (f1, precision, recall) for ROUGE-N 140 | 141 | Raises: 142 | ValueError: raises exception if a param has len <= 0 143 | """ 144 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 145 | raise ValueError("Collections must contain at least 1 sentence.") 146 | 147 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences) 148 | reference_ngrams = _get_word_ngrams(n, reference_sentences) 149 | reference_count = len(reference_ngrams) 150 | evaluated_count = len(evaluated_ngrams) 151 | 152 | # Gets the overlapping ngrams between evaluated and reference 153 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 154 | overlapping_count = len(overlapping_ngrams) 155 | 156 | # Handle edge case. This isn't mathematically correct, but it's good enough 157 | if evaluated_count == 0: 158 | precision = 0.0 159 | else: 160 | precision = overlapping_count / evaluated_count 161 | 162 | if reference_count == 0: 163 | recall = 0.0 164 | else: 165 | recall = overlapping_count / reference_count 166 | 167 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) 168 | 169 | # return overlapping_count / reference_count 170 | return f1_score, precision, recall 171 | 172 | 173 | def _f_p_r_lcs(llcs, m, n): 174 | """ 175 | Computes the LCS-based F-measure score 176 | Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 177 | rouge-working-note-v1.3.1.pdf 178 | 179 | Args: 180 | llcs: Length of LCS 181 | m: number of words in reference summary 182 | n: number of words in candidate summary 183 | 184 | Returns: 185 | Float. LCS-based F-measure score 186 | """ 187 | r_lcs = llcs / m 188 | p_lcs = llcs / n 189 | beta = p_lcs / (r_lcs + 1e-12) 190 | num = (1 + (beta**2)) * r_lcs * p_lcs 191 | denom = r_lcs + ((beta**2) * p_lcs) 192 | f_lcs = num / (denom + 1e-12) 193 | return f_lcs, p_lcs, r_lcs 194 | 195 | 196 | def rouge_l_sentence_level(evaluated_sentences, reference_sentences): 197 | """ 198 | Computes ROUGE-L (sentence level) of two text collections of sentences. 199 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 200 | rouge-working-note-v1.3.1.pdf 201 | 202 | Calculated according to: 203 | R_lcs = LCS(X,Y)/m 204 | P_lcs = LCS(X,Y)/n 205 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 206 | 207 | where: 208 | X = reference summary 209 | Y = Candidate summary 210 | m = length of reference summary 211 | n = length of candidate summary 212 | 213 | Args: 214 | evaluated_sentences: The sentences that have been picked by the summarizer 215 | reference_sentences: The sentences from the referene set 216 | 217 | Returns: 218 | A float: F_lcs 219 | 220 | Raises: 221 | ValueError: raises exception if a param has len <= 0 222 | """ 223 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 224 | raise ValueError("Collections must contain at least 1 sentence.") 225 | reference_words = _split_into_words(reference_sentences) 226 | evaluated_words = _split_into_words(evaluated_sentences) 227 | m = len(reference_words) 228 | n = len(evaluated_words) 229 | lcs = _len_lcs(evaluated_words, reference_words) 230 | return _f_p_r_lcs(lcs, m, n) 231 | 232 | 233 | def _union_lcs(evaluated_sentences, reference_sentence): 234 | """ 235 | Returns LCS_u(r_i, C) which is the LCS score of the union longest common 236 | subsequence between reference sentence ri and candidate summary C. For example 237 | if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 and 238 | c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 is 239 | "w1 w2" and the longest common subsequence of r_i and c2 is "w1 w3 w5". The 240 | union longest common subsequence of r_i, c1, and c2 is "w1 w2 w3 w5" and 241 | LCS_u(r_i, C) = 4/5. 242 | 243 | Args: 244 | evaluated_sentences: The sentences that have been picked by the summarizer 245 | reference_sentence: One of the sentences in the reference summaries 246 | 247 | Returns: 248 | float: LCS_u(r_i, C) 249 | 250 | ValueError: 251 | Raises exception if a param has len <= 0 252 | """ 253 | if len(evaluated_sentences) <= 0: 254 | raise ValueError("Collections must contain at least 1 sentence.") 255 | 256 | lcs_union = set() 257 | reference_words = _split_into_words([reference_sentence]) 258 | combined_lcs_length = 0 259 | for eval_s in evaluated_sentences: 260 | evaluated_words = _split_into_words([eval_s]) 261 | lcs = set(_recon_lcs(reference_words, evaluated_words)) 262 | combined_lcs_length += len(lcs) 263 | lcs_union = lcs_union.union(lcs) 264 | 265 | union_lcs_count = len(lcs_union) 266 | union_lcs_value = union_lcs_count / combined_lcs_length 267 | return union_lcs_value 268 | 269 | 270 | def rouge_l_summary_level(evaluated_sentences, reference_sentences): 271 | """ 272 | Computes ROUGE-L (summary level) of two text collections of sentences. 273 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 274 | rouge-working-note-v1.3.1.pdf 275 | 276 | Calculated according to: 277 | R_lcs = SUM(1, u)[LCS(r_i,C)]/m 278 | P_lcs = SUM(1, u)[LCS(r_i,C)]/n 279 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 280 | 281 | where: 282 | SUM(i,u) = SUM from i through u 283 | u = number of sentences in reference summary 284 | C = Candidate summary made up of v sentences 285 | m = number of words in reference summary 286 | n = number of words in candidate summary 287 | 288 | Args: 289 | evaluated_sentences: The sentences that have been picked by the summarizer 290 | reference_sentence: One of the sentences in the reference summaries 291 | 292 | Returns: 293 | A float: F_lcs 294 | 295 | Raises: 296 | ValueError: raises exception if a param has len <= 0 297 | """ 298 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 299 | raise ValueError("Collections must contain at least 1 sentence.") 300 | 301 | # total number of words in reference sentences 302 | m = len(_split_into_words(reference_sentences)) 303 | 304 | # total number of words in evaluated sentences 305 | n = len(_split_into_words(evaluated_sentences)) 306 | 307 | union_lcs_sum_across_all_references = 0 308 | for ref_s in reference_sentences: 309 | union_lcs_sum_across_all_references += _union_lcs(evaluated_sentences, 310 | ref_s) 311 | return _f_p_r_lcs(union_lcs_sum_across_all_references, m, n) 312 | 313 | 314 | def rouge(hypotheses, references): 315 | """Calculates average rouge scores for a list of hypotheses and 316 | references""" 317 | 318 | # Filter out hyps that are of 0 length 319 | # hyps_and_refs = zip(hypotheses, references) 320 | # hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0] 321 | # hypotheses, references = zip(*hyps_and_refs) 322 | 323 | # Calculate ROUGE-1 F1, precision, recall scores 324 | rouge_1 = [ 325 | rouge_n([hyp], [ref], 1) for hyp, ref in zip(hypotheses, references) 326 | ] 327 | rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1)) 328 | 329 | # Calculate ROUGE-2 F1, precision, recall scores 330 | rouge_2 = [ 331 | rouge_n([hyp], [ref], 2) for hyp, ref in zip(hypotheses, references) 332 | ] 333 | rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2)) 334 | 335 | # Calculate ROUGE-L F1, precision, recall scores 336 | rouge_l = [ 337 | rouge_l_sentence_level([hyp], [ref]) 338 | for hyp, ref in zip(hypotheses, references) 339 | ] 340 | rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l)) 341 | 342 | return { 343 | "rouge_1/f_score": rouge_1_f, 344 | "rouge_1/r_score": rouge_1_r, 345 | "rouge_1/p_score": rouge_1_p, 346 | "rouge_2/f_score": rouge_2_f, 347 | "rouge_2/r_score": rouge_2_r, 348 | "rouge_2/p_score": rouge_2_p, 349 | "rouge_l/f_score": rouge_l_f, 350 | "rouge_l/r_score": rouge_l_r, 351 | "rouge_l/p_score": rouge_l_p, 352 | } 353 | -------------------------------------------------------------------------------- /nmt/scripts/wmt16_en_de.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2017 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | set -e 18 | 19 | BASE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )" 20 | 21 | OUTPUT_DIR="${1:-wmt16_de_en}" 22 | echo "Writing to ${OUTPUT_DIR}. To change this, set the OUTPUT_DIR environment variable." 23 | 24 | OUTPUT_DIR_DATA="${OUTPUT_DIR}/data" 25 | mkdir -p $OUTPUT_DIR_DATA 26 | 27 | echo "Downloading Europarl v7. This may take a while..." 28 | curl -o ${OUTPUT_DIR_DATA}/europarl-v7-de-en.tgz \ 29 | http://www.statmt.org/europarl/v7/de-en.tgz 30 | 31 | echo "Downloading Common Crawl corpus. This may take a while..." 32 | curl -o ${OUTPUT_DIR_DATA}/common-crawl.tgz \ 33 | http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz 34 | 35 | echo "Downloading News Commentary v11. This may take a while..." 36 | curl -o ${OUTPUT_DIR_DATA}/nc-v11.tgz \ 37 | http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz 38 | 39 | echo "Downloading dev/test sets" 40 | curl -o ${OUTPUT_DIR_DATA}/dev.tgz \ 41 | http://data.statmt.org/wmt16/translation-task/dev.tgz 42 | curl -o ${OUTPUT_DIR_DATA}/test.tgz \ 43 | http://data.statmt.org/wmt16/translation-task/test.tgz 44 | 45 | # Extract everything 46 | echo "Extracting all files..." 47 | mkdir -p "${OUTPUT_DIR_DATA}/europarl-v7-de-en" 48 | tar -xvzf "${OUTPUT_DIR_DATA}/europarl-v7-de-en.tgz" -C "${OUTPUT_DIR_DATA}/europarl-v7-de-en" 49 | mkdir -p "${OUTPUT_DIR_DATA}/common-crawl" 50 | tar -xvzf "${OUTPUT_DIR_DATA}/common-crawl.tgz" -C "${OUTPUT_DIR_DATA}/common-crawl" 51 | mkdir -p "${OUTPUT_DIR_DATA}/nc-v11" 52 | tar -xvzf "${OUTPUT_DIR_DATA}/nc-v11.tgz" -C "${OUTPUT_DIR_DATA}/nc-v11" 53 | mkdir -p "${OUTPUT_DIR_DATA}/dev" 54 | tar -xvzf "${OUTPUT_DIR_DATA}/dev.tgz" -C "${OUTPUT_DIR_DATA}/dev" 55 | mkdir -p "${OUTPUT_DIR_DATA}/test" 56 | tar -xvzf "${OUTPUT_DIR_DATA}/test.tgz" -C "${OUTPUT_DIR_DATA}/test" 57 | 58 | # Concatenate Training data 59 | cat "${OUTPUT_DIR_DATA}/europarl-v7-de-en/europarl-v7.de-en.en" \ 60 | "${OUTPUT_DIR_DATA}/common-crawl/commoncrawl.de-en.en" \ 61 | "${OUTPUT_DIR_DATA}/nc-v11/training-parallel-nc-v11/news-commentary-v11.de-en.en" \ 62 | > "${OUTPUT_DIR}/train.en" 63 | wc -l "${OUTPUT_DIR}/train.en" 64 | 65 | cat "${OUTPUT_DIR_DATA}/europarl-v7-de-en/europarl-v7.de-en.de" \ 66 | "${OUTPUT_DIR_DATA}/common-crawl/commoncrawl.de-en.de" \ 67 | "${OUTPUT_DIR_DATA}/nc-v11/training-parallel-nc-v11/news-commentary-v11.de-en.de" \ 68 | > "${OUTPUT_DIR}/train.de" 69 | wc -l "${OUTPUT_DIR}/train.de" 70 | 71 | # Clone Moses 72 | if [ ! -d "${OUTPUT_DIR}/mosesdecoder" ]; then 73 | echo "Cloning moses for data processing" 74 | git clone https://github.com/moses-smt/mosesdecoder.git "${OUTPUT_DIR}/mosesdecoder" 75 | fi 76 | 77 | # Convert SGM files 78 | # Convert newstest2014 data into raw text format 79 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ 80 | < ${OUTPUT_DIR_DATA}/dev/dev/newstest2014-deen-src.de.sgm \ 81 | > ${OUTPUT_DIR_DATA}/dev/dev/newstest2014.de 82 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ 83 | < ${OUTPUT_DIR_DATA}/dev/dev/newstest2014-deen-ref.en.sgm \ 84 | > ${OUTPUT_DIR_DATA}/dev/dev/newstest2014.en 85 | 86 | # Convert newstest2015 data into raw text format 87 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ 88 | < ${OUTPUT_DIR_DATA}/dev/dev/newstest2015-deen-src.de.sgm \ 89 | > ${OUTPUT_DIR_DATA}/dev/dev/newstest2015.de 90 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ 91 | < ${OUTPUT_DIR_DATA}/dev/dev/newstest2015-deen-ref.en.sgm \ 92 | > ${OUTPUT_DIR_DATA}/dev/dev/newstest2015.en 93 | 94 | # Convert newstest2016 data into raw text format 95 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ 96 | < ${OUTPUT_DIR_DATA}/test/test/newstest2016-deen-src.de.sgm \ 97 | > ${OUTPUT_DIR_DATA}/test/test/newstest2016.de 98 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ 99 | < ${OUTPUT_DIR_DATA}/test/test/newstest2016-deen-ref.en.sgm \ 100 | > ${OUTPUT_DIR_DATA}/test/test/newstest2016.en 101 | 102 | # Copy dev/test data to output dir 103 | cp ${OUTPUT_DIR_DATA}/dev/dev/newstest20*.de ${OUTPUT_DIR} 104 | cp ${OUTPUT_DIR_DATA}/dev/dev/newstest20*.en ${OUTPUT_DIR} 105 | cp ${OUTPUT_DIR_DATA}/test/test/newstest20*.de ${OUTPUT_DIR} 106 | cp ${OUTPUT_DIR_DATA}/test/test/newstest20*.en ${OUTPUT_DIR} 107 | 108 | # Tokenize data 109 | for f in ${OUTPUT_DIR}/*.de; do 110 | echo "Tokenizing $f..." 111 | ${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -q -l de -threads 8 < $f > ${f%.*}.tok.de 112 | done 113 | 114 | for f in ${OUTPUT_DIR}/*.en; do 115 | echo "Tokenizing $f..." 116 | ${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -q -l en -threads 8 < $f > ${f%.*}.tok.en 117 | done 118 | 119 | # Clean train corpora 120 | for f in ${OUTPUT_DIR}/train.tok.en; do 121 | fbase=${f%.*} 122 | echo "Cleaning ${fbase}..." 123 | ${OUTPUT_DIR}/mosesdecoder/scripts/training/clean-corpus-n.perl $fbase de en "${fbase}.clean" 1 80 124 | done 125 | 126 | # Generate Subword Units (BPE) 127 | # Clone Subword NMT 128 | if [ ! -d "${OUTPUT_DIR}/subword-nmt" ]; then 129 | git clone https://github.com/rsennrich/subword-nmt.git "${OUTPUT_DIR}/subword-nmt" 130 | fi 131 | 132 | # Learn Shared BPE 133 | for merge_ops in 32000; do 134 | echo "Learning BPE with merge_ops=${merge_ops}. This may take a while..." 135 | cat "${OUTPUT_DIR}/train.tok.clean.de" "${OUTPUT_DIR}/train.tok.clean.en" | \ 136 | ${OUTPUT_DIR}/subword-nmt/learn_bpe.py -s $merge_ops > "${OUTPUT_DIR}/bpe.${merge_ops}" 137 | 138 | echo "Apply BPE with merge_ops=${merge_ops} to tokenized files..." 139 | for lang in en de; do 140 | for f in ${OUTPUT_DIR}/*.tok.${lang} ${OUTPUT_DIR}/*.tok.clean.${lang}; do 141 | outfile="${f%.*}.bpe.${merge_ops}.${lang}" 142 | ${OUTPUT_DIR}/subword-nmt/apply_bpe.py -c "${OUTPUT_DIR}/bpe.${merge_ops}" < $f > "${outfile}" 143 | echo ${outfile} 144 | done 145 | done 146 | 147 | # Create vocabulary file for BPE 148 | echo -e "\n\n" > "${OUTPUT_DIR}/vocab.bpe.${merge_ops}" 149 | cat "${OUTPUT_DIR}/train.tok.clean.bpe.${merge_ops}.en" "${OUTPUT_DIR}/train.tok.clean.bpe.${merge_ops}.de" | \ 150 | ${OUTPUT_DIR}/subword-nmt/get_vocab.py | cut -f1 -d ' ' >> "${OUTPUT_DIR}/vocab.bpe.${merge_ops}" 151 | 152 | done 153 | 154 | # Duplicate vocab file with language suffix 155 | cp "${OUTPUT_DIR}/vocab.bpe.32000" "${OUTPUT_DIR}/vocab.bpe.32000.en" 156 | cp "${OUTPUT_DIR}/vocab.bpe.32000" "${OUTPUT_DIR}/vocab.bpe.32000.de" 157 | 158 | echo "All done." 159 | -------------------------------------------------------------------------------- /nmt/standard_hparams/iwslt15.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention": "scaled_luong", 3 | "attention_architecture": "standard", 4 | "batch_size": 128, 5 | "colocate_gradients_with_ops": true, 6 | "dropout": 0.2, 7 | "encoder_type": "bi", 8 | "eos": "", 9 | "forget_bias": 1.0, 10 | "infer_batch_size": 32, 11 | "init_weight": 0.1, 12 | "learning_rate": 1.0, 13 | "max_gradient_norm": 5.0, 14 | "metrics": ["bleu"], 15 | "num_buckets": 5, 16 | "num_encoder_layers": 2, 17 | "num_decoder_layers": 2, 18 | "num_train_steps": 12000, 19 | "decay_scheme": "luong234", 20 | "num_units": 512, 21 | "optimizer": "sgd", 22 | "residual": false, 23 | "share_vocab": false, 24 | "subword_option": "", 25 | "sos": "", 26 | "src_max_len": 50, 27 | "src_max_len_infer": null, 28 | "steps_per_external_eval": null, 29 | "steps_per_stats": 100, 30 | "tgt_max_len": 50, 31 | "tgt_max_len_infer": null, 32 | "time_major": true, 33 | "unit_type": "lstm", 34 | "beam_width": 10 35 | } 36 | -------------------------------------------------------------------------------- /nmt/standard_hparams/wmt16.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention": "normed_bahdanau", 3 | "attention_architecture": "standard", 4 | "batch_size": 128, 5 | "colocate_gradients_with_ops": true, 6 | "dropout": 0.2, 7 | "encoder_type": "bi", 8 | "eos": "", 9 | "forget_bias": 1.0, 10 | "infer_batch_size": 32, 11 | "init_weight": 0.1, 12 | "learning_rate": 1.0, 13 | "max_gradient_norm": 5.0, 14 | "metrics": ["bleu"], 15 | "num_buckets": 5, 16 | "num_encoder_layers": 4, 17 | "num_decoder_layers": 4, 18 | "num_train_steps": 340000, 19 | "decay_scheme": "luong10", 20 | "num_units": 1024, 21 | "optimizer": "sgd", 22 | "residual": false, 23 | "share_vocab": false, 24 | "subword_option": "bpe", 25 | "sos": "", 26 | "src_max_len": 50, 27 | "src_max_len_infer": null, 28 | "steps_per_external_eval": null, 29 | "steps_per_stats": 100, 30 | "tgt_max_len": 50, 31 | "tgt_max_len_infer": null, 32 | "time_major": true, 33 | "unit_type": "lstm", 34 | "beam_width": 10 35 | } 36 | -------------------------------------------------------------------------------- /nmt/standard_hparams/wmt16_gnmt_4_layer.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention": "normed_bahdanau", 3 | "attention_architecture": "gnmt_v2", 4 | "batch_size": 128, 5 | "colocate_gradients_with_ops": true, 6 | "dropout": 0.2, 7 | "encoder_type": "gnmt", 8 | "eos": "", 9 | "forget_bias": 1.0, 10 | "infer_batch_size": 32, 11 | "init_weight": 0.1, 12 | "learning_rate": 1.0, 13 | "max_gradient_norm": 5.0, 14 | "metrics": ["bleu"], 15 | "num_buckets": 5, 16 | "num_encoder_layers": 4, 17 | "num_decoder_layers": 4, 18 | "num_train_steps": 340000, 19 | "decay_scheme": "luong10", 20 | "num_units": 1024, 21 | "optimizer": "sgd", 22 | "residual": true, 23 | "share_vocab": false, 24 | "subword_option": "bpe", 25 | "sos": "", 26 | "src_max_len": 50, 27 | "src_max_len_infer": null, 28 | "steps_per_external_eval": null, 29 | "steps_per_stats": 100, 30 | "tgt_max_len": 50, 31 | "tgt_max_len_infer": null, 32 | "time_major": true, 33 | "unit_type": "lstm", 34 | "beam_width": 10, 35 | "length_penalty_weight": 1.0 36 | } 37 | -------------------------------------------------------------------------------- /nmt/standard_hparams/wmt16_gnmt_8_layer.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention": "normed_bahdanau", 3 | "attention_architecture": "gnmt_v2", 4 | "batch_size": 128, 5 | "colocate_gradients_with_ops": true, 6 | "dropout": 0.2, 7 | "encoder_type": "gnmt", 8 | "eos": "", 9 | "forget_bias": 1.0, 10 | "infer_batch_size": 32, 11 | "init_weight": 0.1, 12 | "learning_rate": 1.0, 13 | "max_gradient_norm": 5.0, 14 | "metrics": ["bleu"], 15 | "num_buckets": 5, 16 | "num_encoder_layers": 8, 17 | "num_decoder_layers": 8, 18 | "num_train_steps": 340000, 19 | "decay_scheme": "luong10", 20 | "num_units": 1024, 21 | "optimizer": "sgd", 22 | "residual": true, 23 | "share_vocab": false, 24 | "subword_option": "bpe", 25 | "sos": "", 26 | "src_max_len": 50, 27 | "src_max_len_infer": null, 28 | "steps_per_external_eval": null, 29 | "steps_per_stats": 100, 30 | "tgt_max_len": 50, 31 | "tgt_max_len_infer": null, 32 | "time_major": true, 33 | "unit_type": "lstm", 34 | "beam_width": 10, 35 | "length_penalty_weight": 1.0 36 | } 37 | -------------------------------------------------------------------------------- /nmt/testdata/iwslt15.tst2013.100.en: -------------------------------------------------------------------------------- 1 | When I was little , I thought my country was the best on the planet , and I grew up singing a song called " Nothing To Envy . " 2 | And I was very proud . 3 | In school , we spent a lot of time studying the history of Kim Il-Sung , but we never learned much about the outside world , except that America , South Korea , Japan are the enemies . 4 | Although I often wondered about the outside world , I thought I would spend my entire life in North Korea , until everything suddenly changed . 5 | When I was seven years old , I saw my first public execution , but I thought my life in North Korea was normal . 6 | My family was not poor , and myself , I had never experienced hunger . 7 | But one day , in 1995 , my mom brought home a letter from a coworker 's sister . 8 | It read , " When you read this , all five family members will not exist in this world , because we haven 't eaten for the past two weeks . 9 | We are lying on the floor together , and our bodies are so weak we are ready to die . " 10 | I was so shocked . 11 | This was the first time I heard that people in my country were suffering . 12 | Soon after , when I was walking past a train station , I saw something terrible that I can 't erase from my memory . 13 | A lifeless woman was lying on the ground , while an emaciated child in her arms just stared helplessly at his mother 's face . 14 | But nobody helped them , because they were so focused on taking care of themselves and their families . 15 | A huge famine hit North Korea in the mid-1990s . 16 | Ultimately , more than a million North Koreans died during the famine , and many only survived by eating grass , bugs and tree bark . 17 | Power outages also became more and more frequent , so everything around me was completely dark at night except for the sea of lights in China , just across the river from my home . 18 | I always wondered why they had lights but we didn 't . 19 | This is a satellite picture showing North Korea at night compared to neighbors . 20 | This is the Amrok River , which serves as a part of the border between North Korea and China . 21 | As you can see , the river can be very narrow at certain points , allowing North Koreans to secretly cross . 22 | But many die . 23 | Sometimes , I saw dead bodies floating down the river . 24 | I can 't reveal many details [ about ] how I left North Korea , but I only can say that during the ugly years of the famine I was sent to China to live with distant relatives . 25 | But I only thought that I would be separated from my family for a short time . 26 | I could have never imagined that it would take 14 years to live together . 27 | In China , it was hard living as a young girl without my family . 28 | I had no idea what life was going to be like as a North Korean refugee , but I soon learned it 's not only extremely difficult , it 's also very dangerous , since North Korean refugees are considered in China as illegal migrants . 29 | So I was living in constant fear that my identity could be revealed , and I would be repatriated to a horrible fate back in North Korea . 30 | One day , my worst nightmare came true , when I was caught by the Chinese police and brought to the police station for interrogation . 31 | Someone had accused me of being North Korean , so they tested my Chinese language abilities and asked me tons of questions . 32 | I was so scared , I thought my heart was going to explode . 33 | If anything seemed unnatural , I could be imprisoned and repatriated . 34 | I thought my life was over , but I managed to control all the emotions inside me and answer the questions . 35 | After they finished questioning me , one official said to another , " This was a false report . 36 | She 's not North Korean . " 37 | And they let me go . It was a miracle . 38 | Some North Koreans in China seek asylum in foreign embassies , but many can be caught by the Chinese police and repatriated . 39 | These girls were so lucky . 40 | Even though they were caught , they were eventually released after heavy international pressure . 41 | These North Koreans were not so lucky . 42 | Every year , countless North Koreans are caught in China and repatriated to North Korea , where they can be tortured , imprisoned or publicly executed . 43 | Even though I was really fortunate to get out , many other North Koreans have not been so lucky . 44 | It 's tragic that North Koreans have to hide their identities and struggle so hard just to survive . 45 | Even after learning a new language and getting a job , their whole world can be turned upside down in an instant . 46 | That 's why , after 10 years of hiding my identity , I decided to risk going to South Korea , and I started a new life yet again . 47 | Settling down in South Korea was a lot more challenging than I had expected . 48 | English was so important in South Korea , so I had to start learning my third language . 49 | Also , I realized there was a wide gap between North and South . 50 | We are all Korean , but inside , we have become very different due to 67 years of division . 51 | I even went through an identity crisis . 52 | Am I South Korean or North Korean ? 53 | Where am I from ? Who am I ? 54 | Suddenly , there was no country I could proudly call my own . 55 | Even though adjusting to life in South Korea was not easy , I made a plan . 56 | I started studying for the university entrance exam . 57 | Just as I was starting to get used to my new life , I received a shocking phone call . 58 | The North Korean authorities intercepted some money that I sent to my family , and , as a punishment , my family was going to be forcibly removed to a desolate location in the countryside . 59 | They had to get out quickly , so I started planning how to help them escape . 60 | North Koreans have to travel incredible distances on the path to freedom . 61 | It 's almost impossible to cross the border between North Korea and South Korea , so , ironically , I took a flight back to China and I headed toward the North Korean border . 62 | Since my family couldn 't speak Chinese , I had to guide them , somehow , through more than 2,000 miles in China and then into Southeast Asia . 63 | The journey by bus took one week , and we were almost caught several times . 64 | One time , our bus was stopped and boarded by a Chinese police officer . 65 | He took everyone 's I.D. cards , and he started asking them questions . 66 | Since my family couldn 't understand Chinese , I thought my family was going to be arrested . 67 | As the Chinese officer approached my family , I impulsively stood up , and I told him that these are deaf and dumb people that I was chaperoning . 68 | He looked at me suspiciously , but luckily he believed me . 69 | We made it all the way to the border of Laos , but I had to spend almost all my money to bribe the border guards in Laos . 70 | But even after we got past the border , my family was arrested and jailed for illegal border crossing . 71 | After I paid the fine and bribe , my family was released in one month , but soon after , my family was arrested and jailed again in the capital of Laos . 72 | This was one of the lowest points in my life . 73 | I did everything to get my family to freedom , and we came so close , but my family was thrown in jail just a short distance from the South Korean embassy . 74 | I went back and forth between the immigration office and the police station , desperately trying to get my family out , but I didn 't have enough money to pay a bribe or fine anymore . 75 | I lost all hope . 76 | At that moment , I heard one man 's voice ask me , " What 's wrong ? " 77 | I was so surprised that a total stranger cared enough to ask . 78 | In my broken English , and with a dictionary , I explained the situation , and without hesitating , the man went to the ATM and he paid the rest of the money for my family and two other North Koreans to get out of jail . 79 | I thanked him with all my heart , and I asked him , " Why are you helping me ? " 80 | " I 'm not helping you , " he said . 81 | " I 'm helping the North Korean people . " 82 | I realized that this was a symbolic moment in my life . 83 | The kind stranger symbolized new hope for me and the North Korean people when we needed it most , and he showed me the kindness of strangers and the support of the international community are truly the rays of hope we North Korean people need . 84 | Eventually , after our long journey , my family and I were reunited in South Korea , but getting to freedom is only half the battle . 85 | Many North Koreans are separated from their families , and when they arrive in a new country , they start with little or no money . 86 | So we can benefit from the international community for education , English language training , job training , and more . 87 | We can also act as a bridge between the people inside North Korea and the outside world , because many of us stay in contact with family members still inside , and we send information and money that is helping to change North Korea from inside . 88 | I 've been so lucky , received so much help and inspiration in my life , so I want to help give aspiring North Koreans a chance to prosper with international support . 89 | I 'm confident that you will see more and more North Koreans succeeding all over the world , including the TED stage . 90 | Thank you . 91 | Today I have just one request . 92 | Please don 't tell me I 'm normal . 93 | Now I 'd like to introduce you to my brothers . 94 | Remi is 22 , tall and very handsome . 95 | He 's speechless , but he communicates joy in a way that some of the best orators cannot . 96 | Remi knows what love is . 97 | He shares it unconditionally and he shares it regardless . 98 | He 's not greedy . He doesn 't see skin color . 99 | He doesn 't care about religious differences , and get this : He has never told a lie . 100 | When he sings songs from our childhood , attempting words that not even I could remember , he reminds me of one thing : how little we know about the mind , and how wonderful the unknown must be . 101 | -------------------------------------------------------------------------------- /nmt/testdata/iwslt15.tst2013.100.vi: -------------------------------------------------------------------------------- 1 | Khi tôi còn nhỏ , Tôi nghĩ rằng BắcTriều Tiên là đất nước tốt nhất trên thế giới và tôi thường hát bài " Chúng ta chẳng có gì phải ghen tị . " 2 | Tôi đã rất tự hào về đất nước tôi . 3 | Ở trường , chúng tôi dành rất nhiều thời gian để học về cuộc đời của chủ tịch Kim II- Sung , nhưng lại không học nhiều về thế giới bên ngoài , ngoại trừ việc Hoa Kỳ , Hàn Quốc và Nhật Bản là kẻ thù của chúng tôi . 4 | Mặc dù tôi đã từng tự hỏi không biết thế giới bên ngoài kia như thế nào , nhưng tôi vẫn nghĩ rằng mình sẽ sống cả cuộc đời ở BắcTriều Tiên , cho tới khi tất cả mọi thứ đột nhiên thay đổi . 5 | Khi tôi lên 7 , tôi chứng kiến cảnh người ta xử bắn công khai lần đầu tiên trong đời , nhưng tôi vẫn nghĩ cuộc sống của mình ở đây là hoàn toàn bình thường . 6 | Gia đình của tôi không nghèo , và bản thân tôi thì chưa từng phải chịu đói . 7 | Nhưng vào một ngày của năm 1995 , mẹ tôi mang về nhà một lá thư từ một người chị em cùng chỗ làm với mẹ . 8 | Trong đó có viết : Khi chị đọc được những dòng này thì cả gia đình 5 người của em đã không còn trên cõi đời này nữa , bởi vì cả nhà em đã không có gì để ăn trong hai tuần . 9 | Tất cả cùng nằm trên sàn , và cơ thể chúng tôi yếu đến có thể cảm thấy như cái chết đang đến rất gần . 10 | Tôi đã bị sốc . 11 | Vì đó là lần đầu tiên tôi biết rằng đồng bào của tôi đang phải chịu đựng như vậy . 12 | Không lâu sau đó , khi tôi đi qua một nhà ga , tôi nhìn thấy một cảnh tượng kinh hoàng mà tôi không bao giờ có thể quên 13 | Trên nền nhà ga là xác chết của một người đàn bà hai tay vẫn đang ôm một đứa bé hốc hác và đứa bé chỉ biết nhìn chằm chằm vào khuôn mặt của mẹ nó . 14 | Nhưng không có ai giúp họ , bởi vì tất cả đều đang phải lo cho chính mình và cả gia đình . 15 | Vào giữa những năm 90 , Bắc Triều Tiên trải qua một nạn đói trầm trọng . 16 | Nó khiến hơn một triệu người Triều Tiên chết trong nạn đói , và nhiều người chỉ sống sót phải ăn cỏ , sâu bọ và vỏ cây . 17 | Việc cúp điện ngày càng xảy ra thường xuyên , vì thế mọi thứ xung quanh tôi đều chìm vào bóng tối khi đêm đến ngoại trừ ánh sáng đèn từ phía Trung Quốc chỉ cách nhà tôi một con sông . 18 | Tôi lúc nào cũng tự hỏi là tại sao họ lại có điện còn chúng tôi thì không . 19 | Đây là một bức ảnh từ vệ tinh chụp Bắc Triều Tiên vào ban đêm trong tương quan với các nước xung quanh . 20 | Đây là sông Áp Lục nó là biên giới tự nhiên giữa Bắc Triều Tiên và Trung Quốc . 21 | Có thể thấy là lòng sông có đoạn rất hẹp vì thế một số người Bắc Triều Tiên bí mật vượt sang Trung Quốc . 22 | Nhưng rất nhiều người đã chết . 23 | Và tôi đã nhìn thấy xác họ nổi trên sông . 24 | Tôi không thể nói cụ thể về việc mình đã trốn khỏi Bắc Triều Tiên như thế nào chỉ có thể nói rằng trong những năm tháng khốn khó vì nạn đói ấy tôi được gửi sang Trung Quốc để sống với một người họ hàng xa . 25 | Lúc đó , tôi chỉ nghĩ rằng mình sẽ phải xa gia đình một thời gian ngắn . 26 | chứ không bao giờ tôi có thể tưởng tượng rằng tôi sẽ phải xa họ những 14 năm ròng . 27 | Ở Trung Quốc , cuộc sống của một cô bé bị cách ly khỏi gia đình như tôi rất khó khăn . 28 | Tôi đã không tưởng được những gì xảy đến với cuộc sống của một người tị nạn từ Bắc Triều Tiên thì sẽ như thế nào , nhưng tôi sớm nhận ra rằng nó không những rất khó khăn , mà còn vô cùng nguy hiểm , vì những người tị nạn từ Bắc Triều Tiên vào Trung Quốc đều bị coi là dân nhập cư trái phép . 29 | Tôi luôn sống trong một nỗi sợ thường trực rằng danh tính của tôi sẽ bị phát hiện , và tôi sẽ bị trả về với cuộc sống cũ ở Bắc Triều Tiên . 30 | Một ngày , cơn ác mộng đó đã thành sự thật , tôi đã bị cảnh sát Trung Quốc bắt và đưa đến đồn cảnh sát để chất vấn . 31 | Có ai đó đã báo với họ rằng tôi là người Bắc Triều Tiên , vì thế họ đã kiểm tra khả năng tiếng Trung của tôi và hỏi tôi rất nhiều câu hỏi . 32 | Tôi đã vô cùng sợ hãi , và có cảm giác như tim mình sắp nổ tung . 33 | Vì nếu như họ thấy có điều gì không tự nhiên , tôi sẽ bị tống vào tù và rồi bị trả về nước . 34 | Tôi nghĩ cuộc đời mình đến đây là chấm dứt , nhưng tôi vẫn cố gắng điều khiển những cảm xúc của mình và trả lời những câu hỏi của họ . 35 | Sau khi hỏi xong , một trong hai cảnh sát nói với người kia , Đây là một vụ chỉ điểm sai . 36 | Nó không phải là người Bắc Triều Tiên . " 37 | Và họ thả tôi ra . Đó quả là một phép màu . 38 | Một số người Bắc Triều Tiên ở Trung Quốc đã đến những đại sứ quán của nước ngoài để xin tị nạn , nhưng rất nhiều trong số đó đã bị bắt bởi cảnh sát Trung Quốc và bị trả về nước . 39 | những cô gái này đã rất may mắn . 40 | vì mặc dù đã bị bắt , nhưng cuối cùng học cũng được thả ra nhờ vào sức ép từ cộng đồng quốc tế . 41 | Nhưng những người Bắc Triều Tiên này thì không được may mắn như vậy . 42 | Hàng năm , có vô số người Bắc Triều Tiên bị bắt ở Trung Quốc và bị trả về nước , nơi mà họ bị tra tấn , bị giam cầm hoặc bị xử tử công khai . 43 | Trong khi tôi rất may mắn vì đã được thả ra thì rất nhiều đồng bào của tôi lại không được như vậy . 44 | Việc người Bắc Triều Tiên phải che dấu danh tính của mình và đấu tranh để tồn tại quả là một bi kịch . 45 | Kể cả khi đã học tiếng Trung và tìm được một công việc , thì cuộc sống của họ cũng có thể bị đảo lộn hoàn toàn chỉ trong một khoảng khắc . 46 | Đó là lý do tại sao sau 10 năm che dấu danh tính thật tôi quyết định liều mình đi đến Hàn Quốc , để bắt đầu một cuộc sống mới một lần nữa . 47 | Việc ổn định cuộc sống ở đây khó khăn hơn nhiều so với tôi tưởng tượng . 48 | Vì ở Hàn Quốc , tiếng Anh có vị trí vô cùng quan trọng , nên tôi đã bắt đầu học tiếng Anh , ngôn ngữ thứ ba của tôi . 49 | Tôi cũng nhận ra một khoảng cách rất lớn giữa người Nam và Bắc Triều Tiên . 50 | Chúng tôi đều là người Triều Tiên , nhưng đã trở nên rất khác nhau do hậu quả của 67 năm bị chia cắt . 51 | Tôi đã trải qua một cuộc khủng hoảng về nguồn gốc của mình . 52 | Tôi là người Nam Triều Tiên hay Bắc Triều Tiên ? 53 | Tôi đến từ đâu ? và Tôi là ai ? 54 | Bỗng nhiên , tôi chẳng có một đất nước nào để có thể tự hào gọi là Tổ quốc . 55 | Mặc dù để thích ứng với cuộc sống ở Hàn Quốc thì không dễ chút nào Nhưng tôi đã lập một kế hoạch . 56 | và bắt đầu học để chuẩn bị cho kì thi đại học . 57 | Tuy vậy , ngay khi tôi vừa mới làm quen với cuộc sống ở đây , thì tôi được báo một tin khủng khiếp . 58 | Chính quyền Bắc Triều Tiên đã phát hiện ra số tiền mà tôi gửi về cho gia đình , và , để trừng phạt , họ sẽ bắt gia đình tôi phải chuyển về một vùng bị cách ly ở nông thôn . 59 | Để họ có thể nhanh chóng thoát ra khỏi đó , tôi bắt đầu lập kế hoạch giúp gia đình mình trốn thoát . 60 | Người Bắc Triều Tiên đã phải vượt qua những khoảng cách dường như không tưởng để đến với tự do . 61 | Bời vì việc vượt biên từ Bắc sang Nam Triều Tiên gần như là không thể , vì thế , tôi phải bay sang Trung Quốc rồi lại đi ngược về phía biên giới Bắc Triều Tiên . 62 | Bởi vì gia đình tôi không biết tiếng Trung , nên tôi phải đi cùng mọi người qua 2000 dặm ở Trung Quốc rồi vào đến Đông Nam Á 63 | Cuộc hành trình bằng xe buýt kéo dài khoảng 1 tuần , và đã vài lần chúng tôi suýt bị bắt . 64 | Một lần , xe của chúng tôi bị chặn lại và bị khám xét bởi một cảnh sát Trung quốc . 65 | Anh ta thu chứng minh thư của tất cả mọi người , và bắt đầu tra hỏi . 66 | Bởi vì gia đình tôi không hiểu tiếng Trung , nên tôi đã nghĩ rằng họ sẽ bị bắt . 67 | Khi người cảnh sát Trung Quốc đến gần họ , tôi đã ngay lập tức đứng dậy và nói với anh ta rằng đây là những người câm điếc mà tôi đang phải đi cùng . 68 | Anh ta nhìn tôi đầy nghi ngờ , nhưng thật may mắn là anh ta tin lời tôi . 69 | Và chúng tôi tiếp tục đi cho tới biên giới Lào , nhung tôi đã phải sử dụng gần như toàn bộ số tiền mà mình có để hối lộ cho những người canh gác biên giới Lào . 70 | Tuy nhiên , sau khi chúng tôi qua được biên giới , gia đình tôi lại bị bắt vào tù vì tội vượt biên trái phép . 71 | Sau khi tôi nộp tiền phạt và đưa hối lộ , gia đình tôi được thả ra trong một tháng , rồi lại bị bắt lần nữa ở thủ đô của Lào 72 | Đó là thời điểm tuyệt vọng nhất trong cuộc đời tôi . 73 | Tôi đã làm đủ mọi cách để đưa gia đình mình đến với tự do , và chũng tôi gần như đã thành công nhưng họ lại bị bắt trong khi chúng tôi chỉ còn cách đại sứ quán Hàn Quốc một khoảng cách rất ngắn nữa thôi . 74 | Tôi đi đi về về giữa phòng xuất nhập cảnh và đồn cảnh sát , tuyệt vọng tìm cách để đưa gia đình mình thoát khỏi đó , nhưng tôi không còn đủ tiền để hối lộ hay trả tiền phạt nữa . 75 | Tôi hoàn toàn tuyệt vọng . 76 | Đúng vào lúc đó , có một người đàn ông đã hỏi tôi , " Có chuyện gì vậy ? " 77 | Tôi vô cùng ngạc nhiên khi một người hoàn toàn xa lạ lại quan tâm tới mức hỏi tôi như vậy . 78 | Bằng vốn tiếng Anh ít ỏi của mình và một quyển từ điển , tôi đã kể cho cho ông ta nghe hoàn cảnh của gia đình tôi . Không một chút do dự , người đàn ông đó đã đi tới máy ATM. trả tất cả số tiền còn thiếu cho cả gia đình tôi và hai người Bắc Triều Tiên khác để họ được ra tù 79 | Tôi đã cám ơn ông ta bằng cả trái tim mình , và tôi cũng hỏi , " Tại sao ông lại giúp đỡ tôi ? " 80 | " Tôi không giúp đỡ cô , " Ông ta trả lời . 81 | " ' Tôi đang giúp người Bắc Triều Tiên . " 82 | Tôi nhận ra rằng đó là một khoảng khắc có ý nghĩa vô cùng to lớn trong cuộc đời tôi . 83 | Lòng tốt từ người đàn ông xa lạ trở thành biểu tượng hy vọng mới cho tôi và cả những người dân Bắc Triều Tiên khi mà chúng tôi đang rất cần nó , và ông ta đã cho tôi thấy lòng tốt từ những người xa lạ và sự hỗ trợ của cộng đồng quốc tế chính là những tia hy vọng mà người Bắc Triều Tiên chúng tôi đang tìm kiếm . 84 | Cuối cùng , sau một cuộc hành trình dài tôi và gia đình đã được đoàn tụ ở Hàn Quốc , nhưng đến được với tự do mới chỉ là một nửa của cuộc đấu tranh . 85 | Rất nhiều người Bắc Triều Tiên đang bị chia cắt với gia đình của họ , và khi họ đến được một đất nước khác , họ phải bắt đầu từ đầu với rất ít hoặc gần như không có tiền bạc . 86 | Vì thế chúng tôi có thể nhận sự trợ giúp từ cộng đồng quốc tế cho giáo dục , đào tạo tiếng Anh , dạy nghề , và nhiều lĩnh vực khác . 87 | Chúng tôi cũng có thể đóng vai trò như cầu nối giữa những người đang ở trong Bắc Triều Tiên với thế giới bên ngoài , bởi vì có rất nhiều người trong chúng tôi đang giữ liên lạc với những thành viên gia đình khác ở trong nước , và chúng tôi chia sẻ với họ thông tin và tiền bạc để có thể thay đổi Bắc Triều Tiên từ phía trong . 88 | Tôi đã vô cùng may mắn khi nhận được rất nhiều sự giúp đỡ và được truyền cảm hứng trong suốt cuộc đời mình , vì vậy tôi muốn mình cũng có thể chung sức để mang đến cho đất nước tôi một cơ hội để phát triển cùng với sự hỗ trợ của quốc tế . 89 | Tôi tin tưởng rằng các bạn sẽ nhìn thấy ngày càng nhiều người Bắc Triều Tiên thành công ở mọi nơi trên thế giới , kể cả trên sân khấu của TED 90 | Cám ơn các bạn . 91 | Hôm nay tôi chỉ có một yêu cầu mà thôi . 92 | Xin đừng nói với tôi rằng tôi bình thường . 93 | Bây giờ tôi muốn giới thiệu các bạn với những người em trai của tôi 94 | Remi 22 tuổi , cao ráo và rất đẹp trai , 95 | Em không nói được , nhưng em truyền đạt niềm vui theo cách mà ngay cả một số nhà hùng biện giỏi nhất cũng không thể làm được . 96 | Remi biết tình yêu là gì . 97 | Em chia sẻ nó một cách vô điều kiện dù bất kể ra sao chăng nữa . 98 | Em ấy không tham lam . Em không phân biệt màu da . 99 | Em không quan tâm về sự khác biệt tôn giáo , và hãy hiểu rằng : Em ấy chưa từng nói dối . 100 | Khi em hát những bài hát từ thời thơ ấu của chúng tôi , cố gắng nhớ những từ mà đến tôi cũng không thể , em ấy gợi nhớ cho tôi một điều rằng : chúng ta biết ít về bộ não đến như thế nào , và cái ta chưa biết phải tuyệt vời đến thế nào . 101 | -------------------------------------------------------------------------------- /nmt/testdata/iwslt15.vocab.100.en: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Rachel 5 | : 6 | The 7 | science 8 | behind 9 | a 10 | climate 11 | headline 12 | In 13 | 4 14 | minutes 15 | , 16 | atmospheric 17 | chemist 18 | provides 19 | glimpse 20 | of 21 | the 22 | massive 23 | scientific 24 | effort 25 | bold 26 | headlines 27 | on 28 | change 29 | with 30 | her 31 | team 32 | -- 33 | one 34 | thousands 35 | who 36 | contributed 37 | taking 38 | risky 39 | flight 40 | over 41 | rainforest 42 | in 43 | pursuit 44 | data 45 | key 46 | molecule 47 | . 48 | I 49 | 'd 50 | like 51 | to 52 | talk 53 | you 54 | today 55 | about 56 | scale 57 | that 58 | goes 59 | into 60 | making 61 | see 62 | paper 63 | look 64 | this 65 | when 66 | they 67 | have 68 | do 69 | and 70 | air 71 | quality 72 | or 73 | smog 74 | They 75 | are 76 | both 77 | two 78 | branches 79 | same 80 | field 81 | Recently 82 | looked 83 | Panel 84 | Climate 85 | Change 86 | IPCC 87 | put 88 | out 89 | their 90 | report 91 | state 92 | understanding 93 | system 94 | That 95 | was 96 | written 97 | by 98 | scientists 99 | from 100 | 40 101 | -------------------------------------------------------------------------------- /nmt/testdata/iwslt15.vocab.100.vi: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Khoa 5 | học 6 | đằng 7 | sau 8 | một 9 | tiêu 10 | đề 11 | về 12 | khí 13 | hậu 14 | Trong 15 | 4 16 | phút 17 | , 18 | chuyên 19 | gia 20 | hoá 21 | quyển 22 | Rachel 23 | giới 24 | thiệu 25 | sơ 26 | lược 27 | những 28 | nỗ 29 | lực 30 | khoa 31 | miệt 32 | mài 33 | táo 34 | bạo 35 | biến 36 | đổi 37 | cùng 38 | với 39 | đoàn 40 | nghiên 41 | cứu 42 | của 43 | mình 44 | -- 45 | hàng 46 | ngàn 47 | người 48 | đã 49 | cống 50 | hiến 51 | cho 52 | dự 53 | án 54 | này 55 | chuyến 56 | bay 57 | mạo 58 | hiểm 59 | qua 60 | rừng 61 | già 62 | để 63 | tìm 64 | kiếm 65 | thông 66 | tin 67 | phân 68 | tử 69 | then 70 | chốt 71 | . 72 | Tôi 73 | muốn 74 | các 75 | bạn 76 | biết 77 | sự 78 | to 79 | lớn 80 | góp 81 | phần 82 | làm 83 | nên 84 | dòng 85 | tít 86 | thường 87 | thấy 88 | trên 89 | báo 90 | Có 91 | trông 92 | như 93 | thế 94 | khi 95 | bàn 96 | và 97 | nói 98 | chất 99 | lượng 100 | không 101 | -------------------------------------------------------------------------------- /nmt/testdata/label_ref: -------------------------------------------------------------------------------- 1 | positive 2 | positive 3 | positive 4 | negative 5 | negative -------------------------------------------------------------------------------- /nmt/testdata/pred_output: -------------------------------------------------------------------------------- 1 | positive 2 | positive 3 | negative 4 | negative 5 | positive -------------------------------------------------------------------------------- /nmt/testdata/test_embed.txt: -------------------------------------------------------------------------------- 1 | some_word 1.0 2.0 3.0 4.0 2 | some_other_word 4.0 3.0 2.0 1.0 3 | -------------------------------------------------------------------------------- /nmt/testdata/test_embed_with_header.txt: -------------------------------------------------------------------------------- 1 | 2 4 2 | some_word 1.0 2.0 3.0 4.0 3 | some_other_word 4.0 3.0 2.0 1.0 4 | -------------------------------------------------------------------------------- /nmt/testdata/test_infer_file: -------------------------------------------------------------------------------- 1 | A Republic@@ an strategy to counter the re-@@ election of Obama 2 | Republic@@ an leaders justified their policy by the need to combat electoral fraud . 3 | However , the Brenn@@ an Centre considers this a my@@ th , stating that electoral fraud is rar@@ er in the United States than the number of people killed by ligh@@ tn@@ ing . 4 | Indeed , Republic@@ an lawyers identified only 300 cases of electoral fraud in the United States in a decade . 5 | One thing is certain : these new provisions will have a negative impact on vot@@ er tur@@ n-@@ out . -------------------------------------------------------------------------------- /nmt/testdata/test_infer_vocab.src: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | test1 5 | test2 6 | -------------------------------------------------------------------------------- /nmt/testdata/test_infer_vocab.tgt: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | test3 5 | test4 6 | -------------------------------------------------------------------------------- /nmt/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/nmt/0be864257a76c151eef20ea689755f08bc1faf4e/nmt/utils/__init__.py -------------------------------------------------------------------------------- /nmt/utils/common_test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Common utility functions for tests.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from tensorflow.python.ops import lookup_ops 25 | 26 | from ..utils import iterator_utils 27 | from ..utils import standard_hparams_utils 28 | 29 | 30 | def create_test_hparams(unit_type="lstm", 31 | encoder_type="uni", 32 | num_layers=4, 33 | attention="", 34 | attention_architecture=None, 35 | use_residual=False, 36 | inference_indices=None, 37 | num_translations_per_input=1, 38 | beam_width=0, 39 | init_op="uniform"): 40 | """Create training and inference test hparams.""" 41 | num_residual_layers = 0 42 | if use_residual: 43 | # TODO(rzhao): Put num_residual_layers computation logic into 44 | # `model_utils.py`, so we can also test it here. 45 | num_residual_layers = 2 46 | 47 | standard_hparams = standard_hparams_utils.create_standard_hparams() 48 | 49 | # Networks 50 | standard_hparams.num_units = 5 51 | standard_hparams.num_encoder_layers = num_layers 52 | standard_hparams.num_decoder_layers = num_layers 53 | standard_hparams.dropout = 0.5 54 | standard_hparams.unit_type = unit_type 55 | standard_hparams.encoder_type = encoder_type 56 | standard_hparams.residual = use_residual 57 | standard_hparams.num_residual_layers = num_residual_layers 58 | 59 | # Attention mechanisms 60 | standard_hparams.attention = attention 61 | standard_hparams.attention_architecture = attention_architecture 62 | 63 | # Train 64 | standard_hparams.init_op = init_op 65 | standard_hparams.num_train_steps = 1 66 | standard_hparams.decay_scheme = "" 67 | 68 | # Infer 69 | standard_hparams.tgt_max_len_infer = 100 70 | standard_hparams.beam_width = beam_width 71 | standard_hparams.num_translations_per_input = num_translations_per_input 72 | 73 | # Misc 74 | standard_hparams.forget_bias = 0.0 75 | standard_hparams.random_seed = 3 76 | standard_hparams.language_model = False 77 | 78 | # Vocab 79 | standard_hparams.src_vocab_size = 5 80 | standard_hparams.tgt_vocab_size = 5 81 | standard_hparams.eos = "" 82 | standard_hparams.sos = "" 83 | standard_hparams.src_vocab_file = "" 84 | standard_hparams.tgt_vocab_file = "" 85 | standard_hparams.src_embed_file = "" 86 | standard_hparams.tgt_embed_file = "" 87 | 88 | # For inference.py test 89 | standard_hparams.subword_option = "bpe" 90 | standard_hparams.src = "src" 91 | standard_hparams.tgt = "tgt" 92 | standard_hparams.src_max_len = 400 93 | standard_hparams.tgt_eos_id = 0 94 | standard_hparams.inference_indices = inference_indices 95 | return standard_hparams 96 | 97 | 98 | def create_test_iterator(hparams, mode): 99 | """Create test iterator.""" 100 | src_vocab_table = lookup_ops.index_table_from_tensor( 101 | tf.constant([hparams.eos, "a", "b", "c", "d"])) 102 | tgt_vocab_mapping = tf.constant([hparams.sos, hparams.eos, "a", "b", "c"]) 103 | tgt_vocab_table = lookup_ops.index_table_from_tensor(tgt_vocab_mapping) 104 | if mode == tf.contrib.learn.ModeKeys.INFER: 105 | reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_tensor( 106 | tgt_vocab_mapping) 107 | 108 | src_dataset = tf.data.Dataset.from_tensor_slices( 109 | tf.constant(["a a b b c", "a b b"])) 110 | 111 | if mode != tf.contrib.learn.ModeKeys.INFER: 112 | tgt_dataset = tf.data.Dataset.from_tensor_slices( 113 | tf.constant(["a b c b c", "a b c b"])) 114 | return ( 115 | iterator_utils.get_iterator( 116 | src_dataset=src_dataset, 117 | tgt_dataset=tgt_dataset, 118 | src_vocab_table=src_vocab_table, 119 | tgt_vocab_table=tgt_vocab_table, 120 | batch_size=hparams.batch_size, 121 | sos=hparams.sos, 122 | eos=hparams.eos, 123 | random_seed=hparams.random_seed, 124 | num_buckets=hparams.num_buckets), 125 | src_vocab_table, 126 | tgt_vocab_table) 127 | else: 128 | return ( 129 | iterator_utils.get_infer_iterator( 130 | src_dataset=src_dataset, 131 | src_vocab_table=src_vocab_table, 132 | eos=hparams.eos, 133 | batch_size=hparams.batch_size), 134 | src_vocab_table, 135 | tgt_vocab_table, 136 | reverse_tgt_vocab_table) 137 | -------------------------------------------------------------------------------- /nmt/utils/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility for evaluating various tasks, e.g., translation & summarization.""" 17 | import codecs 18 | import os 19 | import re 20 | import subprocess 21 | 22 | import tensorflow as tf 23 | 24 | from ..scripts import bleu 25 | from ..scripts import rouge 26 | 27 | 28 | __all__ = ["evaluate"] 29 | 30 | 31 | def evaluate(ref_file, trans_file, metric, subword_option=None): 32 | """Pick a metric and evaluate depending on task.""" 33 | # BLEU scores for translation task 34 | if metric.lower() == "bleu": 35 | evaluation_score = _bleu(ref_file, trans_file, 36 | subword_option=subword_option) 37 | # ROUGE scores for summarization tasks 38 | elif metric.lower() == "rouge": 39 | evaluation_score = _rouge(ref_file, trans_file, 40 | subword_option=subword_option) 41 | elif metric.lower() == "accuracy": 42 | evaluation_score = _accuracy(ref_file, trans_file) 43 | elif metric.lower() == "word_accuracy": 44 | evaluation_score = _word_accuracy(ref_file, trans_file) 45 | else: 46 | raise ValueError("Unknown metric %s" % metric) 47 | 48 | return evaluation_score 49 | 50 | 51 | def _clean(sentence, subword_option): 52 | """Clean and handle BPE or SPM outputs.""" 53 | sentence = sentence.strip() 54 | 55 | # BPE 56 | if subword_option == "bpe": 57 | sentence = re.sub("@@ ", "", sentence) 58 | 59 | # SPM 60 | elif subword_option == "spm": 61 | sentence = u"".join(sentence.split()).replace(u"\u2581", u" ").lstrip() 62 | 63 | return sentence 64 | 65 | 66 | # Follow //transconsole/localization/machine_translation/metrics/bleu_calc.py 67 | def _bleu(ref_file, trans_file, subword_option=None): 68 | """Compute BLEU scores and handling BPE.""" 69 | max_order = 4 70 | smooth = False 71 | 72 | ref_files = [ref_file] 73 | reference_text = [] 74 | for reference_filename in ref_files: 75 | with codecs.getreader("utf-8")( 76 | tf.gfile.GFile(reference_filename, "rb")) as fh: 77 | reference_text.append(fh.readlines()) 78 | 79 | per_segment_references = [] 80 | for references in zip(*reference_text): 81 | reference_list = [] 82 | for reference in references: 83 | reference = _clean(reference, subword_option) 84 | reference_list.append(reference.split(" ")) 85 | per_segment_references.append(reference_list) 86 | 87 | translations = [] 88 | with codecs.getreader("utf-8")(tf.gfile.GFile(trans_file, "rb")) as fh: 89 | for line in fh: 90 | line = _clean(line, subword_option=None) 91 | translations.append(line.split(" ")) 92 | 93 | # bleu_score, precisions, bp, ratio, translation_length, reference_length 94 | bleu_score, _, _, _, _, _ = bleu.compute_bleu( 95 | per_segment_references, translations, max_order, smooth) 96 | return 100 * bleu_score 97 | 98 | 99 | def _rouge(ref_file, summarization_file, subword_option=None): 100 | """Compute ROUGE scores and handling BPE.""" 101 | 102 | references = [] 103 | with codecs.getreader("utf-8")(tf.gfile.GFile(ref_file, "rb")) as fh: 104 | for line in fh: 105 | references.append(_clean(line, subword_option)) 106 | 107 | hypotheses = [] 108 | with codecs.getreader("utf-8")( 109 | tf.gfile.GFile(summarization_file, "rb")) as fh: 110 | for line in fh: 111 | hypotheses.append(_clean(line, subword_option=None)) 112 | 113 | rouge_score_map = rouge.rouge(hypotheses, references) 114 | return 100 * rouge_score_map["rouge_l/f_score"] 115 | 116 | 117 | def _accuracy(label_file, pred_file): 118 | """Compute accuracy, each line contains a label.""" 119 | 120 | with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "rb")) as label_fh: 121 | with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "rb")) as pred_fh: 122 | count = 0.0 123 | match = 0.0 124 | for label in label_fh: 125 | label = label.strip() 126 | pred = pred_fh.readline().strip() 127 | if label == pred: 128 | match += 1 129 | count += 1 130 | return 100 * match / count 131 | 132 | 133 | def _word_accuracy(label_file, pred_file): 134 | """Compute accuracy on per word basis.""" 135 | 136 | with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "rb")) as label_fh: 137 | with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "rb")) as pred_fh: 138 | total_acc, total_count = 0., 0. 139 | for sentence in label_fh: 140 | labels = sentence.strip().split(" ") 141 | preds = pred_fh.readline().strip().split(" ") 142 | match = 0.0 143 | for pos in range(min(len(labels), len(preds))): 144 | label = labels[pos] 145 | pred = preds[pos] 146 | if label == pred: 147 | match += 1 148 | total_acc += 100 * match / max(len(labels), len(preds)) 149 | total_count += 1 150 | return total_acc / total_count 151 | 152 | 153 | def _moses_bleu(multi_bleu_script, tgt_test, trans_file, subword_option=None): 154 | """Compute BLEU scores using Moses multi-bleu.perl script.""" 155 | 156 | # TODO(thangluong): perform rewrite using python 157 | # BPE 158 | if subword_option == "bpe": 159 | debpe_tgt_test = tgt_test + ".debpe" 160 | if not os.path.exists(debpe_tgt_test): 161 | # TODO(thangluong): not use shell=True, can be a security hazard 162 | subprocess.call("cp %s %s" % (tgt_test, debpe_tgt_test), shell=True) 163 | subprocess.call("sed s/@@ //g %s" % (debpe_tgt_test), 164 | shell=True) 165 | tgt_test = debpe_tgt_test 166 | elif subword_option == "spm": 167 | despm_tgt_test = tgt_test + ".despm" 168 | if not os.path.exists(despm_tgt_test): 169 | subprocess.call("cp %s %s" % (tgt_test, despm_tgt_test)) 170 | subprocess.call("sed s/ //g %s" % (despm_tgt_test)) 171 | subprocess.call(u"sed s/^\u2581/g %s" % (despm_tgt_test)) 172 | subprocess.call(u"sed s/\u2581/ /g %s" % (despm_tgt_test)) 173 | tgt_test = despm_tgt_test 174 | cmd = "%s %s < %s" % (multi_bleu_script, tgt_test, trans_file) 175 | 176 | # subprocess 177 | # TODO(thangluong): not use shell=True, can be a security hazard 178 | bleu_output = subprocess.check_output(cmd, shell=True) 179 | 180 | # extract BLEU score 181 | m = re.search("BLEU = (.+?),", bleu_output) 182 | bleu_score = float(m.group(1)) 183 | 184 | return bleu_score 185 | -------------------------------------------------------------------------------- /nmt/utils/evaluation_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for evaluation_utils.py.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from ..utils import evaluation_utils 25 | 26 | 27 | class EvaluationUtilsTest(tf.test.TestCase): 28 | 29 | def testEvaluate(self): 30 | output = "nmt/testdata/deen_output" 31 | ref_bpe = "nmt/testdata/deen_ref_bpe" 32 | ref_spm = "nmt/testdata/deen_ref_spm" 33 | 34 | expected_bleu_score = 22.5855084573 35 | expected_rouge_score = 50.8429782599 36 | 37 | bpe_bleu_score = evaluation_utils.evaluate( 38 | ref_bpe, output, "bleu", "bpe") 39 | bpe_rouge_score = evaluation_utils.evaluate( 40 | ref_bpe, output, "rouge", "bpe") 41 | 42 | self.assertAlmostEqual(expected_bleu_score, bpe_bleu_score) 43 | self.assertAlmostEqual(expected_rouge_score, bpe_rouge_score) 44 | 45 | spm_bleu_score = evaluation_utils.evaluate( 46 | ref_spm, output, "bleu", "spm") 47 | spm_rouge_score = evaluation_utils.evaluate( 48 | ref_spm, output, "rouge", "spm") 49 | 50 | self.assertAlmostEqual(expected_rouge_score, spm_rouge_score) 51 | self.assertAlmostEqual(expected_bleu_score, spm_bleu_score) 52 | 53 | def testAccuracy(self): 54 | pred_output = "nmt/testdata/pred_output" 55 | label_ref = "nmt/testdata/label_ref" 56 | 57 | expected_accuracy_score = 60.00 58 | 59 | accuracy_score = evaluation_utils.evaluate( 60 | label_ref, pred_output, "accuracy") 61 | self.assertAlmostEqual(expected_accuracy_score, accuracy_score) 62 | 63 | def testWordAccuracy(self): 64 | pred_output = "nmt/testdata/pred_output" 65 | label_ref = "nmt/testdata/label_ref" 66 | 67 | expected_word_accuracy_score = 60.00 68 | 69 | word_accuracy_score = evaluation_utils.evaluate( 70 | label_ref, pred_output, "word_accuracy") 71 | self.assertAlmostEqual(expected_word_accuracy_score, word_accuracy_score) 72 | 73 | 74 | if __name__ == "__main__": 75 | tf.test.main() 76 | -------------------------------------------------------------------------------- /nmt/utils/iterator_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """For loading data into NMT models.""" 16 | from __future__ import print_function 17 | 18 | import collections 19 | 20 | import tensorflow as tf 21 | 22 | from ..utils import vocab_utils 23 | 24 | 25 | __all__ = ["BatchedInput", "get_iterator", "get_infer_iterator"] 26 | 27 | 28 | # NOTE(ebrevdo): When we subclass this, instances' __dict__ becomes empty. 29 | class BatchedInput( 30 | collections.namedtuple("BatchedInput", 31 | ("initializer", "source", "target_input", 32 | "target_output", "source_sequence_length", 33 | "target_sequence_length"))): 34 | pass 35 | 36 | 37 | def get_infer_iterator(src_dataset, 38 | src_vocab_table, 39 | batch_size, 40 | eos, 41 | src_max_len=None, 42 | use_char_encode=False): 43 | if use_char_encode: 44 | src_eos_id = vocab_utils.EOS_CHAR_ID 45 | else: 46 | src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) 47 | src_dataset = src_dataset.map(lambda src: tf.string_split([src]).values) 48 | 49 | if src_max_len: 50 | src_dataset = src_dataset.map(lambda src: src[:src_max_len]) 51 | 52 | if use_char_encode: 53 | # Convert the word strings to character ids 54 | src_dataset = src_dataset.map( 55 | lambda src: tf.reshape(vocab_utils.tokens_to_bytes(src), [-1])) 56 | else: 57 | # Convert the word strings to ids 58 | src_dataset = src_dataset.map( 59 | lambda src: tf.cast(src_vocab_table.lookup(src), tf.int32)) 60 | 61 | # Add in the word counts. 62 | if use_char_encode: 63 | src_dataset = src_dataset.map( 64 | lambda src: (src, 65 | tf.to_int32( 66 | tf.size(src) / vocab_utils.DEFAULT_CHAR_MAXLEN))) 67 | else: 68 | src_dataset = src_dataset.map(lambda src: (src, tf.size(src))) 69 | 70 | def batching_func(x): 71 | return x.padded_batch( 72 | batch_size, 73 | # The entry is the source line rows; 74 | # this has unknown-length vectors. The last entry is 75 | # the source row size; this is a scalar. 76 | padded_shapes=( 77 | tf.TensorShape([None]), # src 78 | tf.TensorShape([])), # src_len 79 | # Pad the source sequences with eos tokens. 80 | # (Though notice we don't generally need to do this since 81 | # later on we will be masking out calculations past the true sequence. 82 | padding_values=( 83 | src_eos_id, # src 84 | 0)) # src_len -- unused 85 | 86 | batched_dataset = batching_func(src_dataset) 87 | batched_iter = batched_dataset.make_initializable_iterator() 88 | (src_ids, src_seq_len) = batched_iter.get_next() 89 | return BatchedInput( 90 | initializer=batched_iter.initializer, 91 | source=src_ids, 92 | target_input=None, 93 | target_output=None, 94 | source_sequence_length=src_seq_len, 95 | target_sequence_length=None) 96 | 97 | 98 | def get_iterator(src_dataset, 99 | tgt_dataset, 100 | src_vocab_table, 101 | tgt_vocab_table, 102 | batch_size, 103 | sos, 104 | eos, 105 | random_seed, 106 | num_buckets, 107 | src_max_len=None, 108 | tgt_max_len=None, 109 | num_parallel_calls=4, 110 | output_buffer_size=None, 111 | skip_count=None, 112 | num_shards=1, 113 | shard_index=0, 114 | reshuffle_each_iteration=True, 115 | use_char_encode=False): 116 | if not output_buffer_size: 117 | output_buffer_size = batch_size * 1000 118 | 119 | if use_char_encode: 120 | src_eos_id = vocab_utils.EOS_CHAR_ID 121 | else: 122 | src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) 123 | 124 | tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32) 125 | tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32) 126 | 127 | src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset)) 128 | 129 | src_tgt_dataset = src_tgt_dataset.shard(num_shards, shard_index) 130 | if skip_count is not None: 131 | src_tgt_dataset = src_tgt_dataset.skip(skip_count) 132 | 133 | src_tgt_dataset = src_tgt_dataset.shuffle( 134 | output_buffer_size, random_seed, reshuffle_each_iteration) 135 | 136 | src_tgt_dataset = src_tgt_dataset.map( 137 | lambda src, tgt: ( 138 | tf.string_split([src]).values, tf.string_split([tgt]).values), 139 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 140 | 141 | # Filter zero length input sequences. 142 | src_tgt_dataset = src_tgt_dataset.filter( 143 | lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0)) 144 | 145 | if src_max_len: 146 | src_tgt_dataset = src_tgt_dataset.map( 147 | lambda src, tgt: (src[:src_max_len], tgt), 148 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 149 | if tgt_max_len: 150 | src_tgt_dataset = src_tgt_dataset.map( 151 | lambda src, tgt: (src, tgt[:tgt_max_len]), 152 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 153 | 154 | # Convert the word strings to ids. Word strings that are not in the 155 | # vocab get the lookup table's default_value integer. 156 | if use_char_encode: 157 | src_tgt_dataset = src_tgt_dataset.map( 158 | lambda src, tgt: (tf.reshape(vocab_utils.tokens_to_bytes(src), [-1]), 159 | tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)), 160 | num_parallel_calls=num_parallel_calls) 161 | else: 162 | src_tgt_dataset = src_tgt_dataset.map( 163 | lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32), 164 | tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)), 165 | num_parallel_calls=num_parallel_calls) 166 | 167 | src_tgt_dataset = src_tgt_dataset.prefetch(output_buffer_size) 168 | # Create a tgt_input prefixed with and a tgt_output suffixed with . 169 | src_tgt_dataset = src_tgt_dataset.map( 170 | lambda src, tgt: (src, 171 | tf.concat(([tgt_sos_id], tgt), 0), 172 | tf.concat((tgt, [tgt_eos_id]), 0)), 173 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 174 | # Add in sequence lengths. 175 | if use_char_encode: 176 | src_tgt_dataset = src_tgt_dataset.map( 177 | lambda src, tgt_in, tgt_out: ( 178 | src, tgt_in, tgt_out, 179 | tf.to_int32(tf.size(src) / vocab_utils.DEFAULT_CHAR_MAXLEN), 180 | tf.size(tgt_in)), 181 | num_parallel_calls=num_parallel_calls) 182 | else: 183 | src_tgt_dataset = src_tgt_dataset.map( 184 | lambda src, tgt_in, tgt_out: ( 185 | src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)), 186 | num_parallel_calls=num_parallel_calls) 187 | 188 | src_tgt_dataset = src_tgt_dataset.prefetch(output_buffer_size) 189 | 190 | # Bucket by source sequence length (buckets for lengths 0-9, 10-19, ...) 191 | def batching_func(x): 192 | return x.padded_batch( 193 | batch_size, 194 | # The first three entries are the source and target line rows; 195 | # these have unknown-length vectors. The last two entries are 196 | # the source and target row sizes; these are scalars. 197 | padded_shapes=( 198 | tf.TensorShape([None]), # src 199 | tf.TensorShape([None]), # tgt_input 200 | tf.TensorShape([None]), # tgt_output 201 | tf.TensorShape([]), # src_len 202 | tf.TensorShape([])), # tgt_len 203 | # Pad the source and target sequences with eos tokens. 204 | # (Though notice we don't generally need to do this since 205 | # later on we will be masking out calculations past the true sequence. 206 | padding_values=( 207 | src_eos_id, # src 208 | tgt_eos_id, # tgt_input 209 | tgt_eos_id, # tgt_output 210 | 0, # src_len -- unused 211 | 0)) # tgt_len -- unused 212 | 213 | if num_buckets > 1: 214 | 215 | def key_func(unused_1, unused_2, unused_3, src_len, tgt_len): 216 | # Calculate bucket_width by maximum source sequence length. 217 | # Pairs with length [0, bucket_width) go to bucket 0, length 218 | # [bucket_width, 2 * bucket_width) go to bucket 1, etc. Pairs with length 219 | # over ((num_bucket-1) * bucket_width) words all go into the last bucket. 220 | if src_max_len: 221 | bucket_width = (src_max_len + num_buckets - 1) // num_buckets 222 | else: 223 | bucket_width = 10 224 | 225 | # Bucket sentence pairs by the length of their source sentence and target 226 | # sentence. 227 | bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width) 228 | return tf.to_int64(tf.minimum(num_buckets, bucket_id)) 229 | 230 | def reduce_func(unused_key, windowed_data): 231 | return batching_func(windowed_data) 232 | 233 | batched_dataset = src_tgt_dataset.apply( 234 | tf.contrib.data.group_by_window( 235 | key_func=key_func, reduce_func=reduce_func, window_size=batch_size)) 236 | 237 | else: 238 | batched_dataset = batching_func(src_tgt_dataset) 239 | batched_iter = batched_dataset.make_initializable_iterator() 240 | (src_ids, tgt_input_ids, tgt_output_ids, src_seq_len, 241 | tgt_seq_len) = (batched_iter.get_next()) 242 | return BatchedInput( 243 | initializer=batched_iter.initializer, 244 | source=src_ids, 245 | target_input=tgt_input_ids, 246 | target_output=tgt_output_ids, 247 | source_sequence_length=src_seq_len, 248 | target_sequence_length=tgt_seq_len) 249 | -------------------------------------------------------------------------------- /nmt/utils/iterator_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for iterator_utils.py""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from tensorflow.python.ops import lookup_ops 25 | 26 | from ..utils import iterator_utils 27 | 28 | 29 | class IteratorUtilsTest(tf.test.TestCase): 30 | 31 | def testGetIterator(self): 32 | tf.set_random_seed(1) 33 | tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( 34 | tf.constant(["a", "b", "c", "eos", "sos"])) 35 | src_dataset = tf.data.Dataset.from_tensor_slices( 36 | tf.constant(["f e a g", "c c a", "d", "c a"])) 37 | tgt_dataset = tf.data.Dataset.from_tensor_slices( 38 | tf.constant(["c c", "a b", "", "b c"])) 39 | hparams = tf.contrib.training.HParams( 40 | random_seed=3, 41 | num_buckets=5, 42 | eos="eos", 43 | sos="sos") 44 | batch_size = 2 45 | src_max_len = 3 46 | iterator = iterator_utils.get_iterator( 47 | src_dataset=src_dataset, 48 | tgt_dataset=tgt_dataset, 49 | src_vocab_table=src_vocab_table, 50 | tgt_vocab_table=tgt_vocab_table, 51 | batch_size=batch_size, 52 | sos=hparams.sos, 53 | eos=hparams.eos, 54 | random_seed=hparams.random_seed, 55 | num_buckets=hparams.num_buckets, 56 | src_max_len=src_max_len, 57 | reshuffle_each_iteration=False) 58 | table_initializer = tf.tables_initializer() 59 | source = iterator.source 60 | target_input = iterator.target_input 61 | target_output = iterator.target_output 62 | src_seq_len = iterator.source_sequence_length 63 | tgt_seq_len = iterator.target_sequence_length 64 | self.assertEqual([None, None], source.shape.as_list()) 65 | self.assertEqual([None, None], target_input.shape.as_list()) 66 | self.assertEqual([None, None], target_output.shape.as_list()) 67 | self.assertEqual([None], src_seq_len.shape.as_list()) 68 | self.assertEqual([None], tgt_seq_len.shape.as_list()) 69 | with self.test_session() as sess: 70 | sess.run(table_initializer) 71 | sess.run(iterator.initializer) 72 | 73 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 74 | sess.run((source, src_seq_len, target_input, target_output, 75 | tgt_seq_len))) 76 | self.assertAllEqual( 77 | [[2, 0, 3], # c a eos -- eos is padding 78 | [-1, -1, 0]], # "f" == unknown, "e" == unknown, a 79 | source_v) 80 | self.assertAllEqual([2, 3], src_len_v) 81 | self.assertAllEqual( 82 | [[4, 1, 2], # sos b c 83 | [4, 2, 2]], # sos c c 84 | target_input_v) 85 | self.assertAllEqual( 86 | [[1, 2, 3], # b c eos 87 | [2, 2, 3]], # c c eos 88 | target_output_v) 89 | self.assertAllEqual([3, 3], tgt_len_v) 90 | 91 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 92 | sess.run((source, src_seq_len, target_input, target_output, 93 | tgt_seq_len))) 94 | self.assertAllEqual( 95 | [[2, 2, 0]], # c c a 96 | source_v) 97 | self.assertAllEqual([3], src_len_v) 98 | self.assertAllEqual( 99 | [[4, 0, 1]], # sos a b 100 | target_input_v) 101 | self.assertAllEqual( 102 | [[0, 1, 3]], # a b eos 103 | target_output_v) 104 | self.assertAllEqual([3], tgt_len_v) 105 | 106 | with self.assertRaisesOpError("End of sequence"): 107 | sess.run(source) 108 | 109 | def testGetIteratorWithShard(self): 110 | tf.set_random_seed(1) 111 | tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( 112 | tf.constant(["a", "b", "c", "eos", "sos"])) 113 | src_dataset = tf.data.Dataset.from_tensor_slices( 114 | tf.constant(["c c a", "f e a g", "d", "c a"])) 115 | tgt_dataset = tf.data.Dataset.from_tensor_slices( 116 | tf.constant(["a b", "c c", "", "b c"])) 117 | hparams = tf.contrib.training.HParams( 118 | random_seed=3, 119 | num_buckets=5, 120 | eos="eos", 121 | sos="sos") 122 | batch_size = 2 123 | src_max_len = 3 124 | iterator = iterator_utils.get_iterator( 125 | src_dataset=src_dataset, 126 | tgt_dataset=tgt_dataset, 127 | src_vocab_table=src_vocab_table, 128 | tgt_vocab_table=tgt_vocab_table, 129 | batch_size=batch_size, 130 | sos=hparams.sos, 131 | eos=hparams.eos, 132 | random_seed=hparams.random_seed, 133 | num_buckets=hparams.num_buckets, 134 | src_max_len=src_max_len, 135 | num_shards=2, 136 | shard_index=1, 137 | reshuffle_each_iteration=False) 138 | table_initializer = tf.tables_initializer() 139 | source = iterator.source 140 | target_input = iterator.target_input 141 | target_output = iterator.target_output 142 | src_seq_len = iterator.source_sequence_length 143 | tgt_seq_len = iterator.target_sequence_length 144 | self.assertEqual([None, None], source.shape.as_list()) 145 | self.assertEqual([None, None], target_input.shape.as_list()) 146 | self.assertEqual([None, None], target_output.shape.as_list()) 147 | self.assertEqual([None], src_seq_len.shape.as_list()) 148 | self.assertEqual([None], tgt_seq_len.shape.as_list()) 149 | with self.test_session() as sess: 150 | sess.run(table_initializer) 151 | sess.run(iterator.initializer) 152 | 153 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 154 | sess.run((source, src_seq_len, target_input, target_output, 155 | tgt_seq_len))) 156 | self.assertAllEqual( 157 | [[2, 0, 3], # c a eos -- eos is padding 158 | [-1, -1, 0]], # "f" == unknown, "e" == unknown, a 159 | source_v) 160 | self.assertAllEqual([2, 3], src_len_v) 161 | self.assertAllEqual( 162 | [[4, 1, 2], # sos b c 163 | [4, 2, 2]], # sos c c 164 | target_input_v) 165 | self.assertAllEqual( 166 | [[1, 2, 3], # b c eos 167 | [2, 2, 3]], # c c eos 168 | target_output_v) 169 | self.assertAllEqual([3, 3], tgt_len_v) 170 | 171 | with self.assertRaisesOpError("End of sequence"): 172 | sess.run(source) 173 | 174 | def testGetIteratorWithSkipCount(self): 175 | tf.set_random_seed(1) 176 | tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( 177 | tf.constant(["a", "b", "c", "eos", "sos"])) 178 | src_dataset = tf.data.Dataset.from_tensor_slices( 179 | tf.constant(["c a", "c c a", "d", "f e a g"])) 180 | tgt_dataset = tf.data.Dataset.from_tensor_slices( 181 | tf.constant(["b c", "a b", "", "c c"])) 182 | hparams = tf.contrib.training.HParams( 183 | random_seed=3, 184 | num_buckets=5, 185 | eos="eos", 186 | sos="sos") 187 | batch_size = 2 188 | src_max_len = 3 189 | skip_count = tf.placeholder(shape=(), dtype=tf.int64) 190 | iterator = iterator_utils.get_iterator( 191 | src_dataset=src_dataset, 192 | tgt_dataset=tgt_dataset, 193 | src_vocab_table=src_vocab_table, 194 | tgt_vocab_table=tgt_vocab_table, 195 | batch_size=batch_size, 196 | sos=hparams.sos, 197 | eos=hparams.eos, 198 | random_seed=hparams.random_seed, 199 | num_buckets=hparams.num_buckets, 200 | src_max_len=src_max_len, 201 | skip_count=skip_count, 202 | reshuffle_each_iteration=False) 203 | table_initializer = tf.tables_initializer() 204 | source = iterator.source 205 | target_input = iterator.target_input 206 | target_output = iterator.target_output 207 | src_seq_len = iterator.source_sequence_length 208 | tgt_seq_len = iterator.target_sequence_length 209 | self.assertEqual([None, None], source.shape.as_list()) 210 | self.assertEqual([None, None], target_input.shape.as_list()) 211 | self.assertEqual([None, None], target_output.shape.as_list()) 212 | self.assertEqual([None], src_seq_len.shape.as_list()) 213 | self.assertEqual([None], tgt_seq_len.shape.as_list()) 214 | with self.test_session() as sess: 215 | sess.run(table_initializer) 216 | sess.run(iterator.initializer, feed_dict={skip_count: 3}) 217 | 218 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 219 | sess.run((source, src_seq_len, target_input, target_output, 220 | tgt_seq_len))) 221 | self.assertAllEqual( 222 | [[-1, -1, 0]], # "f" == unknown, "e" == unknown, a 223 | source_v) 224 | self.assertAllEqual([3], src_len_v) 225 | self.assertAllEqual( 226 | [[4, 2, 2]], # sos c c 227 | target_input_v) 228 | self.assertAllEqual( 229 | [[2, 2, 3]], # c c eos 230 | target_output_v) 231 | self.assertAllEqual([3], tgt_len_v) 232 | 233 | with self.assertRaisesOpError("End of sequence"): 234 | sess.run(source) 235 | 236 | # Re-init iterator with skip_count=0. 237 | sess.run(iterator.initializer, feed_dict={skip_count: 0}) 238 | 239 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 240 | sess.run((source, src_seq_len, target_input, target_output, 241 | tgt_seq_len))) 242 | self.assertAllEqual( 243 | [[-1, -1, 0], # "f" == unknown, "e" == unknown, a 244 | [2, 0, 3]], # c a eos -- eos is padding 245 | source_v) 246 | self.assertAllEqual([3, 2], src_len_v) 247 | self.assertAllEqual( 248 | [[4, 2, 2], # sos c c 249 | [4, 1, 2]], # sos b c 250 | target_input_v) 251 | self.assertAllEqual( 252 | [[2, 2, 3], # c c eos 253 | [1, 2, 3]], # b c eos 254 | target_output_v) 255 | self.assertAllEqual([3, 3], tgt_len_v) 256 | 257 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 258 | sess.run((source, src_seq_len, target_input, target_output, 259 | tgt_seq_len))) 260 | self.assertAllEqual( 261 | [[2, 2, 0]], # c c a 262 | source_v) 263 | self.assertAllEqual([3], src_len_v) 264 | self.assertAllEqual( 265 | [[4, 0, 1]], # sos a b 266 | target_input_v) 267 | self.assertAllEqual( 268 | [[0, 1, 3]], # a b eos 269 | target_output_v) 270 | self.assertAllEqual([3], tgt_len_v) 271 | 272 | with self.assertRaisesOpError("End of sequence"): 273 | sess.run(source) 274 | 275 | 276 | def testGetInferIterator(self): 277 | src_vocab_table = lookup_ops.index_table_from_tensor( 278 | tf.constant(["a", "b", "c", "eos", "sos"])) 279 | src_dataset = tf.data.Dataset.from_tensor_slices( 280 | tf.constant(["c c a", "c a", "d", "f e a g"])) 281 | hparams = tf.contrib.training.HParams( 282 | random_seed=3, 283 | eos="eos", 284 | sos="sos") 285 | batch_size = 2 286 | src_max_len = 3 287 | iterator = iterator_utils.get_infer_iterator( 288 | src_dataset=src_dataset, 289 | src_vocab_table=src_vocab_table, 290 | batch_size=batch_size, 291 | eos=hparams.eos, 292 | src_max_len=src_max_len) 293 | table_initializer = tf.tables_initializer() 294 | source = iterator.source 295 | seq_len = iterator.source_sequence_length 296 | self.assertEqual([None, None], source.shape.as_list()) 297 | self.assertEqual([None], seq_len.shape.as_list()) 298 | with self.test_session() as sess: 299 | sess.run(table_initializer) 300 | sess.run(iterator.initializer) 301 | 302 | (source_v, seq_len_v) = sess.run((source, seq_len)) 303 | self.assertAllEqual( 304 | [[2, 2, 0], # c c a 305 | [2, 0, 3]], # c a eos 306 | source_v) 307 | self.assertAllEqual([3, 2], seq_len_v) 308 | 309 | (source_v, seq_len_v) = sess.run((source, seq_len)) 310 | self.assertAllEqual( 311 | [[-1, 3, 3], # "d" == unknown, eos eos 312 | [-1, -1, 0]], # "f" == unknown, "e" == unknown, a 313 | source_v) 314 | self.assertAllEqual([1, 3], seq_len_v) 315 | 316 | with self.assertRaisesOpError("End of sequence"): 317 | sess.run((source, seq_len)) 318 | 319 | 320 | if __name__ == "__main__": 321 | tf.test.main() 322 | -------------------------------------------------------------------------------- /nmt/utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Generally useful utility functions.""" 17 | from __future__ import print_function 18 | 19 | import codecs 20 | import collections 21 | import json 22 | import math 23 | import os 24 | import sys 25 | import time 26 | from distutils import version 27 | 28 | import numpy as np 29 | import six 30 | import tensorflow as tf 31 | 32 | 33 | def check_tensorflow_version(): 34 | # LINT.IfChange 35 | min_tf_version = "1.12.0" 36 | # LINT.ThenChange(/nmt/copy.bara.sky) 37 | if (version.LooseVersion(tf.__version__) < 38 | version.LooseVersion(min_tf_version)): 39 | raise EnvironmentError("Tensorflow version must >= %s" % min_tf_version) 40 | 41 | 42 | def safe_exp(value): 43 | """Exponentiation with catching of overflow error.""" 44 | try: 45 | ans = math.exp(value) 46 | except OverflowError: 47 | ans = float("inf") 48 | return ans 49 | 50 | 51 | def print_time(s, start_time): 52 | """Take a start time, print elapsed duration, and return a new time.""" 53 | print("%s, time %ds, %s." % (s, (time.time() - start_time), time.ctime())) 54 | sys.stdout.flush() 55 | return time.time() 56 | 57 | 58 | def print_out(s, f=None, new_line=True): 59 | """Similar to print but with support to flush and output to a file.""" 60 | if isinstance(s, bytes): 61 | s = s.decode("utf-8") 62 | 63 | if f: 64 | f.write(s.encode("utf-8")) 65 | if new_line: 66 | f.write(b"\n") 67 | 68 | # stdout 69 | if six.PY2: 70 | sys.stdout.write(s.encode("utf-8")) 71 | else: 72 | sys.stdout.buffer.write(s.encode("utf-8")) 73 | 74 | if new_line: 75 | sys.stdout.write("\n") 76 | sys.stdout.flush() 77 | 78 | 79 | def print_hparams(hparams, skip_patterns=None, header=None): 80 | """Print hparams, can skip keys based on pattern.""" 81 | if header: print_out("%s" % header) 82 | values = hparams.values() 83 | for key in sorted(values.keys()): 84 | if not skip_patterns or all( 85 | [skip_pattern not in key for skip_pattern in skip_patterns]): 86 | print_out(" %s=%s" % (key, str(values[key]))) 87 | 88 | 89 | def load_hparams(model_dir): 90 | """Load hparams from an existing model directory.""" 91 | hparams_file = os.path.join(model_dir, "hparams") 92 | if tf.gfile.Exists(hparams_file): 93 | print_out("# Loading hparams from %s" % hparams_file) 94 | with codecs.getreader("utf-8")(tf.gfile.GFile(hparams_file, "rb")) as f: 95 | try: 96 | hparams_values = json.load(f) 97 | hparams = tf.contrib.training.HParams(**hparams_values) 98 | except ValueError: 99 | print_out(" can't load hparams file") 100 | return None 101 | return hparams 102 | else: 103 | return None 104 | 105 | 106 | def maybe_parse_standard_hparams(hparams, hparams_path): 107 | """Override hparams values with existing standard hparams config.""" 108 | if hparams_path and tf.gfile.Exists(hparams_path): 109 | print_out("# Loading standard hparams from %s" % hparams_path) 110 | with codecs.getreader("utf-8")(tf.gfile.GFile(hparams_path, "rb")) as f: 111 | hparams.parse_json(f.read()) 112 | return hparams 113 | 114 | 115 | def save_hparams(out_dir, hparams): 116 | """Save hparams.""" 117 | hparams_file = os.path.join(out_dir, "hparams") 118 | print_out(" saving hparams to %s" % hparams_file) 119 | with codecs.getwriter("utf-8")(tf.gfile.GFile(hparams_file, "wb")) as f: 120 | f.write(hparams.to_json(indent=4, sort_keys=True)) 121 | 122 | 123 | def debug_tensor(s, msg=None, summarize=10): 124 | """Print the shape and value of a tensor at test time. Return a new tensor.""" 125 | if not msg: 126 | msg = s.name 127 | return tf.Print(s, [tf.shape(s), s], msg + " ", summarize=summarize) 128 | 129 | 130 | def add_summary(summary_writer, global_step, tag, value): 131 | """Add a new summary to the current summary_writer. 132 | Useful to log things that are not part of the training graph, e.g., tag=BLEU. 133 | """ 134 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 135 | summary_writer.add_summary(summary, global_step) 136 | 137 | 138 | def get_config_proto(log_device_placement=False, allow_soft_placement=True, 139 | num_intra_threads=0, num_inter_threads=0): 140 | # GPU options: 141 | # https://www.tensorflow.org/versions/r0.10/how_tos/using_gpu/index.html 142 | config_proto = tf.ConfigProto( 143 | log_device_placement=log_device_placement, 144 | allow_soft_placement=allow_soft_placement) 145 | config_proto.gpu_options.allow_growth = True 146 | 147 | # CPU threads options 148 | if num_intra_threads: 149 | config_proto.intra_op_parallelism_threads = num_intra_threads 150 | if num_inter_threads: 151 | config_proto.inter_op_parallelism_threads = num_inter_threads 152 | 153 | return config_proto 154 | 155 | 156 | def format_text(words): 157 | """Convert a sequence words into sentence.""" 158 | if (not hasattr(words, "__len__") and # for numpy array 159 | not isinstance(words, collections.Iterable)): 160 | words = [words] 161 | return b" ".join(words) 162 | 163 | 164 | def format_bpe_text(symbols, delimiter=b"@@"): 165 | """Convert a sequence of bpe words into sentence.""" 166 | words = [] 167 | word = b"" 168 | if isinstance(symbols, str): 169 | symbols = symbols.encode() 170 | delimiter_len = len(delimiter) 171 | for symbol in symbols: 172 | if len(symbol) >= delimiter_len and symbol[-delimiter_len:] == delimiter: 173 | word += symbol[:-delimiter_len] 174 | else: # end of a word 175 | word += symbol 176 | words.append(word) 177 | word = b"" 178 | return b" ".join(words) 179 | 180 | 181 | def format_spm_text(symbols): 182 | """Decode a text in SPM (https://github.com/google/sentencepiece) format.""" 183 | return u"".join(format_text(symbols).decode("utf-8").split()).replace( 184 | u"\u2581", u" ").strip().encode("utf-8") 185 | -------------------------------------------------------------------------------- /nmt/utils/misc_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for vocab_utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from ..utils import misc_utils 25 | 26 | 27 | class MiscUtilsTest(tf.test.TestCase): 28 | 29 | def testFormatBpeText(self): 30 | bpe_line = ( 31 | b"En@@ ough to make already reluc@@ tant men hesitate to take screening" 32 | b" tests ." 33 | ) 34 | expected_result = ( 35 | b"Enough to make already reluctant men hesitate to take screening tests" 36 | b" ." 37 | ) 38 | self.assertEqual(expected_result, 39 | misc_utils.format_bpe_text(bpe_line.split(b" "))) 40 | 41 | def testFormatSPMText(self): 42 | spm_line = u"\u2581This \u2581is \u2581a \u2581 te st .".encode("utf-8") 43 | expected_result = b"This is a test." 44 | self.assertEqual(expected_result, 45 | misc_utils.format_spm_text(spm_line.split(b" "))) 46 | 47 | 48 | if __name__ == "__main__": 49 | tf.test.main() 50 | -------------------------------------------------------------------------------- /nmt/utils/nmt_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions specifically for NMT.""" 17 | from __future__ import print_function 18 | 19 | import codecs 20 | import time 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from ..utils import evaluation_utils 25 | from ..utils import misc_utils as utils 26 | 27 | __all__ = ["decode_and_evaluate", "get_translation"] 28 | 29 | 30 | def decode_and_evaluate(name, 31 | model, 32 | sess, 33 | trans_file, 34 | ref_file, 35 | metrics, 36 | subword_option, 37 | beam_width, 38 | tgt_eos, 39 | num_translations_per_input=1, 40 | decode=True, 41 | infer_mode="greedy"): 42 | """Decode a test set and compute a score according to the evaluation task.""" 43 | # Decode 44 | if decode: 45 | utils.print_out(" decoding to output %s" % trans_file) 46 | 47 | start_time = time.time() 48 | num_sentences = 0 49 | with codecs.getwriter("utf-8")( 50 | tf.gfile.GFile(trans_file, mode="wb")) as trans_f: 51 | trans_f.write("") # Write empty string to ensure file is created. 52 | 53 | if infer_mode == "greedy": 54 | num_translations_per_input = 1 55 | elif infer_mode == "beam_search": 56 | num_translations_per_input = min(num_translations_per_input, beam_width) 57 | 58 | while True: 59 | try: 60 | nmt_outputs, _ = model.decode(sess) 61 | if infer_mode != "beam_search": 62 | nmt_outputs = np.expand_dims(nmt_outputs, 0) 63 | 64 | batch_size = nmt_outputs.shape[1] 65 | num_sentences += batch_size 66 | 67 | for sent_id in range(batch_size): 68 | for beam_id in range(num_translations_per_input): 69 | translation = get_translation( 70 | nmt_outputs[beam_id], 71 | sent_id, 72 | tgt_eos=tgt_eos, 73 | subword_option=subword_option) 74 | trans_f.write((translation + b"\n").decode("utf-8")) 75 | except tf.errors.OutOfRangeError: 76 | utils.print_time( 77 | " done, num sentences %d, num translations per input %d" % 78 | (num_sentences, num_translations_per_input), start_time) 79 | break 80 | 81 | # Evaluation 82 | evaluation_scores = {} 83 | if ref_file and tf.gfile.Exists(trans_file): 84 | for metric in metrics: 85 | score = evaluation_utils.evaluate( 86 | ref_file, 87 | trans_file, 88 | metric, 89 | subword_option=subword_option) 90 | evaluation_scores[metric] = score 91 | utils.print_out(" %s %s: %.1f" % (metric, name, score)) 92 | 93 | return evaluation_scores 94 | 95 | 96 | def get_translation(nmt_outputs, sent_id, tgt_eos, subword_option): 97 | """Given batch decoding outputs, select a sentence and turn to text.""" 98 | if tgt_eos: tgt_eos = tgt_eos.encode("utf-8") 99 | # Select a sentence 100 | output = nmt_outputs[sent_id, :].tolist() 101 | 102 | # If there is an eos symbol in outputs, cut them at that point. 103 | if tgt_eos and tgt_eos in output: 104 | output = output[:output.index(tgt_eos)] 105 | 106 | if subword_option == "bpe": # BPE 107 | translation = utils.format_bpe_text(output) 108 | elif subword_option == "spm": # SPM 109 | translation = utils.format_spm_text(output) 110 | else: 111 | translation = utils.format_text(output) 112 | 113 | return translation 114 | -------------------------------------------------------------------------------- /nmt/utils/standard_hparams_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """standard hparams utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | 25 | def create_standard_hparams(): 26 | return tf.contrib.training.HParams( 27 | # Data 28 | src="", 29 | tgt="", 30 | train_prefix="", 31 | dev_prefix="", 32 | test_prefix="", 33 | vocab_prefix="", 34 | embed_prefix="", 35 | out_dir="", 36 | 37 | # Networks 38 | num_units=512, 39 | num_encoder_layers=2, 40 | num_decoder_layers=2, 41 | dropout=0.2, 42 | unit_type="lstm", 43 | encoder_type="bi", 44 | residual=False, 45 | time_major=True, 46 | num_embeddings_partitions=0, 47 | num_enc_emb_partitions=0, 48 | num_dec_emb_partitions=0, 49 | 50 | # Attention mechanisms 51 | attention="scaled_luong", 52 | attention_architecture="standard", 53 | output_attention=True, 54 | pass_hidden_state=True, 55 | 56 | # Train 57 | optimizer="sgd", 58 | batch_size=128, 59 | init_op="uniform", 60 | init_weight=0.1, 61 | max_gradient_norm=5.0, 62 | learning_rate=1.0, 63 | warmup_steps=0, 64 | warmup_scheme="t2t", 65 | decay_scheme="luong234", 66 | colocate_gradients_with_ops=True, 67 | num_train_steps=12000, 68 | num_sampled_softmax=0, 69 | 70 | # Data constraints 71 | num_buckets=5, 72 | max_train=0, 73 | src_max_len=50, 74 | tgt_max_len=50, 75 | src_max_len_infer=0, 76 | tgt_max_len_infer=0, 77 | 78 | # Data format 79 | sos="", 80 | eos="", 81 | subword_option="", 82 | use_char_encode=False, 83 | check_special_token=True, 84 | 85 | # Misc 86 | forget_bias=1.0, 87 | num_gpus=1, 88 | epoch_step=0, # record where we were within an epoch. 89 | steps_per_stats=100, 90 | steps_per_external_eval=0, 91 | share_vocab=False, 92 | metrics=["bleu"], 93 | log_device_placement=False, 94 | random_seed=None, 95 | # only enable beam search during inference when beam_width > 0. 96 | beam_width=0, 97 | length_penalty_weight=0.0, 98 | coverage_penalty_weight=0.0, 99 | override_loaded_hparams=True, 100 | num_keep_ckpts=5, 101 | avg_ckpts=False, 102 | 103 | # For inference 104 | inference_indices=None, 105 | infer_batch_size=32, 106 | sampling_temperature=0.0, 107 | num_translations_per_input=1, 108 | infer_mode="greedy", 109 | 110 | # Language model 111 | language_model=False, 112 | ) 113 | -------------------------------------------------------------------------------- /nmt/utils/vocab_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility to handle vocabularies.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import codecs 23 | import os 24 | import tensorflow as tf 25 | 26 | from tensorflow.python.ops import lookup_ops 27 | 28 | from ..utils import misc_utils as utils 29 | 30 | # word level special token 31 | UNK = "" 32 | SOS = "" 33 | EOS = "" 34 | UNK_ID = 0 35 | 36 | # char ids 0-255 come from utf-8 encoding bytes 37 | # assign 256-300 to special chars 38 | BOS_CHAR_ID = 256 # 39 | EOS_CHAR_ID = 257 # 40 | BOW_CHAR_ID = 258 # 41 | EOW_CHAR_ID = 259 # 42 | PAD_CHAR_ID = 260 # 43 | 44 | DEFAULT_CHAR_MAXLEN = 50 # max number of chars for each word. 45 | 46 | 47 | def _string_to_bytes(text, max_length): 48 | """Given string and length, convert to byte seq of at most max_length. 49 | 50 | This process mimics docqa/elmo's preprocessing: 51 | https://github.com/allenai/document-qa/blob/master/docqa/elmo/data.py 52 | 53 | Note that we make use of BOS_CHAR_ID and EOS_CHAR_ID in iterator_utils.py & 54 | our usage differs from docqa/elmo. 55 | 56 | Args: 57 | text: tf.string tensor of shape [] 58 | max_length: max number of chars for each word. 59 | 60 | Returns: 61 | A tf.int32 tensor of the byte encoded text. 62 | """ 63 | byte_ids = tf.to_int32(tf.decode_raw(text, tf.uint8)) 64 | byte_ids = byte_ids[:max_length - 2] 65 | padding = tf.fill([max_length - tf.shape(byte_ids)[0] - 2], PAD_CHAR_ID) 66 | byte_ids = tf.concat( 67 | [[BOW_CHAR_ID], byte_ids, [EOW_CHAR_ID], padding], axis=0) 68 | tf.logging.info(byte_ids) 69 | 70 | byte_ids = tf.reshape(byte_ids, [max_length]) 71 | tf.logging.info(byte_ids.get_shape().as_list()) 72 | return byte_ids + 1 73 | 74 | 75 | def tokens_to_bytes(tokens): 76 | """Given a sequence of strings, map to sequence of bytes. 77 | 78 | Args: 79 | tokens: A tf.string tensor 80 | 81 | Returns: 82 | A tensor of shape words.shape + [bytes_per_word] containing byte versions 83 | of each word. 84 | """ 85 | bytes_per_word = DEFAULT_CHAR_MAXLEN 86 | with tf.device("/cpu:0"): 87 | tf.assert_rank(tokens, 1) 88 | shape = tf.shape(tokens) 89 | tf.logging.info(tokens) 90 | tokens_flat = tf.reshape(tokens, [-1]) 91 | as_bytes_flat = tf.map_fn( 92 | fn=lambda x: _string_to_bytes(x, max_length=bytes_per_word), 93 | elems=tokens_flat, 94 | dtype=tf.int32, 95 | back_prop=False) 96 | tf.logging.info(as_bytes_flat) 97 | as_bytes = tf.reshape(as_bytes_flat, [shape[0], bytes_per_word]) 98 | return as_bytes 99 | 100 | 101 | def load_vocab(vocab_file): 102 | vocab = [] 103 | with codecs.getreader("utf-8")(tf.gfile.GFile(vocab_file, "rb")) as f: 104 | vocab_size = 0 105 | for word in f: 106 | vocab_size += 1 107 | vocab.append(word.strip()) 108 | return vocab, vocab_size 109 | 110 | 111 | def check_vocab(vocab_file, out_dir, check_special_token=True, sos=None, 112 | eos=None, unk=None): 113 | """Check if vocab_file doesn't exist, create from corpus_file.""" 114 | if tf.gfile.Exists(vocab_file): 115 | utils.print_out("# Vocab file %s exists" % vocab_file) 116 | vocab, vocab_size = load_vocab(vocab_file) 117 | if check_special_token: 118 | # Verify if the vocab starts with unk, sos, eos 119 | # If not, prepend those tokens & generate a new vocab file 120 | if not unk: unk = UNK 121 | if not sos: sos = SOS 122 | if not eos: eos = EOS 123 | assert len(vocab) >= 3 124 | if vocab[0] != unk or vocab[1] != sos or vocab[2] != eos: 125 | utils.print_out("The first 3 vocab words [%s, %s, %s]" 126 | " are not [%s, %s, %s]" % 127 | (vocab[0], vocab[1], vocab[2], unk, sos, eos)) 128 | vocab = [unk, sos, eos] + vocab 129 | vocab_size += 3 130 | new_vocab_file = os.path.join(out_dir, os.path.basename(vocab_file)) 131 | with codecs.getwriter("utf-8")( 132 | tf.gfile.GFile(new_vocab_file, "wb")) as f: 133 | for word in vocab: 134 | f.write("%s\n" % word) 135 | vocab_file = new_vocab_file 136 | else: 137 | raise ValueError("vocab_file '%s' does not exist." % vocab_file) 138 | 139 | vocab_size = len(vocab) 140 | return vocab_size, vocab_file 141 | 142 | 143 | def create_vocab_tables(src_vocab_file, tgt_vocab_file, share_vocab): 144 | """Creates vocab tables for src_vocab_file and tgt_vocab_file.""" 145 | src_vocab_table = lookup_ops.index_table_from_file( 146 | src_vocab_file, default_value=UNK_ID) 147 | if share_vocab: 148 | tgt_vocab_table = src_vocab_table 149 | else: 150 | tgt_vocab_table = lookup_ops.index_table_from_file( 151 | tgt_vocab_file, default_value=UNK_ID) 152 | return src_vocab_table, tgt_vocab_table 153 | 154 | 155 | def load_embed_txt(embed_file): 156 | """Load embed_file into a python dictionary. 157 | 158 | Note: the embed_file should be a Glove/word2vec formatted txt file. Assuming 159 | Here is an exampe assuming embed_size=5: 160 | 161 | the -0.071549 0.093459 0.023738 -0.090339 0.056123 162 | to 0.57346 0.5417 -0.23477 -0.3624 0.4037 163 | and 0.20327 0.47348 0.050877 0.002103 0.060547 164 | 165 | For word2vec format, the first line will be: . 166 | 167 | Args: 168 | embed_file: file path to the embedding file. 169 | Returns: 170 | a dictionary that maps word to vector, and the size of embedding dimensions. 171 | """ 172 | emb_dict = dict() 173 | emb_size = None 174 | 175 | is_first_line = True 176 | with codecs.getreader("utf-8")(tf.gfile.GFile(embed_file, "rb")) as f: 177 | for line in f: 178 | tokens = line.rstrip().split(" ") 179 | if is_first_line: 180 | is_first_line = False 181 | if len(tokens) == 2: # header line 182 | emb_size = int(tokens[1]) 183 | continue 184 | word = tokens[0] 185 | vec = list(map(float, tokens[1:])) 186 | emb_dict[word] = vec 187 | if emb_size: 188 | if emb_size != len(vec): 189 | utils.print_out( 190 | "Ignoring %s since embeding size is inconsistent." % word) 191 | del emb_dict[word] 192 | else: 193 | emb_size = len(vec) 194 | return emb_dict, emb_size 195 | -------------------------------------------------------------------------------- /nmt/utils/vocab_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for vocab_utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import codecs 23 | import os 24 | import tensorflow as tf 25 | 26 | from ..utils import vocab_utils 27 | 28 | 29 | class VocabUtilsTest(tf.test.TestCase): 30 | 31 | def testCheckVocab(self): 32 | # Create a vocab file 33 | vocab_dir = os.path.join(tf.test.get_temp_dir(), "vocab_dir") 34 | os.makedirs(vocab_dir) 35 | vocab_file = os.path.join(vocab_dir, "vocab_file") 36 | vocab = ["a", "b", "c"] 37 | with codecs.getwriter("utf-8")(tf.gfile.GFile(vocab_file, "wb")) as f: 38 | for word in vocab: 39 | f.write("%s\n" % word) 40 | 41 | # Call vocab_utils 42 | out_dir = os.path.join(tf.test.get_temp_dir(), "out_dir") 43 | os.makedirs(out_dir) 44 | vocab_size, new_vocab_file = vocab_utils.check_vocab( 45 | vocab_file, out_dir) 46 | 47 | # Assert: we expect the code to add , , and 48 | # create a new vocab file 49 | self.assertEqual(len(vocab) + 3, vocab_size) 50 | self.assertEqual(os.path.join(out_dir, "vocab_file"), new_vocab_file) 51 | new_vocab, _ = vocab_utils.load_vocab(new_vocab_file) 52 | self.assertEqual( 53 | [vocab_utils.UNK, vocab_utils.SOS, vocab_utils.EOS] + vocab, new_vocab) 54 | 55 | 56 | if __name__ == "__main__": 57 | tf.test.main() 58 | --------------------------------------------------------------------------------