├── nmt ├── __init__.py ├── scripts │ ├── __init__.py │ ├── download_iwslt15.sh │ ├── bleu.py │ ├── wmt16_en_de.sh │ └── rouge.py ├── utils │ ├── __init__.py │ ├── misc_utils_test.py │ ├── vocab_utils_test.py │ ├── evaluation_utils_test.py │ ├── standard_hparams_utils.py │ ├── nmt_utils.py │ ├── vocab_utils.py │ ├── common_test_utils.py │ ├── misc_utils.py │ ├── evaluation_utils.py │ ├── iterator_utils.py │ └── iterator_utils_test.py ├── .gitignore ├── standard_hparams │ ├── iwslt15.json │ ├── wmt16.json │ ├── mine.json │ ├── wmt16_gnmt_4_layer.json │ └── wmt16_gnmt_8_layer.json ├── OT.py ├── nmt_test.py ├── attention_model.py ├── inference_test.py ├── inference.py ├── gnmt_model.py └── model_helper.py ├── texar ├── configs │ ├── __init__.py │ ├── config_model.py │ ├── config_giga.py │ └── config_iwslt14.py ├── requirements.txt ├── utils │ ├── raml_samples_generation │ │ ├── README.md │ │ ├── gen_samples_giga.sh │ │ ├── gen_samples_iwslt14.sh │ │ ├── util.py │ │ ├── vocab.py │ │ └── process_samples.py │ └── prepare_data.py ├── README.md ├── OT.py ├── baseline_seq2seq_attn_main.py └── baseline_seq2seq_attn_ot.py ├── .gitattributes ├── image ├── alg.JPG ├── draw2.jpg └── model.JPG ├── LICENSE ├── .gitignore └── README.md /nmt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nmt/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nmt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /texar/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /texar/requirements.txt: -------------------------------------------------------------------------------- 1 | rouge==0.2.1 -------------------------------------------------------------------------------- /nmt/.gitignore: -------------------------------------------------------------------------------- 1 | bazel-bin 2 | bazel-genfiles 3 | bazel-out 4 | bazel-testlogs 5 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /image/alg.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiqunChen0606/OT-Seq2Seq/HEAD/image/alg.JPG -------------------------------------------------------------------------------- /image/draw2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiqunChen0606/OT-Seq2Seq/HEAD/image/draw2.jpg -------------------------------------------------------------------------------- /image/model.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiqunChen0606/OT-Seq2Seq/HEAD/image/model.JPG -------------------------------------------------------------------------------- /texar/utils/raml_samples_generation/README.md: -------------------------------------------------------------------------------- 1 | ## Augmented Data Generation for RAML Algorithm 2 | 3 | Codes here are mainly copied from [pcyin's github](https://github.com/pcyin/pytorch_nmt), with slightly change for supporting ```rouge``` as reward. Note that we have also provided generated samples in the datasets that you can download. 4 | 5 | You may tune hyperparameters in ```gen_samples_giga.sh``` or ```gen_samples_iwslt14.sh``` and use commands like ```bash gen_samples_giga.sh``` to begin your generation. 6 | -------------------------------------------------------------------------------- /texar/utils/raml_samples_generation/gen_samples_giga.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | train_src="../../data/giga/train.article" 4 | train_tgt="../../data/giga/train.title" 5 | 6 | python vocab.py \ 7 | --src_vocab_size 30424 \ 8 | --tgt_vocab_size 23738 \ 9 | --train_src ${train_src} \ 10 | --train_tgt ${train_tgt} \ 11 | --include_singleton \ 12 | --output giga_vocab.bin 13 | 14 | python process_samples.py \ 15 | --mode sample_ngram \ 16 | --vocab giga_vocab.bin \ 17 | --src ${train_src} \ 18 | --tgt ${train_tgt} \ 19 | --sample_size 10 \ 20 | --reward rouge \ 21 | --output samples_giga.txt 22 | -------------------------------------------------------------------------------- /texar/utils/raml_samples_generation/gen_samples_iwslt14.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | train_src="../../data/iwslt14/train.de" 4 | train_tgt="../../data/iwslt14/train.en" 5 | 6 | python vocab.py \ 7 | --src_vocab_size 32007 \ 8 | --tgt_vocab_size 22820 \ 9 | --train_src ${train_src} \ 10 | --train_tgt ${train_tgt} \ 11 | --include_singleton \ 12 | --output iwslt14_vocab.bin 13 | 14 | python process_samples.py \ 15 | --mode sample_ngram \ 16 | --vocab iwslt14_vocab.bin \ 17 | --src ${train_src} \ 18 | --tgt ${train_tgt} \ 19 | --sample_size 10 \ 20 | --reward bleu \ 21 | --output samples_iwslt14.txt 22 | -------------------------------------------------------------------------------- /nmt/standard_hparams/iwslt15.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention": "scaled_luong", 3 | "attention_architecture": "standard", 4 | "batch_size": 128, 5 | "colocate_gradients_with_ops": true, 6 | "dropout": 0.2, 7 | "encoder_type": "bi", 8 | "eos": "", 9 | "forget_bias": 1.0, 10 | "infer_batch_size": 32, 11 | "init_weight": 0.1, 12 | "learning_rate": 1.0, 13 | "max_gradient_norm": 5.0, 14 | "metrics": ["bleu"], 15 | "num_buckets": 5, 16 | "num_layers": 2, 17 | "num_train_steps": 12000, 18 | "decay_scheme": "luong234", 19 | "num_units": 512, 20 | "optimizer": "sgd", 21 | "residual": false, 22 | "share_vocab": false, 23 | "subword_option": "", 24 | "sos": "", 25 | "src_max_len": 50, 26 | "src_max_len_infer": null, 27 | "steps_per_external_eval": null, 28 | "steps_per_stats": 100, 29 | "tgt_max_len": 50, 30 | "tgt_max_len_infer": null, 31 | "time_major": true, 32 | "unit_type": "lstm", 33 | "beam_width": 10 34 | } 35 | -------------------------------------------------------------------------------- /nmt/standard_hparams/wmt16.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention": "normed_bahdanau", 3 | "attention_architecture": "standard", 4 | "batch_size": 128, 5 | "colocate_gradients_with_ops": true, 6 | "dropout": 0.2, 7 | "encoder_type": "bi", 8 | "eos": "", 9 | "forget_bias": 1.0, 10 | "infer_batch_size": 32, 11 | "init_weight": 0.1, 12 | "learning_rate": 1.0, 13 | "max_gradient_norm": 5.0, 14 | "metrics": ["bleu"], 15 | "num_buckets": 5, 16 | "num_layers": 4, 17 | "num_train_steps": 340000, 18 | "decay_scheme": "luong10", 19 | "num_units": 1024, 20 | "optimizer": "sgd", 21 | "residual": false, 22 | "share_vocab": false, 23 | "subword_option": "bpe", 24 | "sos": "", 25 | "src_max_len": 50, 26 | "src_max_len_infer": null, 27 | "steps_per_external_eval": null, 28 | "steps_per_stats": 100, 29 | "tgt_max_len": 50, 30 | "tgt_max_len_infer": null, 31 | "time_major": true, 32 | "unit_type": "lstm", 33 | "beam_width": 10 34 | } 35 | -------------------------------------------------------------------------------- /texar/configs/config_model.py: -------------------------------------------------------------------------------- 1 | num_units = 256 2 | beam_width = 5 3 | decoder_layers = 1 4 | dropout = 0.2 5 | 6 | embedder = { 7 | 'dim': num_units 8 | } 9 | encoder = { 10 | 'rnn_cell_fw': { 11 | 'kwargs': { 12 | 'num_units': num_units 13 | }, 14 | 'dropout': { 15 | 'input_keep_prob': 1. - dropout 16 | } 17 | } 18 | } 19 | decoder = { 20 | 'rnn_cell': { 21 | 'kwargs': { 22 | 'num_units': num_units 23 | }, 24 | 'dropout': { 25 | 'input_keep_prob': 1. - dropout 26 | }, 27 | 'num_layers': decoder_layers 28 | }, 29 | 'attention': { 30 | 'kwargs': { 31 | 'num_units': num_units, 32 | }, 33 | 'attention_layer_size': num_units 34 | } 35 | } 36 | opt = { 37 | 'optimizer': { 38 | 'type': 'AdamOptimizer', 39 | 'kwargs': { 40 | 'learning_rate': 0.001, 41 | }, 42 | }, 43 | } -------------------------------------------------------------------------------- /nmt/standard_hparams/mine.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention": "scaled_luong", 3 | "attention_architecture": "standard", 4 | "batch_size": 128, 5 | "colocate_gradients_with_ops": true, 6 | "dropout": 0.2, 7 | "encoder_type": "bi", 8 | "eos": "", 9 | "forget_bias": 1.0, 10 | "infer_batch_size": 32, 11 | "init_weight": 0.1, 12 | "learning_rate": 1.0, 13 | "max_gradient_norm": 5.0, 14 | "metrics": ["bleu"], 15 | "num_buckets": 5, 16 | "num_layers": 2, 17 | "num_train_steps": 340000, 18 | "decay_scheme": "luong10", 19 | "num_units": 1024, 20 | "optimizer": "sgd", 21 | "residual": true, 22 | "share_vocab": false, 23 | "subword_option": "bpe", 24 | "sos": "", 25 | "src_max_len": 50, 26 | "src_max_len_infer": null, 27 | "steps_per_external_eval": null, 28 | "steps_per_stats": 100, 29 | "tgt_max_len": 50, 30 | "tgt_max_len_infer": null, 31 | "time_major": true, 32 | "unit_type": "lstm", 33 | "beam_width": 10, 34 | "length_penalty_weight": 1.0 35 | } 36 | -------------------------------------------------------------------------------- /nmt/standard_hparams/wmt16_gnmt_4_layer.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention": "normed_bahdanau", 3 | "attention_architecture": "gnmt_v2", 4 | "batch_size": 128, 5 | "colocate_gradients_with_ops": true, 6 | "dropout": 0.2, 7 | "encoder_type": "gnmt", 8 | "eos": "", 9 | "forget_bias": 1.0, 10 | "infer_batch_size": 32, 11 | "init_weight": 0.1, 12 | "learning_rate": 1.0, 13 | "max_gradient_norm": 5.0, 14 | "metrics": ["bleu"], 15 | "num_buckets": 5, 16 | "num_layers": 4, 17 | "num_train_steps": 340000, 18 | "decay_scheme": "luong10", 19 | "num_units": 1024, 20 | "optimizer": "sgd", 21 | "residual": true, 22 | "share_vocab": false, 23 | "subword_option": "bpe", 24 | "sos": "", 25 | "src_max_len": 50, 26 | "src_max_len_infer": null, 27 | "steps_per_external_eval": null, 28 | "steps_per_stats": 100, 29 | "tgt_max_len": 50, 30 | "tgt_max_len_infer": null, 31 | "time_major": true, 32 | "unit_type": "lstm", 33 | "beam_width": 10, 34 | "length_penalty_weight": 1.0 35 | } -------------------------------------------------------------------------------- /nmt/standard_hparams/wmt16_gnmt_8_layer.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention": "normed_bahdanau", 3 | "attention_architecture": "gnmt_v2", 4 | "batch_size": 128, 5 | "colocate_gradients_with_ops": true, 6 | "dropout": 0.2, 7 | "encoder_type": "gnmt", 8 | "eos": "", 9 | "forget_bias": 1.0, 10 | "infer_batch_size": 32, 11 | "init_weight": 0.1, 12 | "learning_rate": 1.0, 13 | "max_gradient_norm": 5.0, 14 | "metrics": ["bleu"], 15 | "num_buckets": 5, 16 | "num_layers": 8, 17 | "num_train_steps": 340000, 18 | "decay_scheme": "luong10", 19 | "num_units": 1024, 20 | "optimizer": "sgd", 21 | "residual": true, 22 | "share_vocab": false, 23 | "subword_option": "bpe", 24 | "sos": "", 25 | "source_reverse": false, 26 | "src_max_len": 50, 27 | "src_max_len_infer": null, 28 | "steps_per_external_eval": null, 29 | "steps_per_stats": 100, 30 | "tgt_max_len": 50, 31 | "tgt_max_len_infer": null, 32 | "time_major": true, 33 | "unit_type": "lstm", 34 | "beam_width": 10, 35 | "length_penalty_weight": 1.0 36 | } 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 LiqunChen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /nmt/scripts/download_iwslt15.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Download small-scale IWSLT15 Vietnames to English translation data for NMT 3 | # model training. 4 | # 5 | # Usage: 6 | # ./download_iwslt15.sh path-to-output-dir 7 | # 8 | # If output directory is not specified, "./iwslt15" will be used as the default 9 | # output directory. 10 | OUT_DIR="${1:-iwslt15}" 11 | SITE_PREFIX="https://nlp.stanford.edu/projects/nmt/data" 12 | 13 | mkdir -v -p $OUT_DIR 14 | 15 | # Download iwslt15 small dataset from standford website. 16 | echo "Download training dataset train.en and train.vi." 17 | curl -o "$OUT_DIR/train.en" "$SITE_PREFIX/iwslt15.en-vi/train.en" 18 | curl -o "$OUT_DIR/train.vi" "$SITE_PREFIX/iwslt15.en-vi/train.vi" 19 | 20 | echo "Download dev dataset tst2012.en and tst2012.vi." 21 | curl -o "$OUT_DIR/tst2012.en" "$SITE_PREFIX/iwslt15.en-vi/tst2012.en" 22 | curl -o "$OUT_DIR/tst2012.vi" "$SITE_PREFIX/iwslt15.en-vi/tst2012.vi" 23 | 24 | echo "Download test dataset tst2013.en and tst2013.vi." 25 | curl -o "$OUT_DIR/tst2013.en" "$SITE_PREFIX/iwslt15.en-vi/tst2013.en" 26 | curl -o "$OUT_DIR/tst2013.vi" "$SITE_PREFIX/iwslt15.en-vi/tst2013.vi" 27 | 28 | echo "Download vocab file vocab.en and vocab.vi." 29 | curl -o "$OUT_DIR/vocab.en" "$SITE_PREFIX/iwslt15.en-vi/vocab.en" 30 | curl -o "$OUT_DIR/vocab.vi" "$SITE_PREFIX/iwslt15.en-vi/vocab.vi" 31 | -------------------------------------------------------------------------------- /texar/configs/config_giga.py: -------------------------------------------------------------------------------- 1 | num_epochs = 20 2 | observe_steps = 500 3 | 4 | eval_metric = 'rouge' 5 | 6 | batch_size = 64 7 | source_vocab_file = './data/giga/vocab.article' 8 | target_vocab_file = './data/giga/vocab.title' 9 | 10 | train = { 11 | 'batch_size': batch_size, 12 | 'allow_smaller_final_batch': False, 13 | 'source_dataset': { 14 | "files": 'data/giga/train.article', 15 | 'vocab_file': source_vocab_file 16 | }, 17 | 'target_dataset': { 18 | 'files': 'data/giga/train.title', 19 | 'vocab_file': target_vocab_file 20 | } 21 | } 22 | val = { 23 | 'batch_size': batch_size, 24 | 'shuffle': False, 25 | 'allow_smaller_final_batch': True, 26 | 'source_dataset': { 27 | "files": 'data/giga/valid.article', 28 | 'vocab_file': source_vocab_file, 29 | }, 30 | 'target_dataset': { 31 | 'files': 'data/giga/valid.title', 32 | 'vocab_file': target_vocab_file, 33 | } 34 | } 35 | test = { 36 | 'batch_size': batch_size, 37 | 'shuffle': False, 38 | 'allow_smaller_final_batch': True, 39 | 'source_dataset': { 40 | "files": 'data/giga/test.article', 41 | 'vocab_file': source_vocab_file, 42 | }, 43 | 'target_dataset': { 44 | 'files': 'data/giga/test.title', 45 | 'vocab_file': target_vocab_file, 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /nmt/OT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def cost_matrix(x, y): 6 | "Returns the matrix of $|x_i-y_j|^p$." 7 | "Returns the cosine distance" 8 | #NOTE: choose cosine distance here 9 | 10 | x = tf.nn.l2_normalize(x, 1, epsilon=1e-12) 11 | y = tf.nn.l2_normalize(y, 1, epsilon=1e-12) 12 | tmp1 = tf.matmul(x, y, transpose_b=True) 13 | cos_dis = 1 - tmp1 14 | 15 | x_col = tf.expand_dims(x, 1) 16 | y_lin = tf.expand_dims(y, 0) 17 | res = tf.reduce_sum(tf.abs(x_col - y_lin), 2) 18 | 19 | return cos_dis 20 | 21 | def IPOT(C, n, m, beta=0.5): 22 | 23 | # sigma = tf.scalar_mul(1 / n, tf.ones([n, 1])) 24 | sigma = tf.ones([m, 1]) / tf.cast(m, tf.float32) 25 | T = tf.ones([n, m]) 26 | A = tf.exp(-C / beta) 27 | for t in range(50): 28 | Q = tf.multiply(A, T) 29 | for k in range(1): 30 | delta = 1 / (tf.cast(n, tf.float32) * tf.matmul(Q, sigma)) 31 | sigma = 1 / ( 32 | tf.cast(m, tf.float32) * tf.matmul(Q, delta, transpose_a=True)) 33 | # pdb.set_trace() 34 | tmp = tf.matmul(tf.diag(tf.squeeze(delta)), Q) 35 | T = tf.matmul(tmp, tf.diag(tf.squeeze(sigma))) 36 | return T 37 | 38 | 39 | def IPOT_distance(C, n, m): 40 | T = IPOT(C, n, m) 41 | distance = tf.trace(tf.matmul(C, T, transpose_a=True)) 42 | return distance 43 | -------------------------------------------------------------------------------- /texar/configs/config_iwslt14.py: -------------------------------------------------------------------------------- 1 | num_epochs = 50 # the best epoch occurs within 10 epochs in most cases 2 | observe_steps = 500 3 | 4 | eval_metric = 'bleu' 5 | 6 | batch_size = 64 7 | source_vocab_file = './data/iwslt14/vocab.de' 8 | target_vocab_file = './data/iwslt14/vocab.en' 9 | 10 | train = { 11 | 'batch_size': batch_size, 12 | 'shuffle': True, 13 | 'allow_smaller_final_batch': False, 14 | 'source_dataset': { 15 | "files": 'data/iwslt14/train.de', 16 | 'vocab_file': source_vocab_file, 17 | 'max_seq_length': 50 18 | }, 19 | 'target_dataset': { 20 | 'files': 'data/iwslt14/train.en', 21 | 'vocab_file': target_vocab_file, 22 | 'max_seq_length': 50 23 | } 24 | } 25 | val = { 26 | 'batch_size': batch_size, 27 | 'shuffle': False, 28 | 'allow_smaller_final_batch': True, 29 | 'source_dataset': { 30 | "files": 'data/iwslt14/valid.de', 31 | 'vocab_file': source_vocab_file, 32 | }, 33 | 'target_dataset': { 34 | 'files': 'data/iwslt14/valid.en', 35 | 'vocab_file': target_vocab_file, 36 | } 37 | } 38 | test = { 39 | 'batch_size': batch_size, 40 | 'shuffle': False, 41 | 'allow_smaller_final_batch': True, 42 | 'source_dataset': { 43 | "files": 'data/iwslt14/test.de', 44 | 'vocab_file': source_vocab_file, 45 | }, 46 | 'target_dataset': { 47 | 'files': 'data/iwslt14/test.en', 48 | 'vocab_file': target_vocab_file, 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /nmt/utils/misc_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for vocab_utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from ..utils import misc_utils 25 | 26 | 27 | class MiscUtilsTest(tf.test.TestCase): 28 | 29 | def testFormatBpeText(self): 30 | bpe_line = ( 31 | b"En@@ ough to make already reluc@@ tant men hesitate to take screening" 32 | b" tests ." 33 | ) 34 | expected_result = ( 35 | b"Enough to make already reluctant men hesitate to take screening tests" 36 | b" ." 37 | ) 38 | self.assertEqual(expected_result, 39 | misc_utils.format_bpe_text(bpe_line.split(b" "))) 40 | 41 | def testFormatSPMText(self): 42 | spm_line = u"\u2581This \u2581is \u2581a \u2581 te st .".encode("utf-8") 43 | expected_result = "This is a test." 44 | self.assertEqual(expected_result, 45 | misc_utils.format_spm_text(spm_line.split(b" "))) 46 | 47 | 48 | if __name__ == "__main__": 49 | tf.test.main() 50 | -------------------------------------------------------------------------------- /texar/utils/prepare_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Texar Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Downloads data. 15 | """ 16 | import tensorflow as tf 17 | import texar as tx 18 | 19 | # pylint: disable=invalid-name 20 | 21 | flags = tf.flags 22 | 23 | flags.DEFINE_string("data", "iwslt14", "Data to download [iwslt14|toy_copy]") 24 | 25 | FLAGS = flags.FLAGS 26 | 27 | 28 | def prepare_data(): 29 | """Downloads data. 30 | """ 31 | if FLAGS.data == 'giga': 32 | tx.data.maybe_download( 33 | urls='https://drive.google.com/file/d/' 34 | '12RZs7QFwjj6dfuYNQ_0Ah-ccH1xFDMD5/view?usp=sharing', 35 | path='./', 36 | filenames='giga.zip', 37 | extract=True) 38 | elif FLAGS.data == 'iwslt14': 39 | tx.data.maybe_download( 40 | urls='https://drive.google.com/file/d/' 41 | '1y4mUWXRS2KstgHopCS9koZ42ENOh6Yb9/view?usp=sharing', 42 | path='./', 43 | filenames='iwslt14.zip', 44 | extract=True) 45 | else: 46 | raise ValueError('Unknown data: {}'.format(FLAGS.data)) 47 | 48 | 49 | def main(): 50 | """Entrypoint. 51 | """ 52 | prepare_data() 53 | 54 | 55 | if __name__ == '__main__': 56 | main() -------------------------------------------------------------------------------- /texar/utils/raml_samples_generation/util.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import numpy as np 3 | 4 | def read_corpus(file_path, source): 5 | data = [] 6 | for line in open(file_path): 7 | sent = line.strip().split(' ') 8 | # only append and to the target sentence 9 | if source == 'tgt': 10 | sent = [''] + sent + [''] 11 | data.append(sent) 12 | 13 | return data 14 | 15 | 16 | def batch_slice(data, batch_size, sort=True): 17 | batch_num = int(np.ceil(len(data) / float(batch_size))) 18 | for i in range(batch_num): 19 | cur_batch_size = batch_size if i < batch_num - 1 else len(data) - batch_size * i 20 | src_sents = [data[i * batch_size + b][0] for b in range(cur_batch_size)] 21 | tgt_sents = [data[i * batch_size + b][1] for b in range(cur_batch_size)] 22 | 23 | if sort: 24 | src_ids = sorted(range(cur_batch_size), key=lambda src_id: len(src_sents[src_id]), reverse=True) 25 | src_sents = [src_sents[src_id] for src_id in src_ids] 26 | tgt_sents = [tgt_sents[src_id] for src_id in src_ids] 27 | 28 | yield src_sents, tgt_sents 29 | 30 | 31 | def data_iter(data, batch_size, shuffle=True): 32 | """ 33 | randomly permute data, then sort by source length, and partition into batches 34 | ensure that the length of source sentences in each batch is decreasing 35 | """ 36 | 37 | buckets = defaultdict(list) 38 | for pair in data: 39 | src_sent = pair[0] 40 | buckets[len(src_sent)].append(pair) 41 | 42 | batched_data = [] 43 | for src_len in buckets: 44 | tuples = buckets[src_len] 45 | if shuffle: np.random.shuffle(tuples) 46 | batched_data.extend(list(batch_slice(tuples, batch_size))) 47 | 48 | if shuffle: 49 | np.random.shuffle(batched_data) 50 | for batch in batched_data: 51 | yield batch 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # IPython 77 | profile_default/ 78 | ipython_config.py 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .dmypy.json 111 | dmypy.json 112 | 113 | # Pyre type checker 114 | .pyre/ 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OT-Seq2Seq 2 | This is the repository for ICLR 2019 paper [IMPROVING SEQUENCE-TO-SEQUENCE LEARNING 3 | VIA OPTIMAL TRANSPORT](https://arxiv.org/pdf/1901.06283.pdf) 4 | 5 | ## Usage ## 6 | Folder [nmt](./nmt) is built upon [GoogleNMT](https://github.com/tensorflow/nmt). 7 | Please follow the instructions in that repo for dataset downloading and code testing. 8 | 9 | Folder [texar](./texar) is built upon [Texar](https://github.com/asyml/texar). 10 | Details about summarization and translation tasks, please follow this [link](./texar). 11 | 12 | ## Brief introduction ## 13 | ![Model intuition](./image/draw2.jpg) 14 | We present a novel Seq2Seq learning scheme that leverages optimal transport (OT) to construct sequence-level loss. Specifically, the OT objective aims to find an optimal matching of similarwords/phrases between two sequences, providing a way to promote their semantic similarity (Kusneret al., 2015). Compared with the above RL and adversarial schemes, our approach has: (i) semantic-invariance, allowing better preservation of sequence-level semantic information; and (ii) improved robustness, since neither the reinforce gradient nor the mini-max game is involved. The OT loss allows end-to-end supervised training and acts as an effective sequence-level regularization to the MLE loss. 15 | 16 | OT can be easily applied to any Seq2Seq learning framework, the framework figure is shown here: 17 | ![Model framekwork](./image/model.JPG) 18 | 19 | Therefore, the training algorithm can be represented as: 20 | ![Model algorithm](./image/alg.JPG) 21 | 22 | 23 | ### Reference 24 | If you are interested in our paper and want to further improve the model, please cite our paper with the following BibTex entry: 25 | ``` 26 | @article{chen2019improving, 27 | title={Improving Sequence-to-Sequence Learning via Optimal Transport}, 28 | author={Chen, Liqun and Zhang, Yizhe and Zhang, Ruiyi and Tao, Chenyang and Gan, Zhe and Zhang, Haichao and Li, Bai and Shen, Dinghan and Chen, Changyou and Carin, Lawrence}, 29 | journal={ICLR}, 30 | year={2019} 31 | } 32 | ``` -------------------------------------------------------------------------------- /nmt/utils/vocab_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for vocab_utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import codecs 23 | import os 24 | import tensorflow as tf 25 | 26 | from ..utils import vocab_utils 27 | 28 | 29 | class VocabUtilsTest(tf.test.TestCase): 30 | 31 | def testCheckVocab(self): 32 | # Create a vocab file 33 | vocab_dir = os.path.join(tf.test.get_temp_dir(), "vocab_dir") 34 | os.makedirs(vocab_dir) 35 | vocab_file = os.path.join(vocab_dir, "vocab_file") 36 | vocab = ["a", "b", "c"] 37 | with codecs.getwriter("utf-8")(tf.gfile.GFile(vocab_file, "wb")) as f: 38 | for word in vocab: 39 | f.write("%s\n" % word) 40 | 41 | # Call vocab_utils 42 | out_dir = os.path.join(tf.test.get_temp_dir(), "out_dir") 43 | os.makedirs(out_dir) 44 | vocab_size, new_vocab_file = vocab_utils.check_vocab( 45 | vocab_file, out_dir) 46 | 47 | # Assert: we expect the code to add , , and 48 | # create a new vocab file 49 | self.assertEqual(len(vocab) + 3, vocab_size) 50 | self.assertEqual(os.path.join(out_dir, "vocab_file"), new_vocab_file) 51 | new_vocab, _ = vocab_utils.load_vocab(new_vocab_file) 52 | self.assertEqual( 53 | [vocab_utils.UNK, vocab_utils.SOS, vocab_utils.EOS] + vocab, new_vocab) 54 | 55 | 56 | if __name__ == "__main__": 57 | tf.test.main() 58 | -------------------------------------------------------------------------------- /nmt/utils/evaluation_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for evaluation_utils.py.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from ..utils import evaluation_utils 25 | 26 | 27 | class EvaluationUtilsTest(tf.test.TestCase): 28 | 29 | def testEvaluate(self): 30 | output = "nmt/testdata/deen_output" 31 | ref_bpe = "nmt/testdata/deen_ref_bpe" 32 | ref_spm = "nmt/testdata/deen_ref_spm" 33 | 34 | expected_bleu_score = 22.5855084573 35 | expected_rouge_score = 50.8429782599 36 | 37 | bpe_bleu_score = evaluation_utils.evaluate( 38 | ref_bpe, output, "bleu", "bpe") 39 | bpe_rouge_score = evaluation_utils.evaluate( 40 | ref_bpe, output, "rouge", "bpe") 41 | 42 | self.assertAlmostEqual(expected_bleu_score, bpe_bleu_score) 43 | self.assertAlmostEqual(expected_rouge_score, bpe_rouge_score) 44 | 45 | spm_bleu_score = evaluation_utils.evaluate( 46 | ref_spm, output, "bleu", "spm") 47 | spm_rouge_score = evaluation_utils.evaluate( 48 | ref_spm, output, "rouge", "spm") 49 | 50 | self.assertAlmostEqual(expected_rouge_score, spm_rouge_score) 51 | self.assertAlmostEqual(expected_bleu_score, spm_bleu_score) 52 | 53 | def testAccuracy(self): 54 | pred_output = "nmt/testdata/pred_output" 55 | label_ref = "nmt/testdata/label_ref" 56 | 57 | expected_accuracy_score = 60.00 58 | 59 | accuracy_score = evaluation_utils.evaluate( 60 | label_ref, pred_output, "accuracy") 61 | self.assertAlmostEqual(expected_accuracy_score, accuracy_score) 62 | 63 | def testWordAccuracy(self): 64 | pred_output = "nmt/testdata/pred_output" 65 | label_ref = "nmt/testdata/label_ref" 66 | 67 | expected_word_accuracy_score = 60.00 68 | 69 | word_accuracy_score = evaluation_utils.evaluate( 70 | label_ref, pred_output, "word_accuracy") 71 | self.assertAlmostEqual(expected_word_accuracy_score, word_accuracy_score) 72 | 73 | 74 | if __name__ == "__main__": 75 | tf.test.main() 76 | -------------------------------------------------------------------------------- /texar/README.md: -------------------------------------------------------------------------------- 1 | This code is largely borrowed from [Texar](https://texar.readthedocs.io/en/latest/). 2 | Please install the requirement first to run the code. 3 | 4 | # Sequence Generation # 5 | This example provide implementations of some classic and advanced training algorithms that tackles the exposure bias. The base model is an attentional seq2seq. 6 | 7 | * **Maximum Likelihood (MLE)**: attentional seq2seq model with maximum likelihood training. 8 | * **Maximum Likelihood (MLE) + Optimal transport (OT)**: Described in [OT-seq2seq](https://arxiv.org/pdf/1901.06283.pdf) and we use the sampling approach (n-gram replacement) by [(Ma et al., 2017)](https://arxiv.org/abs/1705.07136). 9 | 10 | ## Usage ## 11 | 12 | ### Dataset ### 13 | 14 | Two example datasets are provided: 15 | 16 | * iwslt14: The benchmark [IWSLT2014](https://sites.google.com/site/iwsltevaluation2014/home) (de-en) machine translation dataset, following [(Ranzato et al., 2015)](https://arxiv.org/pdf/1511.06732.pdf) for data pre-processing. 17 | * gigaword: The benchmark [GIGAWORD](https://catalog.ldc.upenn.edu/LDC2003T05) text summarization dataset. we sampled 200K out of the 3.8M pre-processed training examples provided by [(Rush et al., 2015)](https://www.aclweb.org/anthology/D/D15/D15-1044.pdf) for the sake of training efficiency. We used the refined validation and test sets provided by [(Zhou et al., 2017)](https://arxiv.org/pdf/1704.07073.pdf). 18 | 19 | Download the data with the following commands: 20 | 21 | ``` 22 | python utils/prepare_data.py --data iwslt14 23 | python utils/prepare_data.py --data giga 24 | ``` 25 | 26 | ### Train the models ### 27 | 28 | #### Baseline Attentional Seq2seq with OT 29 | 30 | ``` 31 | python baseline_seq2seq_attn_main.py \ 32 | --config_model configs.config_model \ 33 | --config_data configs.config_iwslt14 34 | ``` 35 | 36 | Here: 37 | * `--config_model` specifies the model config. Note not to include the `.py` suffix. 38 | * `--config_data` specifies the data config. 39 | 40 | [configs.config_model.py](./configs/config_model.py) specifies a single-layer seq2seq model with Luong attention and bi-directional RNN encoder. Hyperparameters taking default values can be omitted from the config file. 41 | 42 | 43 | For demonstration purpose, [configs.config_model_full.py](./configs/config_model_full.py) gives all possible hyperparameters for the model. The two config files will lead to the same model. 44 | 45 | ## Results ## 46 | 47 | ### Machine Translation 48 | | Model | BLEU Score | 49 | | -----------| -------| 50 | | MLE | 26.44 ± 0.18 | 51 | | Scheduled Sampling | 26.76 ± 0.17 | 52 | | RAML | 27.22 ± 0.14 | 53 | | Interpolation | 27.82 ± 0.11 | 54 | | MLE + OT | 27.79 ± 0.12 | 55 | 56 | ### Text Summarization 57 | | Model | Rouge-1 | Rouge-2 | Rouge-L | 58 | | -----------| -------|-------|-------| 59 | | MLE | 36.11 ± 0.21 | 16.39 ± 0.16 | 32.32 ± 0.19 | 60 | | RAML | 36.30 ± 0.24 | 16.69 ± 0.20 | 32.49 ± 0.17 | 61 | | Interpolation | 36.72 ± 0.29 |16.99 ± 0.17 | 32.95 ± 0.33| 62 | | MLE + OT | 36.82 ± 0.25 | 17.35 ± 0.10 | 33.35 ± 0.14 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /texar/OT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from functools import partial 4 | import pdb 5 | def cost_matrix(x, y): 6 | "Returns the matrix of $|x_i-y_j|^p$." 7 | "Returns the cosine distance" 8 | #NOTE: cosine distance and Euclidean distance 9 | # x_col = x.unsqueeze(1) 10 | # y_lin = y.unsqueeze(0) 11 | # c = torch.sum((torch.abs(x_col - y_lin)) ** p, 2) 12 | # return c 13 | x = tf.nn.l2_normalize(x, 1, epsilon=1e-12) 14 | y = tf.nn.l2_normalize(y, 1, epsilon=1e-12) 15 | tmp1 = tf.matmul(x, y, transpose_b=True) 16 | cos_dis = 1 - tmp1 17 | 18 | x_col = tf.expand_dims(x, 1) 19 | y_lin = tf.expand_dims(y, 0) 20 | res = tf.reduce_sum(tf.abs(x_col - y_lin), 2) 21 | 22 | return cos_dis 23 | 24 | 25 | def IPOT(C, n, m, beta=0.5): 26 | 27 | # sigma = tf.scalar_mul(1 / n, tf.ones([n, 1])) 28 | 29 | sigma = tf.ones([m, 1]) / tf.cast(m, tf.float32) 30 | T = tf.ones([n, m]) 31 | A = tf.exp(-C / beta) 32 | for t in range(50): 33 | Q = tf.multiply(A, T) 34 | for k in range(1): 35 | delta = 1 / (tf.cast(n, tf.float32) * tf.matmul(Q, sigma)) 36 | sigma = 1 / ( 37 | tf.cast(m, tf.float32) * tf.matmul(Q, delta, transpose_a=True)) 38 | # pdb.set_trace() 39 | tmp = tf.matmul(tf.diag(tf.squeeze(delta)), Q) 40 | T = tf.matmul(tmp, tf.diag(tf.squeeze(sigma))) 41 | return T 42 | 43 | 44 | def IPOT_np(C, beta=0.5): 45 | 46 | n, m = C.shape[0], C.shape[1] 47 | sigma = np.ones([m, 1]) / m 48 | T = np.ones([n, m]) 49 | A = np.exp(-C / beta) 50 | for t in range(20): 51 | Q = np.multiply(A, T) 52 | for k in range(1): 53 | delta = 1 / (n * (Q @ sigma)) 54 | sigma = 1 / (m * (Q.T @ delta)) 55 | # pdb.set_trace() 56 | tmp = np.diag(np.squeeze(delta)) @ Q 57 | T = tmp @ np.diag(np.squeeze(sigma)) 58 | return T 59 | 60 | def IPOT_distance(C, n, m): 61 | T = IPOT(C, n, m) 62 | distance = tf.trace(tf.matmul(C, T, transpose_a=True)) 63 | return distance 64 | 65 | 66 | def shape_list(x): 67 | """Return list of dims, statically where possible.""" 68 | x = tf.convert_to_tensor(x) 69 | # If unknown rank, return dynamic shape 70 | if x.get_shape().dims is None: 71 | return tf.shape(x) 72 | static = x.get_shape().as_list() 73 | shape = tf.shape(x) 74 | ret = [] 75 | for i in range(len(static)): 76 | dim = static[i] 77 | if dim is None: 78 | dim = shape[i] 79 | ret.append(dim) 80 | return ret 81 | 82 | def IPOT_distance2(C, beta=1, t_steps=10, k_steps=1): 83 | b, n, m = shape_list(C) 84 | sigma = tf.ones([b, m, 1]) / tf.cast(m, tf.float32) # [b, m, 1] 85 | T = tf.ones([b, n, m]) 86 | A = tf.exp(-C / beta) # [b, n, m] 87 | for t in range(t_steps): 88 | Q = A * T # [b, n, m] 89 | for k in range(k_steps): 90 | delta = 1 / (tf.cast(n, tf.float32) * tf.matmul(Q, sigma)) # [b, n, 1] 91 | sigma = 1 / (tf.cast(m, tf.float32) * tf.matmul(Q, delta, transpose_a=True)) # [b, m, 1] 92 | T = delta * Q * tf.transpose(sigma, [0, 2, 1]) # [b, n, m] 93 | distance = tf.trace(tf.matmul(C, T, transpose_a=True)) 94 | return distance -------------------------------------------------------------------------------- /nmt/utils/standard_hparams_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """standard hparams utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | 25 | def create_standard_hparams(): 26 | return tf.contrib.training.HParams( 27 | # Data 28 | src="", 29 | tgt="", 30 | train_prefix="", 31 | dev_prefix="", 32 | test_prefix="", 33 | vocab_prefix="", 34 | embed_prefix="", 35 | out_dir="", 36 | 37 | # Networks 38 | num_units=512, 39 | num_layers=2, 40 | num_encoder_layers=2, 41 | num_decoder_layers=2, 42 | dropout=0.2, 43 | unit_type="lstm", 44 | encoder_type="bi", 45 | residual=False, 46 | time_major=True, 47 | num_embeddings_partitions=0, 48 | 49 | # Attention mechanisms 50 | attention="scaled_luong", 51 | attention_architecture="standard", 52 | output_attention=True, 53 | pass_hidden_state=True, 54 | 55 | # Train 56 | optimizer="sgd", 57 | batch_size=128, 58 | init_op="uniform", 59 | init_weight=0.1, 60 | max_gradient_norm=5.0, 61 | learning_rate=1.0, 62 | warmup_steps=0, 63 | warmup_scheme="t2t", 64 | decay_scheme="luong234", 65 | colocate_gradients_with_ops=True, 66 | num_train_steps=12000, 67 | 68 | # Data constraints 69 | num_buckets=5, 70 | max_train=0, 71 | src_max_len=50, 72 | tgt_max_len=50, 73 | src_max_len_infer=0, 74 | tgt_max_len_infer=0, 75 | 76 | # Data format 77 | sos="", 78 | eos="", 79 | subword_option="", 80 | check_special_token=True, 81 | 82 | # Misc 83 | forget_bias=1.0, 84 | num_gpus=1, 85 | epoch_step=0, # record where we were within an epoch. 86 | steps_per_stats=100, 87 | steps_per_external_eval=0, 88 | share_vocab=False, 89 | metrics=["bleu"], 90 | log_device_placement=False, 91 | random_seed=None, 92 | # only enable beam search during inference when beam_width > 0. 93 | beam_width=0, 94 | length_penalty_weight=0.0, 95 | override_loaded_hparams=True, 96 | num_keep_ckpts=5, 97 | avg_ckpts=False, 98 | 99 | # For inference 100 | inference_indices=None, 101 | infer_batch_size=32, 102 | sampling_temperature=0.0, 103 | num_translations_per_input=1, 104 | ) 105 | -------------------------------------------------------------------------------- /nmt/nmt_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for nmt.py, train.py and inference.py.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import os 23 | 24 | import tensorflow as tf 25 | 26 | from . import inference 27 | from . import nmt 28 | from . import train 29 | 30 | 31 | def _update_flags(flags, test_name): 32 | """Update flags for basic training.""" 33 | flags.num_train_steps = 100 34 | flags.steps_per_stats = 5 35 | flags.src = "en" 36 | flags.tgt = "vi" 37 | flags.train_prefix = ("nmt/testdata/" 38 | "iwslt15.tst2013.100") 39 | flags.vocab_prefix = ("nmt/testdata/" 40 | "iwslt15.vocab.100") 41 | flags.dev_prefix = ("nmt/testdata/" 42 | "iwslt15.tst2013.100") 43 | flags.test_prefix = ("nmt/testdata/" 44 | "iwslt15.tst2013.100") 45 | flags.out_dir = os.path.join(tf.test.get_temp_dir(), test_name) 46 | 47 | 48 | class NMTTest(tf.test.TestCase): 49 | 50 | def testTrain(self): 51 | """Test the training loop is functional with basic hparams.""" 52 | nmt_parser = argparse.ArgumentParser() 53 | nmt.add_arguments(nmt_parser) 54 | FLAGS, unparsed = nmt_parser.parse_known_args() 55 | 56 | _update_flags(FLAGS, "nmt_train_test") 57 | 58 | default_hparams = nmt.create_hparams(FLAGS) 59 | 60 | train_fn = train.train 61 | nmt.run_main(FLAGS, default_hparams, train_fn, None) 62 | 63 | 64 | def testTrainWithAvgCkpts(self): 65 | """Test the training loop is functional with basic hparams.""" 66 | nmt_parser = argparse.ArgumentParser() 67 | nmt.add_arguments(nmt_parser) 68 | FLAGS, unparsed = nmt_parser.parse_known_args() 69 | 70 | _update_flags(FLAGS, "nmt_train_test_avg_ckpts") 71 | FLAGS.avg_ckpts = True 72 | 73 | default_hparams = nmt.create_hparams(FLAGS) 74 | 75 | train_fn = train.train 76 | nmt.run_main(FLAGS, default_hparams, train_fn, None) 77 | 78 | 79 | def testInference(self): 80 | """Test inference is function with basic hparams.""" 81 | nmt_parser = argparse.ArgumentParser() 82 | nmt.add_arguments(nmt_parser) 83 | FLAGS, unparsed = nmt_parser.parse_known_args() 84 | 85 | _update_flags(FLAGS, "nmt_train_infer") 86 | 87 | # Train one step so we have a checkpoint. 88 | FLAGS.num_train_steps = 1 89 | default_hparams = nmt.create_hparams(FLAGS) 90 | train_fn = train.train 91 | nmt.run_main(FLAGS, default_hparams, train_fn, None) 92 | 93 | # Update FLAGS for inference. 94 | FLAGS.inference_input_file = ("nmt/testdata/" 95 | "iwslt15.tst2013.100.en") 96 | FLAGS.inference_output_file = os.path.join(FLAGS.out_dir, "output") 97 | FLAGS.inference_ref_file = ("nmt/testdata/" 98 | "iwslt15.tst2013.100.vi") 99 | 100 | default_hparams = nmt.create_hparams(FLAGS) 101 | 102 | inference_fn = inference.inference 103 | nmt.run_main(FLAGS, default_hparams, None, inference_fn) 104 | 105 | 106 | if __name__ == "__main__": 107 | tf.test.main() 108 | -------------------------------------------------------------------------------- /nmt/utils/nmt_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions specifically for NMT.""" 17 | from __future__ import print_function 18 | 19 | import codecs 20 | import time 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from ..utils import evaluation_utils 25 | from ..utils import misc_utils as utils 26 | 27 | __all__ = ["decode_and_evaluate", "get_translation"] 28 | 29 | 30 | def decode_and_evaluate(name, 31 | model, 32 | sess, 33 | trans_file, 34 | ref_file, 35 | metrics, 36 | subword_option, 37 | beam_width, 38 | tgt_eos, 39 | num_translations_per_input=1, 40 | decode=True): 41 | """Decode a test set and compute a score according to the evaluation task.""" 42 | # Decode 43 | if decode: 44 | utils.print_out(" decoding to output %s." % trans_file) 45 | 46 | start_time = time.time() 47 | num_sentences = 0 48 | with codecs.getwriter("utf-8")( 49 | tf.gfile.GFile(trans_file, mode="wb")) as trans_f: 50 | trans_f.write("") # Write empty string to ensure file is created. 51 | 52 | num_translations_per_input = max( 53 | min(num_translations_per_input, beam_width), 1) 54 | while True: 55 | try: 56 | nmt_outputs, _ = model.decode(sess) 57 | if beam_width == 0: 58 | nmt_outputs = np.expand_dims(nmt_outputs, 0) 59 | 60 | batch_size = nmt_outputs.shape[1] 61 | num_sentences += batch_size 62 | 63 | for sent_id in range(batch_size): 64 | for beam_id in range(num_translations_per_input): 65 | translation = get_translation( 66 | nmt_outputs[beam_id], 67 | sent_id, 68 | tgt_eos=tgt_eos, 69 | subword_option=subword_option) 70 | trans_f.write((translation + b"\n").decode("utf-8")) 71 | except tf.errors.OutOfRangeError: 72 | utils.print_time( 73 | " done, num sentences %d, num translations per input %d" % 74 | (num_sentences, num_translations_per_input), start_time) 75 | break 76 | 77 | # Evaluation 78 | evaluation_scores = {} 79 | if ref_file and tf.gfile.Exists(trans_file): 80 | for metric in metrics: 81 | score = evaluation_utils.evaluate( 82 | ref_file, 83 | trans_file, 84 | metric, 85 | subword_option=subword_option) 86 | evaluation_scores[metric] = score 87 | utils.print_out(" %s %s: %.1f" % (metric, name, score)) 88 | 89 | return evaluation_scores 90 | 91 | 92 | def get_translation(nmt_outputs, sent_id, tgt_eos, subword_option): 93 | """Given batch decoding outputs, select a sentence and turn to text.""" 94 | if tgt_eos: tgt_eos = tgt_eos.encode("utf-8") 95 | # Select a sentence 96 | output = nmt_outputs[sent_id, :].tolist() 97 | 98 | # If there is an eos symbol in outputs, cut them at that point. 99 | if tgt_eos and tgt_eos in output: 100 | output = output[:output.index(tgt_eos)] 101 | 102 | if subword_option == "bpe": # BPE 103 | translation = utils.format_bpe_text(output) 104 | elif subword_option == "spm": # SPM 105 | translation = utils.format_spm_text(output) 106 | else: 107 | translation = utils.format_text(output) 108 | 109 | return translation 110 | -------------------------------------------------------------------------------- /texar/utils/raml_samples_generation/vocab.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | from collections import Counter 4 | from itertools import chain 5 | 6 | import torch 7 | 8 | from util import read_corpus 9 | 10 | 11 | class VocabEntry(object): 12 | def __init__(self): 13 | self.word2id = dict() 14 | self.unk_id = 3 15 | self.word2id[''] = 0 16 | self.word2id[''] = 1 17 | self.word2id[''] = 2 18 | self.word2id[''] = 3 19 | 20 | self.id2word = {v: k for k, v in self.word2id.iteritems()} 21 | 22 | def __getitem__(self, word): 23 | return self.word2id.get(word, self.unk_id) 24 | 25 | def __contains__(self, word): 26 | return word in self.word2id 27 | 28 | def __setitem__(self, key, value): 29 | raise ValueError('vocabulary is readonly') 30 | 31 | def __len__(self): 32 | return len(self.word2id) 33 | 34 | def __repr__(self): 35 | return 'Vocabulary[size=%d]' % len(self) 36 | 37 | def id2word(self, wid): 38 | return self.id2word[wid] 39 | 40 | def add(self, word): 41 | if word not in self: 42 | wid = self.word2id[word] = len(self) 43 | self.id2word[wid] = word 44 | return wid 45 | else: 46 | return self[word] 47 | 48 | @staticmethod 49 | def from_corpus(corpus, size, remove_singleton=True): 50 | vocab_entry = VocabEntry() 51 | 52 | word_freq = Counter(chain(*corpus)) 53 | non_singletons = [w for w in word_freq if word_freq[w] > 1] 54 | print('number of word types: %d, number of word types w/ frequency > 1: %d' % (len(word_freq), 55 | len(non_singletons))) 56 | 57 | top_k_words = sorted(word_freq.keys(), reverse=True, key=word_freq.get)[:size] 58 | 59 | for word in top_k_words: 60 | if len(vocab_entry) < size: 61 | if not (word_freq[word] == 1 and remove_singleton): 62 | vocab_entry.add(word) 63 | 64 | return vocab_entry 65 | 66 | 67 | class Vocab(object): 68 | def __init__(self, src_sents, tgt_sents, src_vocab_size, tgt_vocab_size, remove_singleton=True): 69 | assert len(src_sents) == len(tgt_sents) 70 | 71 | print('initialize source vocabulary ..') 72 | self.src = VocabEntry.from_corpus(src_sents, src_vocab_size, remove_singleton=remove_singleton) 73 | 74 | print('initialize target vocabulary ..') 75 | self.tgt = VocabEntry.from_corpus(tgt_sents, tgt_vocab_size, remove_singleton=remove_singleton) 76 | 77 | def __repr__(self): 78 | return 'Vocab(source %d words, target %d words)' % (len(self.src), len(self.tgt)) 79 | 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--src_vocab_size', default=50000, type=int, help='source vocabulary size') 84 | parser.add_argument('--tgt_vocab_size', default=50000, type=int, help='target vocabulary size') 85 | parser.add_argument('--include_singleton', action='store_true', default=False, help='whether to include singleton' 86 | 'in the vocabulary (default=False)') 87 | 88 | parser.add_argument('--train_src', type=str, required=True, help='file of source sentences') 89 | parser.add_argument('--train_tgt', type=str, required=True, help='file of target sentences') 90 | 91 | parser.add_argument('--output', default='vocab.bin', type=str, help='output vocabulary file') 92 | 93 | args = parser.parse_args() 94 | 95 | print('read in source sentences: %s' % args.train_src) 96 | print('read in target sentences: %s' % args.train_tgt) 97 | 98 | src_sents = read_corpus(args.train_src, source='src') 99 | tgt_sents = read_corpus(args.train_tgt, source='tgt') 100 | 101 | vocab = Vocab(src_sents, tgt_sents, args.src_vocab_size, args.tgt_vocab_size, remove_singleton=not args.include_singleton) 102 | print('generated vocabulary, source %d words, target %d words' % (len(vocab.src), len(vocab.tgt))) 103 | 104 | torch.save(vocab, args.output) 105 | print('vocabulary saved to %s' % args.output) -------------------------------------------------------------------------------- /nmt/utils/vocab_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility to handle vocabularies.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import codecs 23 | import os 24 | import tensorflow as tf 25 | 26 | from tensorflow.python.ops import lookup_ops 27 | 28 | from ..utils import misc_utils as utils 29 | 30 | 31 | UNK = "" 32 | SOS = "" 33 | EOS = "" 34 | UNK_ID = 0 35 | 36 | 37 | def load_vocab(vocab_file): 38 | vocab = [] 39 | with codecs.getreader("utf-8")(tf.gfile.GFile(vocab_file, "rb")) as f: 40 | vocab_size = 0 41 | for word in f: 42 | vocab_size += 1 43 | vocab.append(word.strip()) 44 | return vocab, vocab_size 45 | 46 | 47 | def check_vocab(vocab_file, out_dir, check_special_token=True, sos=None, 48 | eos=None, unk=None): 49 | """Check if vocab_file doesn't exist, create from corpus_file.""" 50 | if tf.gfile.Exists(vocab_file): 51 | utils.print_out("# Vocab file %s exists" % vocab_file) 52 | vocab, vocab_size = load_vocab(vocab_file) 53 | if check_special_token: 54 | # Verify if the vocab starts with unk, sos, eos 55 | # If not, prepend those tokens & generate a new vocab file 56 | if not unk: unk = UNK 57 | if not sos: sos = SOS 58 | if not eos: eos = EOS 59 | assert len(vocab) >= 3 60 | if vocab[0] != unk or vocab[1] != sos or vocab[2] != eos: 61 | utils.print_out("The first 3 vocab words [%s, %s, %s]" 62 | " are not [%s, %s, %s]" % 63 | (vocab[0], vocab[1], vocab[2], unk, sos, eos)) 64 | vocab = [unk, sos, eos] + vocab 65 | vocab_size += 3 66 | new_vocab_file = os.path.join(out_dir, os.path.basename(vocab_file)) 67 | with codecs.getwriter("utf-8")( 68 | tf.gfile.GFile(new_vocab_file, "wb")) as f: 69 | for word in vocab: 70 | f.write("%s\n" % word) 71 | vocab_file = new_vocab_file 72 | else: 73 | raise ValueError("vocab_file '%s' does not exist." % vocab_file) 74 | 75 | vocab_size = len(vocab) 76 | return vocab_size, vocab_file 77 | 78 | 79 | def create_vocab_tables(src_vocab_file, tgt_vocab_file, share_vocab): 80 | """Creates vocab tables for src_vocab_file and tgt_vocab_file.""" 81 | src_vocab_table = lookup_ops.index_table_from_file( 82 | src_vocab_file, default_value=UNK_ID) 83 | if share_vocab: 84 | tgt_vocab_table = src_vocab_table 85 | else: 86 | tgt_vocab_table = lookup_ops.index_table_from_file( 87 | tgt_vocab_file, default_value=UNK_ID) 88 | return src_vocab_table, tgt_vocab_table 89 | 90 | 91 | def load_embed_txt(embed_file): 92 | """Load embed_file into a python dictionary. 93 | 94 | Note: the embed_file should be a Glove formated txt file. Assuming 95 | embed_size=5, for example: 96 | 97 | the -0.071549 0.093459 0.023738 -0.090339 0.056123 98 | to 0.57346 0.5417 -0.23477 -0.3624 0.4037 99 | and 0.20327 0.47348 0.050877 0.002103 0.060547 100 | 101 | Args: 102 | embed_file: file path to the embedding file. 103 | Returns: 104 | a dictionary that maps word to vector, and the size of embedding dimensions. 105 | """ 106 | emb_dict = dict() 107 | emb_size = None 108 | with codecs.getreader("utf-8")(tf.gfile.GFile(embed_file, 'rb')) as f: 109 | for line in f: 110 | tokens = line.strip().split(" ") 111 | word = tokens[0] 112 | vec = list(map(float, tokens[1:])) 113 | emb_dict[word] = vec 114 | if emb_size: 115 | assert emb_size == len(vec), "All embedding size should be same." 116 | else: 117 | emb_size = len(vec) 118 | return emb_dict, emb_size 119 | -------------------------------------------------------------------------------- /nmt/scripts/bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Python implementation of BLEU and smooth-BLEU. 17 | 18 | This module provides a Python implementation of BLEU and smooth-BLEU. 19 | Smooth BLEU is computed following the method outlined in the paper: 20 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 21 | evaluation metrics for machine translation. COLING 2004. 22 | """ 23 | 24 | import collections 25 | import math 26 | 27 | 28 | def _get_ngrams(segment, max_order): 29 | """Extracts all n-grams upto a given maximum order from an input segment. 30 | 31 | Args: 32 | segment: text segment from which n-grams will be extracted. 33 | max_order: maximum length in tokens of the n-grams returned by this 34 | methods. 35 | 36 | Returns: 37 | The Counter containing all n-grams upto max_order in segment 38 | with a count of how many times each n-gram occurred. 39 | """ 40 | ngram_counts = collections.Counter() 41 | for order in range(1, max_order + 1): 42 | for i in range(0, len(segment) - order + 1): 43 | ngram = tuple(segment[i:i+order]) 44 | ngram_counts[ngram] += 1 45 | return ngram_counts 46 | 47 | 48 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 49 | smooth=False): 50 | """Computes BLEU score of translated segments against one or more references. 51 | 52 | Args: 53 | reference_corpus: list of lists of references for each translation. Each 54 | reference should be tokenized into a list of tokens. 55 | translation_corpus: list of translations to score. Each translation 56 | should be tokenized into a list of tokens. 57 | max_order: Maximum n-gram order to use when computing BLEU score. 58 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 59 | 60 | Returns: 61 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 62 | precisions and brevity penalty. 63 | """ 64 | matches_by_order = [0] * max_order 65 | possible_matches_by_order = [0] * max_order 66 | reference_length = 0 67 | translation_length = 0 68 | for (references, translation) in zip(reference_corpus, 69 | translation_corpus): 70 | reference_length += min(len(r) for r in references) 71 | translation_length += len(translation) 72 | 73 | merged_ref_ngram_counts = collections.Counter() 74 | for reference in references: 75 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 76 | translation_ngram_counts = _get_ngrams(translation, max_order) 77 | overlap = translation_ngram_counts & merged_ref_ngram_counts 78 | for ngram in overlap: 79 | matches_by_order[len(ngram)-1] += overlap[ngram] 80 | for order in range(1, max_order+1): 81 | possible_matches = len(translation) - order + 1 82 | if possible_matches > 0: 83 | possible_matches_by_order[order-1] += possible_matches 84 | 85 | precisions = [0] * max_order 86 | for i in range(0, max_order): 87 | if smooth: 88 | precisions[i] = ((matches_by_order[i] + 1.) / 89 | (possible_matches_by_order[i] + 1.)) 90 | else: 91 | if possible_matches_by_order[i] > 0: 92 | precisions[i] = (float(matches_by_order[i]) / 93 | possible_matches_by_order[i]) 94 | else: 95 | precisions[i] = 0.0 96 | 97 | if min(precisions) > 0: 98 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 99 | geo_mean = math.exp(p_log_sum) 100 | else: 101 | geo_mean = 0 102 | 103 | ratio = float(translation_length) / reference_length 104 | 105 | if ratio > 1.0: 106 | bp = 1. 107 | else: 108 | bp = math.exp(1 - 1. / ratio) 109 | 110 | bleu = geo_mean * bp 111 | 112 | return (bleu, precisions, bp, ratio, translation_length, reference_length) 113 | -------------------------------------------------------------------------------- /nmt/utils/common_test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Common utility functions for tests.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from tensorflow.python.ops import lookup_ops 25 | 26 | from ..utils import iterator_utils 27 | from ..utils import standard_hparams_utils 28 | 29 | 30 | def create_test_hparams(unit_type="lstm", 31 | encoder_type="uni", 32 | num_layers=4, 33 | attention="", 34 | attention_architecture=None, 35 | use_residual=False, 36 | inference_indices=None, 37 | num_translations_per_input=1, 38 | beam_width=0, 39 | init_op="uniform"): 40 | """Create training and inference test hparams.""" 41 | num_residual_layers = 0 42 | if use_residual: 43 | # TODO(rzhao): Put num_residual_layers computation logic into 44 | # `model_utils.py`, so we can also test it here. 45 | num_residual_layers = 2 46 | 47 | standard_hparams = standard_hparams_utils.create_standard_hparams() 48 | 49 | # Networks 50 | standard_hparams.num_units = 5 51 | standard_hparams.num_encoder_layers = num_layers 52 | standard_hparams.num_decoder_layers = num_layers 53 | standard_hparams.dropout = 0.5 54 | standard_hparams.unit_type = unit_type 55 | standard_hparams.encoder_type = encoder_type 56 | standard_hparams.residual = use_residual 57 | standard_hparams.num_residual_layers = num_residual_layers 58 | 59 | # Attention mechanisms 60 | standard_hparams.attention = attention 61 | standard_hparams.attention_architecture = attention_architecture 62 | 63 | # Train 64 | standard_hparams.init_op = init_op 65 | standard_hparams.num_train_steps = 1 66 | standard_hparams.decay_scheme = "" 67 | 68 | # Infer 69 | standard_hparams.tgt_max_len_infer = 100 70 | standard_hparams.beam_width = beam_width 71 | standard_hparams.num_translations_per_input = num_translations_per_input 72 | 73 | # Misc 74 | standard_hparams.forget_bias = 0.0 75 | standard_hparams.random_seed = 3 76 | 77 | # Vocab 78 | standard_hparams.src_vocab_size = 5 79 | standard_hparams.tgt_vocab_size = 5 80 | standard_hparams.eos = "eos" 81 | standard_hparams.sos = "sos" 82 | standard_hparams.src_vocab_file = "" 83 | standard_hparams.tgt_vocab_file = "" 84 | standard_hparams.src_embed_file = "" 85 | standard_hparams.tgt_embed_file = "" 86 | 87 | # For inference.py test 88 | standard_hparams.subword_option = "bpe" 89 | standard_hparams.src = "src" 90 | standard_hparams.tgt = "tgt" 91 | standard_hparams.src_max_len = 400 92 | standard_hparams.tgt_eos_id = 0 93 | standard_hparams.inference_indices = inference_indices 94 | return standard_hparams 95 | 96 | 97 | def create_test_iterator(hparams, mode): 98 | """Create test iterator.""" 99 | src_vocab_table = lookup_ops.index_table_from_tensor( 100 | tf.constant([hparams.eos, "a", "b", "c", "d"])) 101 | tgt_vocab_mapping = tf.constant([hparams.sos, hparams.eos, "a", "b", "c"]) 102 | tgt_vocab_table = lookup_ops.index_table_from_tensor(tgt_vocab_mapping) 103 | if mode == tf.contrib.learn.ModeKeys.INFER: 104 | reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_tensor( 105 | tgt_vocab_mapping) 106 | 107 | src_dataset = tf.data.Dataset.from_tensor_slices( 108 | tf.constant(["a a b b c", "a b b"])) 109 | 110 | if mode != tf.contrib.learn.ModeKeys.INFER: 111 | tgt_dataset = tf.data.Dataset.from_tensor_slices( 112 | tf.constant(["a b c b c", "a b c b"])) 113 | return ( 114 | iterator_utils.get_iterator( 115 | src_dataset=src_dataset, 116 | tgt_dataset=tgt_dataset, 117 | src_vocab_table=src_vocab_table, 118 | tgt_vocab_table=tgt_vocab_table, 119 | batch_size=hparams.batch_size, 120 | sos=hparams.sos, 121 | eos=hparams.eos, 122 | random_seed=hparams.random_seed, 123 | num_buckets=hparams.num_buckets), 124 | src_vocab_table, 125 | tgt_vocab_table) 126 | else: 127 | return ( 128 | iterator_utils.get_infer_iterator( 129 | src_dataset=src_dataset, 130 | src_vocab_table=src_vocab_table, 131 | eos=hparams.eos, 132 | batch_size=hparams.batch_size), 133 | src_vocab_table, 134 | tgt_vocab_table, 135 | reverse_tgt_vocab_table) 136 | -------------------------------------------------------------------------------- /nmt/utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Generally useful utility functions.""" 17 | from __future__ import print_function 18 | 19 | import codecs 20 | import collections 21 | import json 22 | import math 23 | import os 24 | import sys 25 | import time 26 | 27 | import numpy as np 28 | import tensorflow as tf 29 | 30 | 31 | def check_tensorflow_version(): 32 | min_tf_version = "1.4.0-dev20171024" 33 | if tf.__version__ < min_tf_version: 34 | raise EnvironmentError("Tensorflow version must >= %s" % min_tf_version) 35 | 36 | 37 | def safe_exp(value): 38 | """Exponentiation with catching of overflow error.""" 39 | try: 40 | ans = math.exp(value) 41 | except OverflowError: 42 | ans = float("inf") 43 | return ans 44 | 45 | 46 | def print_time(s, start_time): 47 | """Take a start time, print elapsed duration, and return a new time.""" 48 | print("%s, time %ds, %s." % (s, (time.time() - start_time), time.ctime())) 49 | sys.stdout.flush() 50 | return time.time() 51 | 52 | 53 | def print_out(s, f=None, new_line=True): 54 | """Similar to print but with support to flush and output to a file.""" 55 | if isinstance(s, bytes): 56 | s = s.decode("utf-8") 57 | 58 | if f: 59 | f.write(s.encode("utf-8")) 60 | if new_line: 61 | f.write(b"\n") 62 | 63 | # stdout 64 | out_s = s.encode("utf-8") 65 | if not isinstance(out_s, str): 66 | out_s = out_s.decode("utf-8") 67 | print(out_s, end="", file=sys.stdout) 68 | 69 | if new_line: 70 | sys.stdout.write("\n") 71 | sys.stdout.flush() 72 | 73 | 74 | def print_hparams(hparams, skip_patterns=None, header=None): 75 | """Print hparams, can skip keys based on pattern.""" 76 | if header: print_out("%s" % header) 77 | values = hparams.values() 78 | for key in sorted(values.keys()): 79 | if not skip_patterns or all( 80 | [skip_pattern not in key for skip_pattern in skip_patterns]): 81 | print_out(" %s=%s" % (key, str(values[key]))) 82 | 83 | 84 | def load_hparams(model_dir): 85 | """Load hparams from an existing model directory.""" 86 | hparams_file = os.path.join(model_dir, "hparams") 87 | if tf.gfile.Exists(hparams_file): 88 | print_out("# Loading hparams from %s" % hparams_file) 89 | with codecs.getreader("utf-8")(tf.gfile.GFile(hparams_file, "rb")) as f: 90 | try: 91 | hparams_values = json.load(f) 92 | hparams = tf.contrib.training.HParams(**hparams_values) 93 | except ValueError: 94 | print_out(" can't load hparams file") 95 | return None 96 | return hparams 97 | else: 98 | return None 99 | 100 | 101 | def maybe_parse_standard_hparams(hparams, hparams_path): 102 | """Override hparams values with existing standard hparams config.""" 103 | if not hparams_path: 104 | return hparams 105 | 106 | if tf.gfile.Exists(hparams_path): 107 | print_out("# Loading standard hparams from %s" % hparams_path) 108 | with tf.gfile.GFile(hparams_path, "r") as f: 109 | hparams.parse_json(f.read()) 110 | 111 | return hparams 112 | 113 | 114 | def save_hparams(out_dir, hparams): 115 | """Save hparams.""" 116 | hparams_file = os.path.join(out_dir, "hparams") 117 | print_out(" saving hparams to %s" % hparams_file) 118 | with codecs.getwriter("utf-8")(tf.gfile.GFile(hparams_file, "wb")) as f: 119 | f.write(hparams.to_json()) 120 | 121 | 122 | def debug_tensor(s, msg=None, summarize=10): 123 | """Print the shape and value of a tensor at test time. Return a new tensor.""" 124 | if not msg: 125 | msg = s.name 126 | return tf.Print(s, [tf.shape(s), s], msg + " ", summarize=summarize) 127 | 128 | 129 | def add_summary(summary_writer, global_step, tag, value): 130 | """Add a new summary to the current summary_writer. 131 | Useful to log things that are not part of the training graph, e.g., tag=BLEU. 132 | """ 133 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 134 | summary_writer.add_summary(summary, global_step) 135 | 136 | 137 | def get_config_proto(log_device_placement=False, allow_soft_placement=True, 138 | num_intra_threads=0, num_inter_threads=0): 139 | # GPU options: 140 | # https://www.tensorflow.org/versions/r0.10/how_tos/using_gpu/index.html 141 | config_proto = tf.ConfigProto( 142 | log_device_placement=log_device_placement, 143 | allow_soft_placement=allow_soft_placement) 144 | config_proto.gpu_options.allow_growth = True 145 | 146 | # CPU threads options 147 | if num_intra_threads: 148 | config_proto.intra_op_parallelism_threads = num_intra_threads 149 | if num_inter_threads: 150 | config_proto.inter_op_parallelism_threads = num_inter_threads 151 | 152 | return config_proto 153 | 154 | 155 | def format_text(words): 156 | """Convert a sequence words into sentence.""" 157 | if (not hasattr(words, "__len__") and # for numpy array 158 | not isinstance(words, collections.Iterable)): 159 | words = [words] 160 | return b" ".join(words) 161 | 162 | 163 | def format_bpe_text(symbols, delimiter=b"@@"): 164 | """Convert a sequence of bpe words into sentence.""" 165 | words = [] 166 | word = b"" 167 | if isinstance(symbols, str): 168 | symbols = symbols.encode() 169 | delimiter_len = len(delimiter) 170 | for symbol in symbols: 171 | if len(symbol) >= delimiter_len and symbol[-delimiter_len:] == delimiter: 172 | word += symbol[:-delimiter_len] 173 | else: # end of a word 174 | word += symbol 175 | words.append(word) 176 | word = b"" 177 | return b" ".join(words) 178 | 179 | 180 | def format_spm_text(symbols): 181 | """Decode a text in SPM (https://github.com/google/sentencepiece) format.""" 182 | return u"".join(format_text(symbols).decode("utf-8").split()).replace( 183 | u"\u2581", u" ").strip().encode("utf-8") 184 | -------------------------------------------------------------------------------- /nmt/utils/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility for evaluating various tasks, e.g., translation & summarization.""" 17 | import codecs 18 | import os 19 | import re 20 | import subprocess 21 | 22 | import tensorflow as tf 23 | 24 | from ..scripts import bleu 25 | from ..scripts import rouge 26 | 27 | 28 | __all__ = ["evaluate"] 29 | 30 | 31 | def evaluate(ref_file, trans_file, metric, subword_option=None): 32 | """Pick a metric and evaluate depending on task.""" 33 | # BLEU scores for translation task 34 | if metric.lower() == "bleu": 35 | evaluation_score = _bleu(ref_file, trans_file, 36 | subword_option=subword_option) 37 | # ROUGE scores for summarization tasks 38 | elif metric.lower() == "rouge": 39 | evaluation_score = _rouge(ref_file, trans_file, 40 | subword_option=subword_option) 41 | elif metric.lower() == "accuracy": 42 | evaluation_score = _accuracy(ref_file, trans_file) 43 | elif metric.lower() == "word_accuracy": 44 | evaluation_score = _word_accuracy(ref_file, trans_file) 45 | else: 46 | raise ValueError("Unknown metric %s" % metric) 47 | 48 | return evaluation_score 49 | 50 | 51 | def _clean(sentence, subword_option): 52 | """Clean and handle BPE or SPM outputs.""" 53 | sentence = sentence.strip() 54 | 55 | # BPE 56 | if subword_option == "bpe": 57 | sentence = re.sub("@@ ", "", sentence) 58 | 59 | # SPM 60 | elif subword_option == "spm": 61 | sentence = u"".join(sentence.split()).replace(u"\u2581", u" ").lstrip() 62 | 63 | return sentence 64 | 65 | 66 | # Follow //transconsole/localization/machine_translation/metrics/bleu_calc.py 67 | def _bleu(ref_file, trans_file, subword_option=None): 68 | """Compute BLEU scores and handling BPE.""" 69 | max_order = 4 70 | smooth = False 71 | 72 | ref_files = [ref_file] 73 | reference_text = [] 74 | for reference_filename in ref_files: 75 | with codecs.getreader("utf-8")( 76 | tf.gfile.GFile(reference_filename, "rb")) as fh: 77 | reference_text.append(fh.readlines()) 78 | 79 | per_segment_references = [] 80 | for references in zip(*reference_text): 81 | reference_list = [] 82 | for reference in references: 83 | reference = _clean(reference, subword_option) 84 | reference_list.append(reference.split(" ")) 85 | per_segment_references.append(reference_list) 86 | 87 | translations = [] 88 | with codecs.getreader("utf-8")(tf.gfile.GFile(trans_file, "rb")) as fh: 89 | for line in fh: 90 | line = _clean(line, subword_option=None) 91 | translations.append(line.split(" ")) 92 | 93 | # bleu_score, precisions, bp, ratio, translation_length, reference_length 94 | bleu_score, _, _, _, _, _ = bleu.compute_bleu( 95 | per_segment_references, translations, max_order, smooth) 96 | return 100 * bleu_score 97 | 98 | 99 | def _rouge(ref_file, summarization_file, subword_option=None): 100 | """Compute ROUGE scores and handling BPE.""" 101 | 102 | references = [] 103 | with codecs.getreader("utf-8")(tf.gfile.GFile(ref_file, "rb")) as fh: 104 | for line in fh: 105 | references.append(_clean(line, subword_option)) 106 | 107 | hypotheses = [] 108 | with codecs.getreader("utf-8")( 109 | tf.gfile.GFile(summarization_file, "rb")) as fh: 110 | for line in fh: 111 | hypotheses.append(_clean(line, subword_option=None)) 112 | 113 | rouge_score_map = rouge.rouge(hypotheses, references) 114 | return 100 * rouge_score_map["rouge_l/f_score"] 115 | 116 | 117 | def _accuracy(label_file, pred_file): 118 | """Compute accuracy, each line contains a label.""" 119 | 120 | with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "rb")) as label_fh: 121 | with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "rb")) as pred_fh: 122 | count = 0.0 123 | match = 0.0 124 | for label in label_fh: 125 | label = label.strip() 126 | pred = pred_fh.readline().strip() 127 | if label == pred: 128 | match += 1 129 | count += 1 130 | return 100 * match / count 131 | 132 | 133 | def _word_accuracy(label_file, pred_file): 134 | """Compute accuracy on per word basis.""" 135 | 136 | with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "r")) as label_fh: 137 | with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "r")) as pred_fh: 138 | total_acc, total_count = 0., 0. 139 | for sentence in label_fh: 140 | labels = sentence.strip().split(" ") 141 | preds = pred_fh.readline().strip().split(" ") 142 | match = 0.0 143 | for pos in range(min(len(labels), len(preds))): 144 | label = labels[pos] 145 | pred = preds[pos] 146 | if label == pred: 147 | match += 1 148 | total_acc += 100 * match / max(len(labels), len(preds)) 149 | total_count += 1 150 | return total_acc / total_count 151 | 152 | 153 | def _moses_bleu(multi_bleu_script, tgt_test, trans_file, subword_option=None): 154 | """Compute BLEU scores using Moses multi-bleu.perl script.""" 155 | 156 | # TODO(thangluong): perform rewrite using python 157 | # BPE 158 | if subword_option == "bpe": 159 | debpe_tgt_test = tgt_test + ".debpe" 160 | if not os.path.exists(debpe_tgt_test): 161 | # TODO(thangluong): not use shell=True, can be a security hazard 162 | subprocess.call("cp %s %s" % (tgt_test, debpe_tgt_test), shell=True) 163 | subprocess.call("sed s/@@ //g %s" % (debpe_tgt_test), 164 | shell=True) 165 | tgt_test = debpe_tgt_test 166 | elif subword_option == "spm": 167 | despm_tgt_test = tgt_test + ".despm" 168 | if not os.path.exists(despm_tgt_test): 169 | subprocess.call("cp %s %s" % (tgt_test, despm_tgt_test)) 170 | subprocess.call("sed s/ //g %s" % (despm_tgt_test)) 171 | subprocess.call(u"sed s/^\u2581/g %s" % (despm_tgt_test)) 172 | subprocess.call(u"sed s/\u2581/ /g %s" % (despm_tgt_test)) 173 | tgt_test = despm_tgt_test 174 | cmd = "%s %s < %s" % (multi_bleu_script, tgt_test, trans_file) 175 | 176 | # subprocess 177 | # TODO(thangluong): not use shell=True, can be a security hazard 178 | bleu_output = subprocess.check_output(cmd, shell=True) 179 | 180 | # extract BLEU score 181 | m = re.search("BLEU = (.+?),", bleu_output) 182 | bleu_score = float(m.group(1)) 183 | 184 | return bleu_score 185 | -------------------------------------------------------------------------------- /nmt/scripts/wmt16_en_de.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2017 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | set -e 18 | 19 | BASE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )" 20 | 21 | OUTPUT_DIR="${1:-wmt16_de_en}" 22 | echo "Writing to ${OUTPUT_DIR}. To change this, set the OUTPUT_DIR environment variable." 23 | 24 | OUTPUT_DIR_DATA="${OUTPUT_DIR}/data" 25 | mkdir -p $OUTPUT_DIR_DATA 26 | 27 | echo "Downloading Europarl v7. This may take a while..." 28 | curl -o ${OUTPUT_DIR_DATA}/europarl-v7-de-en.tgz \ 29 | http://www.statmt.org/europarl/v7/de-en.tgz 30 | 31 | echo "Downloading Common Crawl corpus. This may take a while..." 32 | curl -o ${OUTPUT_DIR_DATA}/common-crawl.tgz \ 33 | http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz 34 | 35 | echo "Downloading News Commentary v11. This may take a while..." 36 | curl -o ${OUTPUT_DIR_DATA}/nc-v11.tgz \ 37 | http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz 38 | 39 | echo "Downloading dev/test sets" 40 | curl -o ${OUTPUT_DIR_DATA}/dev.tgz \ 41 | http://data.statmt.org/wmt16/translation-task/dev.tgz 42 | curl -o ${OUTPUT_DIR_DATA}/test.tgz \ 43 | http://data.statmt.org/wmt16/translation-task/test.tgz 44 | 45 | # Extract everything 46 | echo "Extracting all files..." 47 | mkdir -p "${OUTPUT_DIR_DATA}/europarl-v7-de-en" 48 | tar -xvzf "${OUTPUT_DIR_DATA}/europarl-v7-de-en.tgz" -C "${OUTPUT_DIR_DATA}/europarl-v7-de-en" 49 | mkdir -p "${OUTPUT_DIR_DATA}/common-crawl" 50 | tar -xvzf "${OUTPUT_DIR_DATA}/common-crawl.tgz" -C "${OUTPUT_DIR_DATA}/common-crawl" 51 | mkdir -p "${OUTPUT_DIR_DATA}/nc-v11" 52 | tar -xvzf "${OUTPUT_DIR_DATA}/nc-v11.tgz" -C "${OUTPUT_DIR_DATA}/nc-v11" 53 | mkdir -p "${OUTPUT_DIR_DATA}/dev" 54 | tar -xvzf "${OUTPUT_DIR_DATA}/dev.tgz" -C "${OUTPUT_DIR_DATA}/dev" 55 | mkdir -p "${OUTPUT_DIR_DATA}/test" 56 | tar -xvzf "${OUTPUT_DIR_DATA}/test.tgz" -C "${OUTPUT_DIR_DATA}/test" 57 | 58 | # Concatenate Training data 59 | cat "${OUTPUT_DIR_DATA}/europarl-v7-de-en/europarl-v7.de-en.en" \ 60 | "${OUTPUT_DIR_DATA}/common-crawl/commoncrawl.de-en.en" \ 61 | "${OUTPUT_DIR_DATA}/nc-v11/training-parallel-nc-v11/news-commentary-v11.de-en.en" \ 62 | > "${OUTPUT_DIR}/train.en" 63 | wc -l "${OUTPUT_DIR}/train.en" 64 | 65 | cat "${OUTPUT_DIR_DATA}/europarl-v7-de-en/europarl-v7.de-en.de" \ 66 | "${OUTPUT_DIR_DATA}/common-crawl/commoncrawl.de-en.de" \ 67 | "${OUTPUT_DIR_DATA}/nc-v11/training-parallel-nc-v11/news-commentary-v11.de-en.de" \ 68 | > "${OUTPUT_DIR}/train.de" 69 | wc -l "${OUTPUT_DIR}/train.de" 70 | 71 | # Clone Moses 72 | if [ ! -d "${OUTPUT_DIR}/mosesdecoder" ]; then 73 | echo "Cloning moses for data processing" 74 | git clone https://github.com/moses-smt/mosesdecoder.git "${OUTPUT_DIR}/mosesdecoder" 75 | fi 76 | 77 | # Convert SGM files 78 | # Convert newstest2014 data into raw text format 79 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ 80 | < ${OUTPUT_DIR_DATA}/dev/dev/newstest2014-deen-src.de.sgm \ 81 | > ${OUTPUT_DIR_DATA}/dev/dev/newstest2014.de 82 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ 83 | < ${OUTPUT_DIR_DATA}/dev/dev/newstest2014-deen-ref.en.sgm \ 84 | > ${OUTPUT_DIR_DATA}/dev/dev/newstest2014.en 85 | 86 | # Convert newstest2015 data into raw text format 87 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ 88 | < ${OUTPUT_DIR_DATA}/dev/dev/newstest2015-deen-src.de.sgm \ 89 | > ${OUTPUT_DIR_DATA}/dev/dev/newstest2015.de 90 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ 91 | < ${OUTPUT_DIR_DATA}/dev/dev/newstest2015-deen-ref.en.sgm \ 92 | > ${OUTPUT_DIR_DATA}/dev/dev/newstest2015.en 93 | 94 | # Convert newstest2016 data into raw text format 95 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ 96 | < ${OUTPUT_DIR_DATA}/test/test/newstest2016-deen-src.de.sgm \ 97 | > ${OUTPUT_DIR_DATA}/test/test/newstest2016.de 98 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ 99 | < ${OUTPUT_DIR_DATA}/test/test/newstest2016-deen-ref.en.sgm \ 100 | > ${OUTPUT_DIR_DATA}/test/test/newstest2016.en 101 | 102 | # Copy dev/test data to output dir 103 | cp ${OUTPUT_DIR_DATA}/dev/dev/newstest20*.de ${OUTPUT_DIR} 104 | cp ${OUTPUT_DIR_DATA}/dev/dev/newstest20*.en ${OUTPUT_DIR} 105 | cp ${OUTPUT_DIR_DATA}/test/test/newstest20*.de ${OUTPUT_DIR} 106 | cp ${OUTPUT_DIR_DATA}/test/test/newstest20*.en ${OUTPUT_DIR} 107 | 108 | # Tokenize data 109 | for f in ${OUTPUT_DIR}/*.de; do 110 | echo "Tokenizing $f..." 111 | ${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -q -l de -threads 8 < $f > ${f%.*}.tok.de 112 | done 113 | 114 | for f in ${OUTPUT_DIR}/*.en; do 115 | echo "Tokenizing $f..." 116 | ${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -q -l en -threads 8 < $f > ${f%.*}.tok.en 117 | done 118 | 119 | # Clean train corpora 120 | for f in ${OUTPUT_DIR}/train.tok.en; do 121 | fbase=${f%.*} 122 | echo "Cleaning ${fbase}..." 123 | ${OUTPUT_DIR}/mosesdecoder/scripts/training/clean-corpus-n.perl $fbase de en "${fbase}.clean" 1 80 124 | done 125 | 126 | # Generate Subword Units (BPE) 127 | # Clone Subword NMT 128 | if [ ! -d "${OUTPUT_DIR}/subword-nmt" ]; then 129 | git clone https://github.com/rsennrich/subword-nmt.git "${OUTPUT_DIR}/subword-nmt" 130 | fi 131 | 132 | # Learn Shared BPE 133 | for merge_ops in 32000; do 134 | echo "Learning BPE with merge_ops=${merge_ops}. This may take a while..." 135 | cat "${OUTPUT_DIR}/train.tok.clean.de" "${OUTPUT_DIR}/train.tok.clean.en" | \ 136 | ${OUTPUT_DIR}/subword-nmt/learn_bpe.py -s $merge_ops > "${OUTPUT_DIR}/bpe.${merge_ops}" 137 | 138 | echo "Apply BPE with merge_ops=${merge_ops} to tokenized files..." 139 | for lang in en de; do 140 | for f in ${OUTPUT_DIR}/*.tok.${lang} ${OUTPUT_DIR}/*.tok.clean.${lang}; do 141 | outfile="${f%.*}.bpe.${merge_ops}.${lang}" 142 | ${OUTPUT_DIR}/subword-nmt/apply_bpe.py -c "${OUTPUT_DIR}/bpe.${merge_ops}" < $f > "${outfile}" 143 | echo ${outfile} 144 | done 145 | done 146 | 147 | # Create vocabulary file for BPE 148 | echo -e "\n\n" > "${OUTPUT_DIR}/vocab.bpe.${merge_ops}" 149 | cat "${OUTPUT_DIR}/train.tok.clean.bpe.${merge_ops}.en" "${OUTPUT_DIR}/train.tok.clean.bpe.${merge_ops}.de" | \ 150 | ${OUTPUT_DIR}/subword-nmt/get_vocab.py | cut -f1 -d ' ' >> "${OUTPUT_DIR}/vocab.bpe.${merge_ops}" 151 | 152 | done 153 | 154 | # Duplicate vocab file with language suffix 155 | cp "${OUTPUT_DIR}/vocab.bpe.32000" "${OUTPUT_DIR}/vocab.bpe.32000.en" 156 | cp "${OUTPUT_DIR}/vocab.bpe.32000" "${OUTPUT_DIR}/vocab.bpe.32000.de" 157 | 158 | echo "All done." 159 | -------------------------------------------------------------------------------- /nmt/attention_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Attention-based sequence-to-sequence model with dynamic RNN support.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from . import model 23 | from . import model_helper 24 | 25 | __all__ = ["AttentionModel"] 26 | 27 | 28 | class AttentionModel(model.Model): 29 | """Sequence-to-sequence dynamic model with attention. 30 | 31 | This class implements a multi-layer recurrent neural network as encoder, 32 | and an attention-based decoder. This is the same as the model described in 33 | (Luong et al., EMNLP'2015) paper: https://arxiv.org/pdf/1508.04025v5.pdf. 34 | This class also allows to use GRU cells in addition to LSTM cells with 35 | support for dropout. 36 | """ 37 | 38 | def __init__(self, 39 | hparams, 40 | mode, 41 | iterator, 42 | source_vocab_table, 43 | target_vocab_table, 44 | reverse_target_vocab_table=None, 45 | scope=None, 46 | extra_args=None): 47 | # Set attention_mechanism_fn 48 | if extra_args and extra_args.attention_mechanism_fn: 49 | self.attention_mechanism_fn = extra_args.attention_mechanism_fn 50 | else: 51 | self.attention_mechanism_fn = create_attention_mechanism 52 | 53 | super(AttentionModel, self).__init__( 54 | hparams=hparams, 55 | mode=mode, 56 | iterator=iterator, 57 | source_vocab_table=source_vocab_table, 58 | target_vocab_table=target_vocab_table, 59 | reverse_target_vocab_table=reverse_target_vocab_table, 60 | scope=scope, 61 | extra_args=extra_args) 62 | 63 | if self.mode == tf.contrib.learn.ModeKeys.INFER: 64 | self.infer_summary = self._get_infer_summary(hparams) 65 | 66 | def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, 67 | source_sequence_length): 68 | """Build a RNN cell with attention mechanism that can be used by decoder.""" 69 | attention_option = hparams.attention 70 | attention_architecture = hparams.attention_architecture 71 | 72 | if attention_architecture != "standard": 73 | raise ValueError( 74 | "Unknown attention architecture %s" % attention_architecture) 75 | 76 | num_units = hparams.num_units 77 | num_layers = self.num_decoder_layers 78 | num_residual_layers = self.num_decoder_residual_layers 79 | beam_width = hparams.beam_width 80 | 81 | dtype = tf.float32 82 | 83 | # Ensure memory is batch-major 84 | if self.time_major: 85 | memory = tf.transpose(encoder_outputs, [1, 0, 2]) 86 | else: 87 | memory = encoder_outputs 88 | 89 | if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0: 90 | memory = tf.contrib.seq2seq.tile_batch( 91 | memory, multiplier=beam_width) 92 | source_sequence_length = tf.contrib.seq2seq.tile_batch( 93 | source_sequence_length, multiplier=beam_width) 94 | encoder_state = tf.contrib.seq2seq.tile_batch( 95 | encoder_state, multiplier=beam_width) 96 | batch_size = self.batch_size * beam_width 97 | else: 98 | batch_size = self.batch_size 99 | 100 | attention_mechanism = self.attention_mechanism_fn( 101 | attention_option, num_units, memory, source_sequence_length, self.mode) 102 | 103 | cell = model_helper.create_rnn_cell( 104 | unit_type=hparams.unit_type, 105 | num_units=num_units, 106 | num_layers=num_layers, 107 | num_residual_layers=num_residual_layers, 108 | forget_bias=hparams.forget_bias, 109 | dropout=hparams.dropout, 110 | num_gpus=self.num_gpus, 111 | mode=self.mode, 112 | single_cell_fn=self.single_cell_fn) 113 | 114 | # Only generate alignment in greedy INFER mode. 115 | alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and 116 | beam_width == 0) 117 | cell = tf.contrib.seq2seq.AttentionWrapper( 118 | cell, 119 | attention_mechanism, 120 | attention_layer_size=num_units, 121 | alignment_history=alignment_history, 122 | output_attention=hparams.output_attention, 123 | name="attention") 124 | 125 | # TODO(thangluong): do we need num_layers, num_gpus? 126 | cell = tf.contrib.rnn.DeviceWrapper(cell, 127 | model_helper.get_device_str( 128 | num_layers - 1, self.num_gpus)) 129 | 130 | if hparams.pass_hidden_state: 131 | decoder_initial_state = cell.zero_state(batch_size, dtype).clone( 132 | cell_state=encoder_state) 133 | else: 134 | decoder_initial_state = cell.zero_state(batch_size, dtype) 135 | 136 | return cell, decoder_initial_state 137 | 138 | def _get_infer_summary(self, hparams): 139 | if hparams.beam_width > 0: 140 | return tf.no_op() 141 | return _create_attention_images_summary(self.final_context_state) 142 | 143 | 144 | def create_attention_mechanism(attention_option, num_units, memory, 145 | source_sequence_length, mode): 146 | """Create attention mechanism based on the attention_option.""" 147 | del mode # unused 148 | 149 | # Mechanism 150 | if attention_option == "luong": 151 | attention_mechanism = tf.contrib.seq2seq.LuongAttention( 152 | num_units, memory, memory_sequence_length=source_sequence_length) 153 | elif attention_option == "scaled_luong": 154 | attention_mechanism = tf.contrib.seq2seq.LuongAttention( 155 | num_units, 156 | memory, 157 | memory_sequence_length=source_sequence_length, 158 | scale=True) 159 | elif attention_option == "bahdanau": 160 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( 161 | num_units, memory, memory_sequence_length=source_sequence_length) 162 | elif attention_option == "normed_bahdanau": 163 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( 164 | num_units, 165 | memory, 166 | memory_sequence_length=source_sequence_length, 167 | normalize=True) 168 | else: 169 | raise ValueError("Unknown attention option %s" % attention_option) 170 | 171 | return attention_mechanism 172 | 173 | 174 | def _create_attention_images_summary(final_context_state): 175 | """create attention image and attention summary.""" 176 | attention_images = (final_context_state.alignment_history.stack()) 177 | # Reshape to (batch, src_seq_len, tgt_seq_len,1) 178 | attention_images = tf.expand_dims( 179 | tf.transpose(attention_images, [1, 2, 0]), -1) 180 | # Scale to range [0, 255] 181 | attention_images *= 255 182 | attention_summary = tf.summary.image("attention_images", attention_images) 183 | return attention_summary 184 | -------------------------------------------------------------------------------- /nmt/inference_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for model inference.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | from . import attention_model 27 | from . import model_helper 28 | from . import model as nmt_model 29 | from . import gnmt_model 30 | from . import inference 31 | from .utils import common_test_utils 32 | 33 | float32 = np.float32 34 | int32 = np.int32 35 | array = np.array 36 | 37 | 38 | class InferenceTest(tf.test.TestCase): 39 | 40 | def _createTestInferCheckpoint(self, hparams, out_dir): 41 | if not hparams.attention: 42 | model_creator = nmt_model.Model 43 | elif hparams.attention_architecture == "standard": 44 | model_creator = attention_model.AttentionModel 45 | elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]: 46 | model_creator = gnmt_model.GNMTModel 47 | else: 48 | raise ValueError("Unknown model architecture") 49 | 50 | infer_model = model_helper.create_infer_model(model_creator, hparams) 51 | with self.test_session(graph=infer_model.graph) as sess: 52 | loaded_model, global_step = model_helper.create_or_load_model( 53 | infer_model.model, out_dir, sess, "infer_name") 54 | ckpt = loaded_model.saver.save( 55 | sess, os.path.join(out_dir, "translate.ckpt"), 56 | global_step=global_step) 57 | return ckpt 58 | 59 | def testBasicModel(self): 60 | hparams = common_test_utils.create_test_hparams( 61 | encoder_type="uni", 62 | num_layers=1, 63 | attention="", 64 | attention_architecture="", 65 | use_residual=False,) 66 | vocab_prefix = "nmt/testdata/test_infer_vocab" 67 | hparams.src_vocab_file = vocab_prefix + "." + hparams.src 68 | hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt 69 | 70 | infer_file = "nmt/testdata/test_infer_file" 71 | out_dir = os.path.join(tf.test.get_temp_dir(), "basic_infer") 72 | hparams.out_dir = out_dir 73 | os.makedirs(out_dir) 74 | output_infer = os.path.join(out_dir, "output_infer") 75 | ckpt = self._createTestInferCheckpoint(hparams, out_dir) 76 | inference.inference(ckpt, infer_file, output_infer, hparams) 77 | with open(output_infer) as f: 78 | self.assertEqual(5, len(list(f))) 79 | 80 | def testBasicModelWithMultipleTranslations(self): 81 | hparams = common_test_utils.create_test_hparams( 82 | encoder_type="uni", 83 | num_layers=1, 84 | attention="", 85 | attention_architecture="", 86 | use_residual=False, 87 | num_translations_per_input=2, 88 | beam_width=2, 89 | ) 90 | vocab_prefix = "nmt/testdata/test_infer_vocab" 91 | hparams.src_vocab_file = vocab_prefix + "." + hparams.src 92 | hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt 93 | 94 | infer_file = "nmt/testdata/test_infer_file" 95 | out_dir = os.path.join(tf.test.get_temp_dir(), "multi_basic_infer") 96 | hparams.out_dir = out_dir 97 | os.makedirs(out_dir) 98 | output_infer = os.path.join(out_dir, "output_infer") 99 | ckpt = self._createTestInferCheckpoint(hparams, out_dir) 100 | inference.inference(ckpt, infer_file, output_infer, hparams) 101 | with open(output_infer) as f: 102 | self.assertEqual(10, len(list(f))) 103 | 104 | def testAttentionModel(self): 105 | hparams = common_test_utils.create_test_hparams( 106 | encoder_type="uni", 107 | num_layers=1, 108 | attention="scaled_luong", 109 | attention_architecture="standard", 110 | use_residual=False,) 111 | vocab_prefix = "nmt/testdata/test_infer_vocab" 112 | hparams.src_vocab_file = vocab_prefix + "." + hparams.src 113 | hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt 114 | 115 | infer_file = "nmt/testdata/test_infer_file" 116 | out_dir = os.path.join(tf.test.get_temp_dir(), "attention_infer") 117 | hparams.out_dir = out_dir 118 | os.makedirs(out_dir) 119 | output_infer = os.path.join(out_dir, "output_infer") 120 | ckpt = self._createTestInferCheckpoint(hparams, out_dir) 121 | inference.inference(ckpt, infer_file, output_infer, hparams) 122 | with open(output_infer) as f: 123 | self.assertEqual(5, len(list(f))) 124 | 125 | def testMultiWorkers(self): 126 | hparams = common_test_utils.create_test_hparams( 127 | encoder_type="uni", 128 | num_layers=2, 129 | attention="scaled_luong", 130 | attention_architecture="standard", 131 | use_residual=False,) 132 | vocab_prefix = "nmt/testdata/test_infer_vocab" 133 | hparams.src_vocab_file = vocab_prefix + "." + hparams.src 134 | hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt 135 | 136 | infer_file = "nmt/testdata/test_infer_file" 137 | out_dir = os.path.join(tf.test.get_temp_dir(), "multi_worker_infer") 138 | hparams.out_dir = out_dir 139 | os.makedirs(out_dir) 140 | output_infer = os.path.join(out_dir, "output_infer") 141 | 142 | num_workers = 3 143 | 144 | # There are 5 examples, make batch_size=3 makes job0 has 3 examples, job1 145 | # has 2 examples, and job2 has 0 example. This helps testing some edge 146 | # cases. 147 | hparams.batch_size = 3 148 | 149 | ckpt = self._createTestInferCheckpoint(hparams, out_dir) 150 | inference.inference( 151 | ckpt, infer_file, output_infer, hparams, num_workers, jobid=1) 152 | 153 | inference.inference( 154 | ckpt, infer_file, output_infer, hparams, num_workers, jobid=2) 155 | 156 | # Note: Need to start job 0 at the end; otherwise, it will block the testing 157 | # thread. 158 | inference.inference( 159 | ckpt, infer_file, output_infer, hparams, num_workers, jobid=0) 160 | 161 | with open(output_infer) as f: 162 | self.assertEqual(5, len(list(f))) 163 | 164 | def testBasicModelWithInferIndices(self): 165 | hparams = common_test_utils.create_test_hparams( 166 | encoder_type="uni", 167 | num_layers=1, 168 | attention="", 169 | attention_architecture="", 170 | use_residual=False, 171 | inference_indices=[0]) 172 | vocab_prefix = "nmt/testdata/test_infer_vocab" 173 | hparams.src_vocab_file = vocab_prefix + "." + hparams.src 174 | hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt 175 | 176 | infer_file = "nmt/testdata/test_infer_file" 177 | out_dir = os.path.join(tf.test.get_temp_dir(), "basic_infer_with_indices") 178 | hparams.out_dir = out_dir 179 | os.makedirs(out_dir) 180 | output_infer = os.path.join(out_dir, "output_infer") 181 | ckpt = self._createTestInferCheckpoint(hparams, out_dir) 182 | inference.inference(ckpt, infer_file, output_infer, hparams) 183 | with open(output_infer) as f: 184 | self.assertEqual(1, len(list(f))) 185 | 186 | def testAttentionModelWithInferIndices(self): 187 | hparams = common_test_utils.create_test_hparams( 188 | encoder_type="uni", 189 | num_layers=1, 190 | attention="scaled_luong", 191 | attention_architecture="standard", 192 | use_residual=False, 193 | inference_indices=[1, 2]) 194 | # TODO(rzhao): Make infer indices support batch_size > 1. 195 | hparams.infer_batch_size = 1 196 | vocab_prefix = "nmt/testdata/test_infer_vocab" 197 | hparams.src_vocab_file = vocab_prefix + "." + hparams.src 198 | hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt 199 | 200 | infer_file = "nmt/testdata/test_infer_file" 201 | out_dir = os.path.join(tf.test.get_temp_dir(), 202 | "attention_infer_with_indices") 203 | hparams.out_dir = out_dir 204 | os.makedirs(out_dir) 205 | output_infer = os.path.join(out_dir, "output_infer") 206 | ckpt = self._createTestInferCheckpoint(hparams, out_dir) 207 | inference.inference(ckpt, infer_file, output_infer, hparams) 208 | with open(output_infer) as f: 209 | self.assertEqual(2, len(list(f))) 210 | self.assertTrue(os.path.exists(output_infer+str(1)+".png")) 211 | self.assertTrue(os.path.exists(output_infer+str(2)+".png")) 212 | 213 | 214 | if __name__ == "__main__": 215 | tf.test.main() 216 | -------------------------------------------------------------------------------- /nmt/utils/iterator_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """For loading data into NMT models.""" 16 | from __future__ import print_function 17 | 18 | import collections 19 | 20 | import tensorflow as tf 21 | 22 | __all__ = ["BatchedInput", "get_iterator", "get_infer_iterator"] 23 | 24 | 25 | # NOTE(ebrevdo): When we subclass this, instances' __dict__ becomes empty. 26 | class BatchedInput( 27 | collections.namedtuple("BatchedInput", 28 | ("initializer", "source", "target_input", 29 | "target_output", "source_sequence_length", 30 | "target_sequence_length"))): 31 | pass 32 | 33 | 34 | def get_infer_iterator(src_dataset, 35 | src_vocab_table, 36 | batch_size, 37 | eos, 38 | src_max_len=None): 39 | src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) 40 | src_dataset = src_dataset.map(lambda src: tf.string_split([src]).values) 41 | 42 | if src_max_len: 43 | src_dataset = src_dataset.map(lambda src: src[:src_max_len]) 44 | # Convert the word strings to ids 45 | src_dataset = src_dataset.map( 46 | lambda src: tf.cast(src_vocab_table.lookup(src), tf.int32)) 47 | # Add in the word counts. 48 | src_dataset = src_dataset.map(lambda src: (src, tf.size(src))) 49 | 50 | def batching_func(x): 51 | return x.padded_batch( 52 | batch_size, 53 | # The entry is the source line rows; 54 | # this has unknown-length vectors. The last entry is 55 | # the source row size; this is a scalar. 56 | padded_shapes=( 57 | tf.TensorShape([None]), # src 58 | tf.TensorShape([])), # src_len 59 | # Pad the source sequences with eos tokens. 60 | # (Though notice we don't generally need to do this since 61 | # later on we will be masking out calculations past the true sequence. 62 | padding_values=( 63 | src_eos_id, # src 64 | 0)) # src_len -- unused 65 | 66 | batched_dataset = batching_func(src_dataset) 67 | batched_iter = batched_dataset.make_initializable_iterator() 68 | (src_ids, src_seq_len) = batched_iter.get_next() 69 | return BatchedInput( 70 | initializer=batched_iter.initializer, 71 | source=src_ids, 72 | target_input=None, 73 | target_output=None, 74 | source_sequence_length=src_seq_len, 75 | target_sequence_length=None) 76 | 77 | 78 | def get_iterator(src_dataset, 79 | tgt_dataset, 80 | src_vocab_table, 81 | tgt_vocab_table, 82 | batch_size, 83 | sos, 84 | eos, 85 | random_seed, 86 | num_buckets, 87 | src_max_len=None, 88 | tgt_max_len=None, 89 | num_parallel_calls=4, 90 | output_buffer_size=None, 91 | skip_count=None, 92 | num_shards=1, 93 | shard_index=0, 94 | reshuffle_each_iteration=True): 95 | if not output_buffer_size: 96 | output_buffer_size = batch_size * 1000 97 | src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) 98 | tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32) 99 | tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32) 100 | 101 | src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset)) 102 | 103 | src_tgt_dataset = src_tgt_dataset.shard(num_shards, shard_index) 104 | if skip_count is not None: 105 | src_tgt_dataset = src_tgt_dataset.skip(skip_count) 106 | 107 | src_tgt_dataset = src_tgt_dataset.shuffle( 108 | output_buffer_size, random_seed, reshuffle_each_iteration) 109 | 110 | src_tgt_dataset = src_tgt_dataset.map( 111 | lambda src, tgt: ( 112 | tf.string_split([src]).values, tf.string_split([tgt]).values), 113 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 114 | 115 | # Filter zero length input sequences. 116 | src_tgt_dataset = src_tgt_dataset.filter( 117 | lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0)) 118 | 119 | if src_max_len: 120 | src_tgt_dataset = src_tgt_dataset.map( 121 | lambda src, tgt: (src[:src_max_len], tgt), 122 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 123 | if tgt_max_len: 124 | src_tgt_dataset = src_tgt_dataset.map( 125 | lambda src, tgt: (src, tgt[:tgt_max_len]), 126 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 127 | # Convert the word strings to ids. Word strings that are not in the 128 | # vocab get the lookup table's default_value integer. 129 | src_tgt_dataset = src_tgt_dataset.map( 130 | lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32), 131 | tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)), 132 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 133 | # Create a tgt_input prefixed with and a tgt_output suffixed with . 134 | src_tgt_dataset = src_tgt_dataset.map( 135 | lambda src, tgt: (src, 136 | tf.concat(([tgt_sos_id], tgt), 0), 137 | tf.concat((tgt, [tgt_eos_id]), 0)), 138 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 139 | # Add in sequence lengths. 140 | src_tgt_dataset = src_tgt_dataset.map( 141 | lambda src, tgt_in, tgt_out: ( 142 | src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)), 143 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 144 | 145 | # Bucket by source sequence length (buckets for lengths 0-9, 10-19, ...) 146 | def batching_func(x): 147 | return x.padded_batch( 148 | batch_size, 149 | # The first three entries are the source and target line rows; 150 | # these have unknown-length vectors. The last two entries are 151 | # the source and target row sizes; these are scalars. 152 | padded_shapes=( 153 | tf.TensorShape([None]), # src 154 | tf.TensorShape([None]), # tgt_input 155 | tf.TensorShape([None]), # tgt_output 156 | tf.TensorShape([]), # src_len 157 | tf.TensorShape([])), # tgt_len 158 | # Pad the source and target sequences with eos tokens. 159 | # (Though notice we don't generally need to do this since 160 | # later on we will be masking out calculations past the true sequence. 161 | padding_values=( 162 | src_eos_id, # src 163 | tgt_eos_id, # tgt_input 164 | tgt_eos_id, # tgt_output 165 | 0, # src_len -- unused 166 | 0)) # tgt_len -- unused 167 | 168 | if num_buckets > 1: 169 | 170 | def key_func(unused_1, unused_2, unused_3, src_len, tgt_len): 171 | # Calculate bucket_width by maximum source sequence length. 172 | # Pairs with length [0, bucket_width) go to bucket 0, length 173 | # [bucket_width, 2 * bucket_width) go to bucket 1, etc. Pairs with length 174 | # over ((num_bucket-1) * bucket_width) words all go into the last bucket. 175 | if src_max_len: 176 | bucket_width = (src_max_len + num_buckets - 1) // num_buckets 177 | else: 178 | bucket_width = 10 179 | 180 | # Bucket sentence pairs by the length of their source sentence and target 181 | # sentence. 182 | bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width) 183 | return tf.to_int64(tf.minimum(num_buckets, bucket_id)) 184 | 185 | def reduce_func(unused_key, windowed_data): 186 | return batching_func(windowed_data) 187 | 188 | batched_dataset = src_tgt_dataset.apply( 189 | tf.contrib.data.group_by_window( 190 | key_func=key_func, reduce_func=reduce_func, window_size=batch_size)) 191 | 192 | else: 193 | batched_dataset = batching_func(src_tgt_dataset) 194 | batched_iter = batched_dataset.make_initializable_iterator() 195 | (src_ids, tgt_input_ids, tgt_output_ids, src_seq_len, 196 | tgt_seq_len) = (batched_iter.get_next()) 197 | return BatchedInput( 198 | initializer=batched_iter.initializer, 199 | source=src_ids, 200 | target_input=tgt_input_ids, 201 | target_output=tgt_output_ids, 202 | source_sequence_length=src_seq_len, 203 | target_sequence_length=tgt_seq_len) 204 | -------------------------------------------------------------------------------- /nmt/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """To perform inference on test set given a trained model.""" 17 | from __future__ import print_function 18 | 19 | import codecs 20 | import time 21 | 22 | import tensorflow as tf 23 | 24 | from . import attention_model 25 | from . import gnmt_model 26 | from . import model as nmt_model 27 | from . import model_helper 28 | from .utils import misc_utils as utils 29 | from .utils import nmt_utils 30 | 31 | __all__ = ["load_data", "inference", 32 | "single_worker_inference", "multi_worker_inference"] 33 | 34 | 35 | def _decode_inference_indices(model, sess, output_infer, 36 | output_infer_summary_prefix, 37 | inference_indices, 38 | tgt_eos, 39 | subword_option): 40 | """Decoding only a specific set of sentences.""" 41 | utils.print_out(" decoding to output %s , num sents %d." % 42 | (output_infer, len(inference_indices))) 43 | start_time = time.time() 44 | with codecs.getwriter("utf-8")( 45 | tf.gfile.GFile(output_infer, mode="wb")) as trans_f: 46 | trans_f.write("") # Write empty string to ensure file is created. 47 | for decode_id in inference_indices: 48 | nmt_outputs, infer_summary = model.decode(sess) 49 | 50 | # get text translation 51 | assert nmt_outputs.shape[0] == 1 52 | translation = nmt_utils.get_translation( 53 | nmt_outputs, 54 | sent_id=0, 55 | tgt_eos=tgt_eos, 56 | subword_option=subword_option) 57 | 58 | if infer_summary is not None: # Attention models 59 | image_file = output_infer_summary_prefix + str(decode_id) + ".png" 60 | utils.print_out(" save attention image to %s*" % image_file) 61 | image_summ = tf.Summary() 62 | image_summ.ParseFromString(infer_summary) 63 | with tf.gfile.GFile(image_file, mode="w") as img_f: 64 | img_f.write(image_summ.value[0].image.encoded_image_string) 65 | 66 | trans_f.write("%s\n" % translation) 67 | utils.print_out(translation + b"\n") 68 | utils.print_time(" done", start_time) 69 | 70 | 71 | def load_data(inference_input_file, hparams=None): 72 | """Load inference data.""" 73 | with codecs.getreader("utf-8")( 74 | tf.gfile.GFile(inference_input_file, mode="rb")) as f: 75 | inference_data = f.read().splitlines() 76 | 77 | if hparams and hparams.inference_indices: 78 | inference_data = [inference_data[i] for i in hparams.inference_indices] 79 | 80 | return inference_data 81 | 82 | 83 | def inference(ckpt, 84 | inference_input_file, 85 | inference_output_file, 86 | hparams, 87 | num_workers=1, 88 | jobid=0, 89 | scope=None): 90 | """Perform translation.""" 91 | if hparams.inference_indices: 92 | assert num_workers == 1 93 | 94 | if not hparams.attention: 95 | model_creator = nmt_model.Model 96 | elif hparams.attention_architecture == "standard": 97 | model_creator = attention_model.AttentionModel 98 | elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]: 99 | model_creator = gnmt_model.GNMTModel 100 | else: 101 | raise ValueError("Unknown model architecture") 102 | infer_model = model_helper.create_infer_model(model_creator, hparams, scope) 103 | 104 | if num_workers == 1: 105 | single_worker_inference( 106 | infer_model, 107 | ckpt, 108 | inference_input_file, 109 | inference_output_file, 110 | hparams) 111 | else: 112 | multi_worker_inference( 113 | infer_model, 114 | ckpt, 115 | inference_input_file, 116 | inference_output_file, 117 | hparams, 118 | num_workers=num_workers, 119 | jobid=jobid) 120 | 121 | 122 | def single_worker_inference(infer_model, 123 | ckpt, 124 | inference_input_file, 125 | inference_output_file, 126 | hparams): 127 | """Inference with a single worker.""" 128 | output_infer = inference_output_file 129 | 130 | # Read data 131 | infer_data = load_data(inference_input_file, hparams) 132 | 133 | with tf.Session( 134 | graph=infer_model.graph, config=utils.get_config_proto()) as sess: 135 | loaded_infer_model = model_helper.load_model( 136 | infer_model.model, ckpt, sess, "infer") 137 | sess.run( 138 | infer_model.iterator.initializer, 139 | feed_dict={ 140 | infer_model.src_placeholder: infer_data, 141 | infer_model.batch_size_placeholder: hparams.infer_batch_size 142 | }) 143 | # Decode 144 | utils.print_out("# Start decoding") 145 | if hparams.inference_indices: 146 | _decode_inference_indices( 147 | loaded_infer_model, 148 | sess, 149 | output_infer=output_infer, 150 | output_infer_summary_prefix=output_infer, 151 | inference_indices=hparams.inference_indices, 152 | tgt_eos=hparams.eos, 153 | subword_option=hparams.subword_option) 154 | else: 155 | nmt_utils.decode_and_evaluate( 156 | "infer", 157 | loaded_infer_model, 158 | sess, 159 | output_infer, 160 | ref_file=None, 161 | metrics=hparams.metrics, 162 | subword_option=hparams.subword_option, 163 | beam_width=hparams.beam_width, 164 | tgt_eos=hparams.eos, 165 | num_translations_per_input=hparams.num_translations_per_input) 166 | 167 | 168 | def multi_worker_inference(infer_model, 169 | ckpt, 170 | inference_input_file, 171 | inference_output_file, 172 | hparams, 173 | num_workers, 174 | jobid): 175 | """Inference using multiple workers.""" 176 | assert num_workers > 1 177 | 178 | final_output_infer = inference_output_file 179 | output_infer = "%s_%d" % (inference_output_file, jobid) 180 | output_infer_done = "%s_done_%d" % (inference_output_file, jobid) 181 | 182 | # Read data 183 | infer_data = load_data(inference_input_file, hparams) 184 | 185 | # Split data to multiple workers 186 | total_load = len(infer_data) 187 | load_per_worker = int((total_load - 1) / num_workers) + 1 188 | start_position = jobid * load_per_worker 189 | end_position = min(start_position + load_per_worker, total_load) 190 | infer_data = infer_data[start_position:end_position] 191 | 192 | with tf.Session( 193 | graph=infer_model.graph, config=utils.get_config_proto()) as sess: 194 | loaded_infer_model = model_helper.load_model( 195 | infer_model.model, ckpt, sess, "infer") 196 | sess.run(infer_model.iterator.initializer, 197 | { 198 | infer_model.src_placeholder: infer_data, 199 | infer_model.batch_size_placeholder: hparams.infer_batch_size 200 | }) 201 | # Decode 202 | utils.print_out("# Start decoding") 203 | nmt_utils.decode_and_evaluate( 204 | "infer", 205 | loaded_infer_model, 206 | sess, 207 | output_infer, 208 | ref_file=None, 209 | metrics=hparams.metrics, 210 | subword_option=hparams.subword_option, 211 | beam_width=hparams.beam_width, 212 | tgt_eos=hparams.eos, 213 | num_translations_per_input=hparams.num_translations_per_input) 214 | 215 | # Change file name to indicate the file writing is completed. 216 | tf.gfile.Rename(output_infer, output_infer_done, overwrite=True) 217 | 218 | # Job 0 is responsible for the clean up. 219 | if jobid != 0: return 220 | 221 | # Now write all translations 222 | with codecs.getwriter("utf-8")( 223 | tf.gfile.GFile(final_output_infer, mode="wb")) as final_f: 224 | for worker_id in range(num_workers): 225 | worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) 226 | while not tf.gfile.Exists(worker_infer_done): 227 | utils.print_out(" waitting job %d to complete." % worker_id) 228 | time.sleep(10) 229 | 230 | with codecs.getreader("utf-8")( 231 | tf.gfile.GFile(worker_infer_done, mode="rb")) as f: 232 | for translation in f: 233 | final_f.write("%s" % translation) 234 | 235 | for worker_id in range(num_workers): 236 | worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) 237 | tf.gfile.Remove(worker_infer_done) 238 | -------------------------------------------------------------------------------- /texar/baseline_seq2seq_attn_main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Texar Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Attentional Seq2seq. 16 | same as examples/seq2seq_attn except that here Rouge is also supported. 17 | """ 18 | from __future__ import absolute_import 19 | from __future__ import print_function 20 | from __future__ import division 21 | from __future__ import unicode_literals 22 | 23 | # pylint: disable=invalid-name, too-many-arguments, too-many-locals 24 | 25 | from io import open 26 | import importlib 27 | import tensorflow as tf 28 | import texar as tx 29 | from rouge import Rouge 30 | import pdb 31 | import OT 32 | 33 | flags = tf.flags 34 | 35 | flags.DEFINE_string("config_model", "configs.config_model", "The model config.") 36 | flags.DEFINE_string("config_data", "configs.config_iwslt14", 37 | "The dataset config.") 38 | 39 | flags.DEFINE_string('output_dir', '.', 'where to keep training logs') 40 | 41 | FLAGS = flags.FLAGS 42 | 43 | config_model = importlib.import_module(FLAGS.config_model) 44 | config_data = importlib.import_module(FLAGS.config_data) 45 | 46 | if not FLAGS.output_dir.endswith('/'): 47 | FLAGS.output_dir += '/' 48 | log_dir = FLAGS.output_dir + 'training_log_baseline/' 49 | tx.utils.maybe_create_dir(log_dir) 50 | 51 | 52 | def build_model(batch, train_data): 53 | """Assembles the seq2seq model. 54 | """ 55 | source_embedder = tx.modules.WordEmbedder( 56 | vocab_size=train_data.source_vocab.size, hparams=config_model.embedder) 57 | 58 | encoder = tx.modules.BidirectionalRNNEncoder( 59 | hparams=config_model.encoder) 60 | 61 | enc_outputs, _ = encoder(source_embedder(batch['source_text_ids'])) 62 | 63 | target_embedder = tx.modules.WordEmbedder( 64 | vocab_size=train_data.target_vocab.size, hparams=config_model.embedder) 65 | 66 | decoder = tx.modules.AttentionRNNDecoder( 67 | memory=tf.concat(enc_outputs, axis=2), 68 | memory_sequence_length=batch['source_length'], 69 | vocab_size=train_data.target_vocab.size, 70 | hparams=config_model.decoder) 71 | 72 | training_outputs, _, _ = decoder( 73 | decoding_strategy='train_greedy', 74 | inputs=target_embedder(batch['target_text_ids'][:, :-1]), 75 | sequence_length=batch['target_length'] - 1) 76 | 77 | MLE_loss = tx.losses.sequence_sparse_softmax_cross_entropy( 78 | labels=batch['target_text_ids'][:, 1:], 79 | logits=training_outputs.logits, 80 | sequence_length=batch['target_length'] - 1) 81 | 82 | train_op = tx.core.get_train_op( 83 | MLE_loss, 84 | hparams=config_model.opt) 85 | 86 | 87 | start_tokens = tf.ones_like(batch['target_length']) *\ 88 | train_data.target_vocab.bos_token_id 89 | beam_search_outputs, _, _ = \ 90 | tx.modules.beam_search_decode( 91 | decoder_or_cell=decoder, 92 | embedding=target_embedder, 93 | start_tokens=start_tokens, 94 | end_token=train_data.target_vocab.eos_token_id, 95 | beam_width=config_model.beam_width, 96 | max_decoding_length=60) 97 | 98 | return train_op, beam_search_outputs 99 | 100 | 101 | def print_stdout_and_file(content, file): 102 | print(content) 103 | print(content, file=file) 104 | 105 | 106 | def main(): 107 | """Entrypoint. 108 | """ 109 | train_data = tx.data.PairedTextData(hparams=config_data.train) 110 | val_data = tx.data.PairedTextData(hparams=config_data.val) 111 | test_data = tx.data.PairedTextData(hparams=config_data.test) 112 | 113 | data_iterator = tx.data.TrainTestDataIterator( 114 | train=train_data, val=val_data, test=test_data) 115 | 116 | batch = data_iterator.get_next() 117 | 118 | train_op, infer_outputs = build_model(batch, train_data) 119 | 120 | def _train_epoch(sess, epoch_no): 121 | data_iterator.switch_to_train_data(sess) 122 | training_log_file = \ 123 | open(log_dir + 'training_log' + str(epoch_no) + '.txt', 'w', 124 | encoding='utf-8') 125 | 126 | step = 0 127 | while True: 128 | try: 129 | loss = sess.run(train_op) 130 | print("step={}, loss={:.4f}".format(step, loss), 131 | file=training_log_file) 132 | if step % config_data.observe_steps == 0: 133 | print("step={}, loss={:.4f}".format(step, loss)) 134 | training_log_file.flush() 135 | step += 1 136 | except tf.errors.OutOfRangeError: 137 | break 138 | 139 | def _eval_epoch(sess, mode, epoch_no): 140 | if mode == 'val': 141 | data_iterator.switch_to_val_data(sess) 142 | else: 143 | data_iterator.switch_to_test_data(sess) 144 | 145 | refs, hypos = [], [] 146 | while True: 147 | try: 148 | fetches = [ 149 | batch['target_text'][:, 1:], 150 | infer_outputs.predicted_ids[:, :, 0] 151 | ] 152 | feed_dict = { 153 | tx.global_mode(): tf.estimator.ModeKeys.EVAL 154 | } 155 | target_texts_ori, output_ids = \ 156 | sess.run(fetches, feed_dict=feed_dict) 157 | 158 | target_texts = tx.utils.strip_special_tokens( 159 | target_texts_ori.tolist(), is_token_list=True) 160 | target_texts = tx.utils.str_join(target_texts) 161 | output_texts = tx.utils.map_ids_to_strs( 162 | ids=output_ids, vocab=val_data.target_vocab) 163 | 164 | tx.utils.write_paired_text( 165 | target_texts, output_texts, 166 | log_dir + mode + '_results' + str(epoch_no) + '.txt', 167 | append=True, mode='h', sep=' ||| ') 168 | 169 | for hypo, ref in zip(output_texts, target_texts): 170 | if config_data.eval_metric == 'bleu': 171 | hypos.append(hypo) 172 | refs.append([ref]) 173 | elif config_data.eval_metric == 'rouge': 174 | hypos.append(tx.utils.compat_as_text(hypo)) 175 | refs.append(tx.utils.compat_as_text(ref)) 176 | except tf.errors.OutOfRangeError: 177 | break 178 | 179 | if config_data.eval_metric == 'bleu': 180 | return tx.evals.corpus_bleu_moses( 181 | list_of_references=refs, hypotheses=hypos) 182 | elif config_data.eval_metric == 'rouge': 183 | rouge = Rouge() 184 | return rouge.get_scores(hyps=hypos, refs=refs, avg=True) 185 | 186 | def _calc_reward(score): 187 | """ 188 | Return the bleu score or the sum of (Rouge-1, Rouge-2, Rouge-L). 189 | """ 190 | if config_data.eval_metric == 'bleu': 191 | return score 192 | elif config_data.eval_metric == 'rouge': 193 | return sum([value['f'] for key, value in score.items()]) 194 | 195 | with tf.Session() as sess: 196 | sess.run(tf.global_variables_initializer()) 197 | sess.run(tf.local_variables_initializer()) 198 | sess.run(tf.tables_initializer()) 199 | 200 | best_val_score = -1. 201 | scores_file = open(log_dir + 'scores.txt', 'w', encoding='utf-8') 202 | for i in range(config_data.num_epochs): 203 | _train_epoch(sess, i) 204 | 205 | val_score = _eval_epoch(sess, 'val', i) 206 | test_score = _eval_epoch(sess, 'test', i) 207 | 208 | best_val_score = max(best_val_score, _calc_reward(val_score)) 209 | 210 | if config_data.eval_metric == 'bleu': 211 | print_stdout_and_file( 212 | 'val epoch={}, BLEU={:.4f}; best-ever={:.4f}'.format( 213 | i, val_score, best_val_score), file=scores_file) 214 | 215 | print_stdout_and_file( 216 | 'test epoch={}, BLEU={:.4f}'.format(i, test_score), 217 | file=scores_file) 218 | print_stdout_and_file('=' * 50, file=scores_file) 219 | 220 | elif config_data.eval_metric == 'rouge': 221 | print_stdout_and_file( 222 | 'valid epoch {}:'.format(i), file=scores_file) 223 | for key, value in val_score.items(): 224 | print_stdout_and_file( 225 | '{}: {}'.format(key, value), file=scores_file) 226 | print_stdout_and_file('fsum: {}; best_val_fsum: {}'.format( 227 | _calc_reward(val_score), best_val_score), file=scores_file) 228 | 229 | print_stdout_and_file( 230 | 'test epoch {}:'.format(i), file=scores_file) 231 | for key, value in test_score.items(): 232 | print_stdout_and_file( 233 | '{}: {}'.format(key, value), file=scores_file) 234 | print_stdout_and_file('=' * 110, file=scores_file) 235 | 236 | scores_file.flush() 237 | 238 | 239 | if __name__ == '__main__': 240 | main() 241 | -------------------------------------------------------------------------------- /texar/baseline_seq2seq_attn_ot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Texar Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Attentional Seq2seq. 16 | same as examples/seq2seq_attn except that here Rouge is also supported. 17 | """ 18 | from __future__ import absolute_import 19 | from __future__ import print_function 20 | from __future__ import division 21 | from __future__ import unicode_literals 22 | 23 | # pylint: disable=invalid-name, too-many-arguments, too-many-locals 24 | import os 25 | from io import open 26 | import importlib 27 | import tensorflow as tf 28 | import texar as tx 29 | from rouge import Rouge 30 | import OT 31 | import pdb 32 | 33 | GPUID = 0 34 | os.environ["CUDA_VISIBLE_DEVICES"] = str(GPUID) 35 | 36 | flags = tf.flags 37 | 38 | flags.DEFINE_string("config_model", "configs.config_model", "The model config.") 39 | flags.DEFINE_string("config_data", "configs.config_iwslt14", 40 | "The dataset config.") 41 | 42 | flags.DEFINE_string('output_dir', '.', 'where to keep training logs') 43 | 44 | FLAGS = flags.FLAGS 45 | 46 | config_model = importlib.import_module(FLAGS.config_model) 47 | config_data = importlib.import_module(FLAGS.config_data) 48 | 49 | if not FLAGS.output_dir.endswith('/'): 50 | FLAGS.output_dir += '/' 51 | log_dir = FLAGS.output_dir + 'training_log_baseline/' 52 | tx.utils.maybe_create_dir(log_dir) 53 | 54 | 55 | def build_model(batch, train_data): 56 | """Assembles the seq2seq model. 57 | """ 58 | source_embedder = tx.modules.WordEmbedder( 59 | vocab_size=train_data.source_vocab.size, hparams=config_model.embedder) 60 | 61 | encoder = tx.modules.BidirectionalRNNEncoder( 62 | hparams=config_model.encoder) 63 | 64 | enc_outputs, _ = encoder(source_embedder(batch['source_text_ids'])) 65 | 66 | target_embedder = tx.modules.WordEmbedder( 67 | vocab_size=train_data.target_vocab.size, hparams=config_model.embedder) 68 | 69 | decoder = tx.modules.AttentionRNNDecoder( 70 | memory=tf.concat(enc_outputs, axis=2), 71 | memory_sequence_length=batch['source_length'], 72 | vocab_size=train_data.target_vocab.size, 73 | hparams=config_model.decoder) 74 | 75 | training_outputs, _, _ = decoder( 76 | decoding_strategy='train_greedy', 77 | inputs=target_embedder(batch['target_text_ids'][:, :-1]), 78 | sequence_length=batch['target_length'] - 1) 79 | 80 | # Modify loss 81 | MLE_loss = tx.losses.sequence_sparse_softmax_cross_entropy( 82 | labels=batch['target_text_ids'][:, 1:], 83 | logits=training_outputs.logits, 84 | sequence_length=batch['target_length'] - 1) 85 | 86 | # TODO: key words matching loss 87 | tgt_logits = training_outputs.logits 88 | tgt_words = target_embedder(soft_ids = tgt_logits) 89 | src_words = source_embedder(ids = batch['source_text_ids']) 90 | src_words = tf.nn.l2_normalize(src_words, 2, epsilon=1e-12) 91 | tgt_words = tf.nn.l2_normalize(tgt_words, 2, epsilon=1e-12) 92 | 93 | cosine_cost = 1 - tf.einsum( 94 | 'aij,ajk->aik', src_words, tf.transpose(tgt_words, [0,2,1])) 95 | # pdb.set_trace() 96 | OT_loss = tf.reduce_mean(OT.IPOT_distance2(cosine_cost)) 97 | 98 | Total_loss = MLE_loss + 0.1 * OT_loss 99 | 100 | train_op = tx.core.get_train_op( 101 | Total_loss, 102 | hparams=config_model.opt) 103 | 104 | 105 | start_tokens = tf.ones_like(batch['target_length']) *\ 106 | train_data.target_vocab.bos_token_id 107 | beam_search_outputs, _, _ = \ 108 | tx.modules.beam_search_decode( 109 | decoder_or_cell=decoder, 110 | embedding=target_embedder, 111 | start_tokens=start_tokens, 112 | end_token=train_data.target_vocab.eos_token_id, 113 | beam_width=config_model.beam_width, 114 | max_decoding_length=60) 115 | 116 | return train_op, beam_search_outputs 117 | 118 | 119 | def print_stdout_and_file(content, file): 120 | print(content) 121 | print(content, file=file) 122 | 123 | 124 | def main(): 125 | """Entrypoint. 126 | """ 127 | train_data = tx.data.PairedTextData(hparams=config_data.train) 128 | val_data = tx.data.PairedTextData(hparams=config_data.val) 129 | test_data = tx.data.PairedTextData(hparams=config_data.test) 130 | # pdb.set_trace() 131 | data_iterator = tx.data.TrainTestDataIterator( 132 | train=train_data, val=val_data, test=test_data) 133 | 134 | batch = data_iterator.get_next() 135 | 136 | train_op, infer_outputs = build_model(batch, train_data) 137 | 138 | def _train_epoch(sess, epoch_no): 139 | data_iterator.switch_to_train_data(sess) 140 | training_log_file = \ 141 | open(log_dir + 'training_log' + str(epoch_no) + '.txt', 'w', 142 | encoding='utf-8') 143 | 144 | step = 0 145 | while True: 146 | try: 147 | loss = sess.run(train_op) 148 | print("step={}, loss={:.4f}".format(step, loss), 149 | file=training_log_file) 150 | if step % config_data.observe_steps == 0: 151 | print("step={}, loss={:.4f}".format(step, loss)) 152 | training_log_file.flush() 153 | step += 1 154 | except tf.errors.OutOfRangeError: 155 | break 156 | 157 | def _eval_epoch(sess, mode, epoch_no): 158 | if mode == 'val': 159 | data_iterator.switch_to_val_data(sess) 160 | else: 161 | data_iterator.switch_to_test_data(sess) 162 | 163 | refs, hypos = [], [] 164 | while True: 165 | try: 166 | fetches = [ 167 | batch['target_text'][:, 1:], 168 | infer_outputs.predicted_ids[:, :, 0] 169 | ] 170 | feed_dict = { 171 | tx.global_mode(): tf.estimator.ModeKeys.EVAL 172 | } 173 | target_texts_ori, output_ids = \ 174 | sess.run(fetches, feed_dict=feed_dict) 175 | 176 | target_texts = tx.utils.strip_special_tokens( 177 | target_texts_ori.tolist(), is_token_list=True) 178 | target_texts = tx.utils.str_join(target_texts) 179 | output_texts = tx.utils.map_ids_to_strs( 180 | ids=output_ids, vocab=val_data.target_vocab) 181 | 182 | tx.utils.write_paired_text( 183 | target_texts, output_texts, 184 | log_dir + mode + '_results' + str(epoch_no) + '.txt', 185 | append=True, mode='h', sep=' ||| ') 186 | 187 | for hypo, ref in zip(output_texts, target_texts): 188 | if config_data.eval_metric == 'bleu': 189 | hypos.append(hypo) 190 | refs.append([ref]) 191 | elif config_data.eval_metric == 'rouge': 192 | hypos.append(tx.utils.compat_as_text(hypo)) 193 | refs.append(tx.utils.compat_as_text(ref)) 194 | except tf.errors.OutOfRangeError: 195 | break 196 | 197 | if config_data.eval_metric == 'bleu': 198 | return tx.evals.corpus_bleu_moses( 199 | list_of_references=refs, hypotheses=hypos) 200 | elif config_data.eval_metric == 'rouge': 201 | rouge = Rouge() 202 | return rouge.get_scores(hyps=hypos, refs=refs, avg=True) 203 | 204 | def _calc_reward(score): 205 | """ 206 | Return the bleu score or the sum of (Rouge-1, Rouge-2, Rouge-L). 207 | """ 208 | if config_data.eval_metric == 'bleu': 209 | return score 210 | elif config_data.eval_metric == 'rouge': 211 | return sum([value['f'] for key, value in score.items()]) 212 | 213 | config = tf.ConfigProto() 214 | config.gpu_options.per_process_gpu_memory_fraction = 0.4 215 | with tf.Session(config=config) as sess: 216 | sess.run(tf.global_variables_initializer()) 217 | sess.run(tf.local_variables_initializer()) 218 | sess.run(tf.tables_initializer()) 219 | 220 | best_val_score = -1. 221 | scores_file = open(log_dir + 'scores.txt', 'w', encoding='utf-8') 222 | for i in range(config_data.num_epochs): 223 | _train_epoch(sess, i) 224 | 225 | val_score = _eval_epoch(sess, 'val', i) 226 | test_score = _eval_epoch(sess, 'test', i) 227 | 228 | best_val_score = max(best_val_score, _calc_reward(val_score)) 229 | 230 | if config_data.eval_metric == 'bleu': 231 | print_stdout_and_file( 232 | 'val epoch={}, BLEU={:.4f}; best-ever={:.4f}'.format( 233 | i, val_score, best_val_score), file=scores_file) 234 | 235 | print_stdout_and_file( 236 | 'test epoch={}, BLEU={:.4f}'.format(i, test_score), 237 | file=scores_file) 238 | print_stdout_and_file('=' * 50, file=scores_file) 239 | 240 | elif config_data.eval_metric == 'rouge': 241 | print_stdout_and_file( 242 | 'valid epoch {}:'.format(i), file=scores_file) 243 | for key, value in val_score.items(): 244 | print_stdout_and_file( 245 | '{}: {}'.format(key, value), file=scores_file) 246 | print_stdout_and_file('fsum: {}; best_val_fsum: {}'.format( 247 | _calc_reward(val_score), best_val_score), file=scores_file) 248 | 249 | print_stdout_and_file( 250 | 'test epoch {}:'.format(i), file=scores_file) 251 | for key, value in test_score.items(): 252 | print_stdout_and_file( 253 | '{}: {}'.format(key, value), file=scores_file) 254 | print_stdout_and_file('=' * 110, file=scores_file) 255 | 256 | scores_file.flush() 257 | 258 | 259 | if __name__ == '__main__': 260 | main() 261 | -------------------------------------------------------------------------------- /nmt/gnmt_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """GNMT attention sequence-to-sequence model with dynamic RNN support.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | # TODO(rzhao): Use tf.contrib.framework.nest once 1.3 is out. 24 | from tensorflow.python.util import nest 25 | 26 | from . import attention_model 27 | from . import model_helper 28 | from .utils import misc_utils as utils 29 | 30 | __all__ = ["GNMTModel"] 31 | 32 | 33 | class GNMTModel(attention_model.AttentionModel): 34 | """Sequence-to-sequence dynamic model with GNMT attention architecture. 35 | """ 36 | 37 | def __init__(self, 38 | hparams, 39 | mode, 40 | iterator, 41 | source_vocab_table, 42 | target_vocab_table, 43 | reverse_target_vocab_table=None, 44 | scope=None, 45 | extra_args=None): 46 | super(GNMTModel, self).__init__( 47 | hparams=hparams, 48 | mode=mode, 49 | iterator=iterator, 50 | source_vocab_table=source_vocab_table, 51 | target_vocab_table=target_vocab_table, 52 | reverse_target_vocab_table=reverse_target_vocab_table, 53 | scope=scope, 54 | extra_args=extra_args) 55 | 56 | def _build_encoder(self, hparams): 57 | """Build a GNMT encoder.""" 58 | if hparams.encoder_type == "uni" or hparams.encoder_type == "bi": 59 | return super(GNMTModel, self)._build_encoder(hparams) 60 | 61 | if hparams.encoder_type != "gnmt": 62 | raise ValueError("Unknown encoder_type %s" % hparams.encoder_type) 63 | 64 | # Build GNMT encoder. 65 | num_bi_layers = 1 66 | num_uni_layers = self.num_encoder_layers - num_bi_layers 67 | utils.print_out(" num_bi_layers = %d" % num_bi_layers) 68 | utils.print_out(" num_uni_layers = %d" % num_uni_layers) 69 | 70 | iterator = self.iterator 71 | source = iterator.source 72 | if self.time_major: 73 | source = tf.transpose(source) 74 | 75 | with tf.variable_scope("encoder") as scope: 76 | dtype = scope.dtype 77 | 78 | # Look up embedding, emp_inp: [max_time, batch_size, num_units] 79 | # when time_major = True 80 | encoder_emb_inp = tf.nn.embedding_lookup(self.embedding_encoder, 81 | source) 82 | 83 | # Execute _build_bidirectional_rnn from Model class 84 | bi_encoder_outputs, bi_encoder_state = self._build_bidirectional_rnn( 85 | inputs=encoder_emb_inp, 86 | sequence_length=iterator.source_sequence_length, 87 | dtype=dtype, 88 | hparams=hparams, 89 | num_bi_layers=num_bi_layers, 90 | num_bi_residual_layers=0, # no residual connection 91 | ) 92 | 93 | uni_cell = model_helper.create_rnn_cell( 94 | unit_type=hparams.unit_type, 95 | num_units=hparams.num_units, 96 | num_layers=num_uni_layers, 97 | num_residual_layers=self.num_encoder_residual_layers, 98 | forget_bias=hparams.forget_bias, 99 | dropout=hparams.dropout, 100 | num_gpus=self.num_gpus, 101 | base_gpu=1, 102 | mode=self.mode, 103 | single_cell_fn=self.single_cell_fn) 104 | 105 | # encoder_outputs: size [max_time, batch_size, num_units] 106 | # when time_major = True 107 | encoder_outputs, encoder_state = tf.nn.dynamic_rnn( 108 | uni_cell, 109 | bi_encoder_outputs, 110 | dtype=dtype, 111 | sequence_length=iterator.source_sequence_length, 112 | time_major=self.time_major) 113 | 114 | # Pass all encoder state except the first bi-directional layer's state to 115 | # decoder. 116 | encoder_state = (bi_encoder_state[1],) + ( 117 | (encoder_state,) if num_uni_layers == 1 else encoder_state) 118 | 119 | return encoder_outputs, encoder_state, encoder_emb_inp 120 | 121 | def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, 122 | source_sequence_length): 123 | """Build a RNN cell with GNMT attention architecture.""" 124 | # Standard attention 125 | if hparams.attention_architecture == "standard": 126 | return super(GNMTModel, self)._build_decoder_cell( 127 | hparams, encoder_outputs, encoder_state, source_sequence_length) 128 | 129 | # GNMT attention 130 | attention_option = hparams.attention 131 | attention_architecture = hparams.attention_architecture 132 | num_units = hparams.num_units 133 | beam_width = hparams.beam_width 134 | 135 | dtype = tf.float32 136 | 137 | if self.time_major: 138 | memory = tf.transpose(encoder_outputs, [1, 0, 2]) 139 | else: 140 | memory = encoder_outputs 141 | 142 | if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0: 143 | memory = tf.contrib.seq2seq.tile_batch( 144 | memory, multiplier=beam_width) 145 | source_sequence_length = tf.contrib.seq2seq.tile_batch( 146 | source_sequence_length, multiplier=beam_width) 147 | encoder_state = tf.contrib.seq2seq.tile_batch( 148 | encoder_state, multiplier=beam_width) 149 | batch_size = self.batch_size * beam_width 150 | else: 151 | batch_size = self.batch_size 152 | 153 | attention_mechanism = self.attention_mechanism_fn( 154 | attention_option, num_units, memory, source_sequence_length, self.mode) 155 | 156 | cell_list = model_helper._cell_list( # pylint: disable=protected-access 157 | unit_type=hparams.unit_type, 158 | num_units=num_units, 159 | num_layers=self.num_decoder_layers, 160 | num_residual_layers=self.num_decoder_residual_layers, 161 | forget_bias=hparams.forget_bias, 162 | dropout=hparams.dropout, 163 | num_gpus=self.num_gpus, 164 | mode=self.mode, 165 | single_cell_fn=self.single_cell_fn, 166 | residual_fn=gnmt_residual_fn 167 | ) 168 | 169 | # Only wrap the bottom layer with the attention mechanism. 170 | attention_cell = cell_list.pop(0) 171 | 172 | # Only generate alignment in greedy INFER mode. 173 | alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and 174 | beam_width == 0) 175 | attention_cell = tf.contrib.seq2seq.AttentionWrapper( 176 | attention_cell, 177 | attention_mechanism, 178 | attention_layer_size=None, # don't use attention layer. 179 | output_attention=False, 180 | alignment_history=alignment_history, 181 | name="attention") 182 | 183 | if attention_architecture == "gnmt": 184 | cell = GNMTAttentionMultiCell( 185 | attention_cell, cell_list) 186 | elif attention_architecture == "gnmt_v2": 187 | cell = GNMTAttentionMultiCell( 188 | attention_cell, cell_list, use_new_attention=True) 189 | else: 190 | raise ValueError( 191 | "Unknown attention_architecture %s" % attention_architecture) 192 | 193 | if hparams.pass_hidden_state: 194 | decoder_initial_state = tuple( 195 | zs.clone(cell_state=es) 196 | if isinstance(zs, tf.contrib.seq2seq.AttentionWrapperState) else es 197 | for zs, es in zip( 198 | cell.zero_state(batch_size, dtype), encoder_state)) 199 | else: 200 | decoder_initial_state = cell.zero_state(batch_size, dtype) 201 | 202 | return cell, decoder_initial_state 203 | 204 | def _get_infer_summary(self, hparams): 205 | # Standard attention 206 | if hparams.attention_architecture == "standard": 207 | return super(GNMTModel, self)._get_infer_summary(hparams) 208 | 209 | # GNMT attention 210 | if hparams.beam_width > 0: 211 | return tf.no_op() 212 | return attention_model._create_attention_images_summary( 213 | self.final_context_state[0]) 214 | 215 | 216 | class GNMTAttentionMultiCell(tf.nn.rnn_cell.MultiRNNCell): 217 | """A MultiCell with GNMT attention style.""" 218 | 219 | def __init__(self, attention_cell, cells, use_new_attention=False): 220 | """Creates a GNMTAttentionMultiCell. 221 | 222 | Args: 223 | attention_cell: An instance of AttentionWrapper. 224 | cells: A list of RNNCell wrapped with AttentionInputWrapper. 225 | use_new_attention: Whether to use the attention generated from current 226 | step bottom layer's output. Default is False. 227 | """ 228 | cells = [attention_cell] + cells 229 | self.use_new_attention = use_new_attention 230 | super(GNMTAttentionMultiCell, self).__init__(cells, state_is_tuple=True) 231 | 232 | def __call__(self, inputs, state, scope=None): 233 | """Run the cell with bottom layer's attention copied to all upper layers.""" 234 | if not nest.is_sequence(state): 235 | raise ValueError( 236 | "Expected state to be a tuple of length %d, but received: %s" 237 | % (len(self.state_size), state)) 238 | 239 | with tf.variable_scope(scope or "multi_rnn_cell"): 240 | new_states = [] 241 | 242 | with tf.variable_scope("cell_0_attention"): 243 | attention_cell = self._cells[0] 244 | attention_state = state[0] 245 | cur_inp, new_attention_state = attention_cell(inputs, attention_state) 246 | new_states.append(new_attention_state) 247 | 248 | for i in range(1, len(self._cells)): 249 | with tf.variable_scope("cell_%d" % i): 250 | 251 | cell = self._cells[i] 252 | cur_state = state[i] 253 | 254 | if self.use_new_attention: 255 | cur_inp = tf.concat([cur_inp, new_attention_state.attention], -1) 256 | else: 257 | cur_inp = tf.concat([cur_inp, attention_state.attention], -1) 258 | 259 | cur_inp, new_state = cell(cur_inp, cur_state) 260 | new_states.append(new_state) 261 | 262 | return cur_inp, tuple(new_states) 263 | 264 | 265 | def gnmt_residual_fn(inputs, outputs): 266 | """Residual function that handles different inputs and outputs inner dims. 267 | 268 | Args: 269 | inputs: cell inputs, this is actual inputs concatenated with the attention 270 | vector. 271 | outputs: cell outputs 272 | 273 | Returns: 274 | outputs + actual inputs 275 | """ 276 | def split_input(inp, out): 277 | out_dim = out.get_shape().as_list()[-1] 278 | inp_dim = inp.get_shape().as_list()[-1] 279 | return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1) 280 | actual_inputs, _ = nest.map_structure(split_input, inputs, outputs) 281 | def assert_shape_match(inp, out): 282 | inp.get_shape().assert_is_compatible_with(out.get_shape()) 283 | nest.assert_same_structure(actual_inputs, outputs) 284 | nest.map_structure(assert_shape_match, actual_inputs, outputs) 285 | return nest.map_structure(lambda inp, out: inp + out, actual_inputs, outputs) 286 | -------------------------------------------------------------------------------- /nmt/scripts/rouge.py: -------------------------------------------------------------------------------- 1 | """ROUGE metric implementation. 2 | 3 | Copy from tf_seq2seq/seq2seq/metrics/rouge.py. 4 | This is a modified and slightly extended verison of 5 | https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import itertools 14 | import numpy as np 15 | 16 | #pylint: disable=C0103 17 | 18 | 19 | def _get_ngrams(n, text): 20 | """Calcualtes n-grams. 21 | 22 | Args: 23 | n: which n-grams to calculate 24 | text: An array of tokens 25 | 26 | Returns: 27 | A set of n-grams 28 | """ 29 | ngram_set = set() 30 | text_length = len(text) 31 | max_index_ngram_start = text_length - n 32 | for i in range(max_index_ngram_start + 1): 33 | ngram_set.add(tuple(text[i:i + n])) 34 | return ngram_set 35 | 36 | 37 | def _split_into_words(sentences): 38 | """Splits multiple sentences into words and flattens the result""" 39 | return list(itertools.chain(*[_.split(" ") for _ in sentences])) 40 | 41 | 42 | def _get_word_ngrams(n, sentences): 43 | """Calculates word n-grams for multiple sentences. 44 | """ 45 | assert len(sentences) > 0 46 | assert n > 0 47 | 48 | words = _split_into_words(sentences) 49 | return _get_ngrams(n, words) 50 | 51 | 52 | def _len_lcs(x, y): 53 | """ 54 | Returns the length of the Longest Common Subsequence between sequences x 55 | and y. 56 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 57 | 58 | Args: 59 | x: sequence of words 60 | y: sequence of words 61 | 62 | Returns 63 | integer: Length of LCS between x and y 64 | """ 65 | table = _lcs(x, y) 66 | n, m = len(x), len(y) 67 | return table[n, m] 68 | 69 | 70 | def _lcs(x, y): 71 | """ 72 | Computes the length of the longest common subsequence (lcs) between two 73 | strings. The implementation below uses a DP programming algorithm and runs 74 | in O(nm) time where n = len(x) and m = len(y). 75 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 76 | 77 | Args: 78 | x: collection of words 79 | y: collection of words 80 | 81 | Returns: 82 | Table of dictionary of coord and len lcs 83 | """ 84 | n, m = len(x), len(y) 85 | table = dict() 86 | for i in range(n + 1): 87 | for j in range(m + 1): 88 | if i == 0 or j == 0: 89 | table[i, j] = 0 90 | elif x[i - 1] == y[j - 1]: 91 | table[i, j] = table[i - 1, j - 1] + 1 92 | else: 93 | table[i, j] = max(table[i - 1, j], table[i, j - 1]) 94 | return table 95 | 96 | 97 | def _recon_lcs(x, y): 98 | """ 99 | Returns the Longest Subsequence between x and y. 100 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 101 | 102 | Args: 103 | x: sequence of words 104 | y: sequence of words 105 | 106 | Returns: 107 | sequence: LCS of x and y 108 | """ 109 | i, j = len(x), len(y) 110 | table = _lcs(x, y) 111 | 112 | def _recon(i, j): 113 | """private recon calculation""" 114 | if i == 0 or j == 0: 115 | return [] 116 | elif x[i - 1] == y[j - 1]: 117 | return _recon(i - 1, j - 1) + [(x[i - 1], i)] 118 | elif table[i - 1, j] > table[i, j - 1]: 119 | return _recon(i - 1, j) 120 | else: 121 | return _recon(i, j - 1) 122 | 123 | recon_tuple = tuple(map(lambda x: x[0], _recon(i, j))) 124 | return recon_tuple 125 | 126 | 127 | def rouge_n(evaluated_sentences, reference_sentences, n=2): 128 | """ 129 | Computes ROUGE-N of two text collections of sentences. 130 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/ 131 | papers/rouge-working-note-v1.3.1.pdf 132 | 133 | Args: 134 | evaluated_sentences: The sentences that have been picked by the summarizer 135 | reference_sentences: The sentences from the referene set 136 | n: Size of ngram. Defaults to 2. 137 | 138 | Returns: 139 | A tuple (f1, precision, recall) for ROUGE-N 140 | 141 | Raises: 142 | ValueError: raises exception if a param has len <= 0 143 | """ 144 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 145 | raise ValueError("Collections must contain at least 1 sentence.") 146 | 147 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences) 148 | reference_ngrams = _get_word_ngrams(n, reference_sentences) 149 | reference_count = len(reference_ngrams) 150 | evaluated_count = len(evaluated_ngrams) 151 | 152 | # Gets the overlapping ngrams between evaluated and reference 153 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 154 | overlapping_count = len(overlapping_ngrams) 155 | 156 | # Handle edge case. This isn't mathematically correct, but it's good enough 157 | if evaluated_count == 0: 158 | precision = 0.0 159 | else: 160 | precision = overlapping_count / evaluated_count 161 | 162 | if reference_count == 0: 163 | recall = 0.0 164 | else: 165 | recall = overlapping_count / reference_count 166 | 167 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) 168 | 169 | # return overlapping_count / reference_count 170 | return f1_score, precision, recall 171 | 172 | 173 | def _f_p_r_lcs(llcs, m, n): 174 | """ 175 | Computes the LCS-based F-measure score 176 | Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 177 | rouge-working-note-v1.3.1.pdf 178 | 179 | Args: 180 | llcs: Length of LCS 181 | m: number of words in reference summary 182 | n: number of words in candidate summary 183 | 184 | Returns: 185 | Float. LCS-based F-measure score 186 | """ 187 | r_lcs = llcs / m 188 | p_lcs = llcs / n 189 | beta = p_lcs / (r_lcs + 1e-12) 190 | num = (1 + (beta**2)) * r_lcs * p_lcs 191 | denom = r_lcs + ((beta**2) * p_lcs) 192 | f_lcs = num / (denom + 1e-12) 193 | return f_lcs, p_lcs, r_lcs 194 | 195 | 196 | def rouge_l_sentence_level(evaluated_sentences, reference_sentences): 197 | """ 198 | Computes ROUGE-L (sentence level) of two text collections of sentences. 199 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 200 | rouge-working-note-v1.3.1.pdf 201 | 202 | Calculated according to: 203 | R_lcs = LCS(X,Y)/m 204 | P_lcs = LCS(X,Y)/n 205 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 206 | 207 | where: 208 | X = reference summary 209 | Y = Candidate summary 210 | m = length of reference summary 211 | n = length of candidate summary 212 | 213 | Args: 214 | evaluated_sentences: The sentences that have been picked by the summarizer 215 | reference_sentences: The sentences from the referene set 216 | 217 | Returns: 218 | A float: F_lcs 219 | 220 | Raises: 221 | ValueError: raises exception if a param has len <= 0 222 | """ 223 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 224 | raise ValueError("Collections must contain at least 1 sentence.") 225 | reference_words = _split_into_words(reference_sentences) 226 | evaluated_words = _split_into_words(evaluated_sentences) 227 | m = len(reference_words) 228 | n = len(evaluated_words) 229 | lcs = _len_lcs(evaluated_words, reference_words) 230 | return _f_p_r_lcs(lcs, m, n) 231 | 232 | 233 | def _union_lcs(evaluated_sentences, reference_sentence): 234 | """ 235 | Returns LCS_u(r_i, C) which is the LCS score of the union longest common 236 | subsequence between reference sentence ri and candidate summary C. For example 237 | if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 and 238 | c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 is 239 | "w1 w2" and the longest common subsequence of r_i and c2 is "w1 w3 w5". The 240 | union longest common subsequence of r_i, c1, and c2 is "w1 w2 w3 w5" and 241 | LCS_u(r_i, C) = 4/5. 242 | 243 | Args: 244 | evaluated_sentences: The sentences that have been picked by the summarizer 245 | reference_sentence: One of the sentences in the reference summaries 246 | 247 | Returns: 248 | float: LCS_u(r_i, C) 249 | 250 | ValueError: 251 | Raises exception if a param has len <= 0 252 | """ 253 | if len(evaluated_sentences) <= 0: 254 | raise ValueError("Collections must contain at least 1 sentence.") 255 | 256 | lcs_union = set() 257 | reference_words = _split_into_words([reference_sentence]) 258 | combined_lcs_length = 0 259 | for eval_s in evaluated_sentences: 260 | evaluated_words = _split_into_words([eval_s]) 261 | lcs = set(_recon_lcs(reference_words, evaluated_words)) 262 | combined_lcs_length += len(lcs) 263 | lcs_union = lcs_union.union(lcs) 264 | 265 | union_lcs_count = len(lcs_union) 266 | union_lcs_value = union_lcs_count / combined_lcs_length 267 | return union_lcs_value 268 | 269 | 270 | def rouge_l_summary_level(evaluated_sentences, reference_sentences): 271 | """ 272 | Computes ROUGE-L (summary level) of two text collections of sentences. 273 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 274 | rouge-working-note-v1.3.1.pdf 275 | 276 | Calculated according to: 277 | R_lcs = SUM(1, u)[LCS(r_i,C)]/m 278 | P_lcs = SUM(1, u)[LCS(r_i,C)]/n 279 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 280 | 281 | where: 282 | SUM(i,u) = SUM from i through u 283 | u = number of sentences in reference summary 284 | C = Candidate summary made up of v sentences 285 | m = number of words in reference summary 286 | n = number of words in candidate summary 287 | 288 | Args: 289 | evaluated_sentences: The sentences that have been picked by the summarizer 290 | reference_sentence: One of the sentences in the reference summaries 291 | 292 | Returns: 293 | A float: F_lcs 294 | 295 | Raises: 296 | ValueError: raises exception if a param has len <= 0 297 | """ 298 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 299 | raise ValueError("Collections must contain at least 1 sentence.") 300 | 301 | # total number of words in reference sentences 302 | m = len(_split_into_words(reference_sentences)) 303 | 304 | # total number of words in evaluated sentences 305 | n = len(_split_into_words(evaluated_sentences)) 306 | 307 | union_lcs_sum_across_all_references = 0 308 | for ref_s in reference_sentences: 309 | union_lcs_sum_across_all_references += _union_lcs(evaluated_sentences, 310 | ref_s) 311 | return _f_p_r_lcs(union_lcs_sum_across_all_references, m, n) 312 | 313 | 314 | def rouge(hypotheses, references): 315 | """Calculates average rouge scores for a list of hypotheses and 316 | references""" 317 | 318 | # Filter out hyps that are of 0 length 319 | # hyps_and_refs = zip(hypotheses, references) 320 | # hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0] 321 | # hypotheses, references = zip(*hyps_and_refs) 322 | 323 | # Calculate ROUGE-1 F1, precision, recall scores 324 | rouge_1 = [ 325 | rouge_n([hyp], [ref], 1) for hyp, ref in zip(hypotheses, references) 326 | ] 327 | rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1)) 328 | 329 | # Calculate ROUGE-2 F1, precision, recall scores 330 | rouge_2 = [ 331 | rouge_n([hyp], [ref], 2) for hyp, ref in zip(hypotheses, references) 332 | ] 333 | rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2)) 334 | 335 | # Calculate ROUGE-L F1, precision, recall scores 336 | rouge_l = [ 337 | rouge_l_sentence_level([hyp], [ref]) 338 | for hyp, ref in zip(hypotheses, references) 339 | ] 340 | rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l)) 341 | 342 | return { 343 | "rouge_1/f_score": rouge_1_f, 344 | "rouge_1/r_score": rouge_1_r, 345 | "rouge_1/p_score": rouge_1_p, 346 | "rouge_2/f_score": rouge_2_f, 347 | "rouge_2/r_score": rouge_2_r, 348 | "rouge_2/p_score": rouge_2_p, 349 | "rouge_l/f_score": rouge_l_f, 350 | "rouge_l/r_score": rouge_l_r, 351 | "rouge_l/p_score": rouge_l_p, 352 | } 353 | -------------------------------------------------------------------------------- /texar/utils/raml_samples_generation/process_samples.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from nltk.translate.bleu_score import sentence_bleu 3 | from nltk.translate.bleu_score import SmoothingFunction 4 | import sys 5 | import re 6 | import argparse 7 | import torch 8 | from util import read_corpus 9 | import numpy as np 10 | from scipy.misc import comb 11 | from vocab import Vocab, VocabEntry 12 | import math 13 | from rouge import Rouge 14 | 15 | 16 | def is_valid_sample(sent): 17 | tokens = sent.split(' ') 18 | return len(tokens) >= 1 and len(tokens) < 50 19 | 20 | 21 | def sample_from_model(args): 22 | para_data = args.parallel_data 23 | sample_file = args.sample_file 24 | output = args.output 25 | 26 | tgt_sent_pattern = re.compile('^\[(\d+)\] (.*?)$') 27 | para_data = [l.strip().split(' ||| ') for l in open(para_data)] 28 | 29 | f_out = open(output, 'w') 30 | f = open(sample_file) 31 | f.readline() 32 | for src_sent, tgt_sent in para_data: 33 | line = f.readline().strip() 34 | assert line.startswith('****') 35 | line = f.readline().strip() 36 | print(line) 37 | assert line.startswith('target:') 38 | 39 | tgt_sent2 = line[len('target:'):] 40 | assert tgt_sent == tgt_sent2 41 | 42 | line = f.readline().strip() # samples 43 | 44 | tgt_sent = ' '.join(tgt_sent.split(' ')[1:-1]) 45 | tgt_samples = set() 46 | for i in range(1, 101): 47 | line = f.readline().rstrip('\n') 48 | m = tgt_sent_pattern.match(line) 49 | 50 | assert m, line 51 | assert int(m.group(1)) == i 52 | 53 | sampled_tgt_sent = m.group(2).strip() 54 | 55 | if is_valid_sample(sampled_tgt_sent): 56 | tgt_samples.add(sampled_tgt_sent) 57 | 58 | line = f.readline().strip() 59 | assert line.startswith('****') 60 | 61 | tgt_samples.add(tgt_sent) 62 | tgt_samples = list(tgt_samples) 63 | 64 | assert len(tgt_samples) > 0 65 | 66 | tgt_ref_tokens = tgt_sent.split(' ') 67 | bleu_scores = [] 68 | for tgt_sample in tgt_samples: 69 | bleu_score = sentence_bleu([tgt_ref_tokens], tgt_sample.split(' ')) 70 | bleu_scores.append(bleu_score) 71 | 72 | tgt_ranks = sorted(range(len(tgt_samples)), key=lambda i: bleu_scores[i], reverse=True) 73 | 74 | print('%d samples' % len(tgt_samples)) 75 | 76 | print('*' * 50, file=f_out) 77 | print('source: ' + src_sent, file=f_out) 78 | print('%d samples' % len(tgt_samples), file=f_out) 79 | for i in tgt_ranks: 80 | print('%s ||| %f' % (tgt_samples[i], bleu_scores[i]), file=f_out) 81 | print('*' * 50, file=f_out) 82 | 83 | f_out.close() 84 | 85 | 86 | def get_new_ngram(ngram, n, vocab): 87 | """ 88 | replace ngram `ngram` with a newly sampled ngram of the same length 89 | """ 90 | 91 | new_ngram_wids = [np.random.randint(3, len(vocab)) for i in range(n)] 92 | new_ngram = [vocab.id2word[wid] for wid in new_ngram_wids] 93 | 94 | return new_ngram 95 | 96 | 97 | def sample_ngram(args): 98 | src_sents = read_corpus(args.src, 'src') 99 | tgt_sents = read_corpus(args.tgt, 'src') # do not read in and 100 | f_out = open(args.output, 'w') 101 | 102 | vocab = torch.load(args.vocab) 103 | tgt_vocab = vocab.tgt 104 | 105 | smooth_bleu = args.smooth_bleu 106 | sm_func = None 107 | if smooth_bleu: 108 | sm_func = SmoothingFunction().method3 109 | 110 | for src_sent, tgt_sent in zip(src_sents, tgt_sents): 111 | src_sent = ' '.join(src_sent) 112 | 113 | tgt_len = len(tgt_sent) 114 | tgt_samples = [] 115 | tgt_samples_distort_rates = [] # how many unigrams are replaced 116 | 117 | # generate 100 samples 118 | 119 | # append itself 120 | tgt_samples.append(tgt_sent) 121 | tgt_samples_distort_rates.append(0) 122 | 123 | for sid in range(args.sample_size - 1): 124 | n = np.random.randint(1, min(tgt_len, args.max_ngram_size + 1)) # we do not replace the last token: it must be a period! 125 | 126 | idx = np.random.randint(tgt_len - n) 127 | ngram = tgt_sent[idx: idx+n] 128 | new_ngram = get_new_ngram(ngram, n, tgt_vocab) 129 | 130 | sampled_tgt_sent = list(tgt_sent) 131 | sampled_tgt_sent[idx: idx+n] = new_ngram 132 | 133 | # compute the probability of this sample 134 | # prob = 1. / args.max_ngram_size * 1. / (tgt_len - 1 + n) * 1 / (len(tgt_vocab) ** n) 135 | 136 | tgt_samples.append(sampled_tgt_sent) 137 | tgt_samples_distort_rates.append(n) 138 | 139 | # compute bleu scores or edit distances and rank the samples by bleu scores 140 | rewards = [] 141 | for tgt_sample, tgt_sample_distort_rate in zip(tgt_samples, tgt_samples_distort_rates): 142 | if args.reward == 'bleu': 143 | reward = sentence_bleu([tgt_sent], tgt_sample, smoothing_function=sm_func) 144 | elif args.reward == 'rouge': 145 | rouge = Rouge() 146 | scores = rouge.get_scores(hyps=[' '.join(tgt_sample).decode('utf-8')], refs=[' '.join(tgt_sent).decode('utf-8')], avg=True) 147 | reward = sum([value['f'] for key, value in scores.items()]) 148 | else: 149 | reward = -tgt_sample_distort_rate 150 | 151 | rewards.append(reward) 152 | 153 | tgt_ranks = sorted(range(len(tgt_samples)), key=lambda i: rewards[i], reverse=True) 154 | # convert list of tokens into a string 155 | tgt_samples = [' '.join(tgt_sample) for tgt_sample in tgt_samples] 156 | 157 | print('*' * 50, file=f_out) 158 | print('source: ' + src_sent, file=f_out) 159 | print('%d samples' % len(tgt_samples), file=f_out) 160 | for i in tgt_ranks: 161 | print('%s ||| %f' % (tgt_samples[i], rewards[i]), file=f_out) 162 | print('*' * 50, file=f_out) 163 | 164 | f_out.close() 165 | 166 | 167 | def sample_ngram_adapt(args): 168 | src_sents = read_corpus(args.src, 'src') 169 | tgt_sents = read_corpus(args.tgt, 'src') # do not read in and 170 | f_out = open(args.output, 'w') 171 | 172 | vocab = torch.load(args.vocab) 173 | tgt_vocab = vocab.tgt 174 | 175 | max_len = max([len(tgt_sent) for tgt_sent in tgt_sents]) + 1 176 | 177 | for src_sent, tgt_sent in zip(src_sents, tgt_sents): 178 | src_sent = ' '.join(src_sent) 179 | 180 | tgt_len = len(tgt_sent) 181 | tgt_samples = [] 182 | 183 | # generate 100 samples 184 | 185 | # append itself 186 | tgt_samples.append(tgt_sent) 187 | 188 | for sid in range(args.sample_size - 1): 189 | max_n = min(tgt_len - 1, 4) 190 | bias_n = int(max_n * tgt_len / max_len) + 1 191 | assert 1 <= bias_n <= 4, 'bias_n={}, not in [1,4], max_n={}, tgt_len={}, max_len={}'.format(bias_n, max_n, tgt_len, max_len) 192 | 193 | p = [1.0/(max_n + 5)] * max_n 194 | p[bias_n - 1] = 1 - p[0] * (max_n - 1) 195 | assert abs(sum(p) - 1) < 1e-10, 'sum(p) != 1' 196 | 197 | n = np.random.choice(np.arange(1, int(max_n + 1)), p=p) # we do not replace the last token: it must be a period! 198 | assert n < tgt_len, 'n={}, tgt_len={}'.format(n, tgt_len) 199 | 200 | idx = np.random.randint(tgt_len - n) 201 | ngram = tgt_sent[idx: idx+n] 202 | new_ngram = get_new_ngram(ngram, n, tgt_vocab) 203 | 204 | sampled_tgt_sent = list(tgt_sent) 205 | sampled_tgt_sent[idx: idx+n] = new_ngram 206 | 207 | tgt_samples.append(sampled_tgt_sent) 208 | 209 | # compute bleu scores and rank the samples by bleu scores 210 | bleu_scores = [] 211 | for tgt_sample in tgt_samples: 212 | bleu_score = sentence_bleu([tgt_sent], tgt_sample) 213 | bleu_scores.append(bleu_score) 214 | 215 | tgt_ranks = sorted(range(len(tgt_samples)), key=lambda i: bleu_scores[i], reverse=True) 216 | # convert list of tokens into a string 217 | tgt_samples = [' '.join(tgt_sample) for tgt_sample in tgt_samples] 218 | 219 | print('*' * 50, file=f_out) 220 | print('source: ' + src_sent, file=f_out) 221 | print('%d samples' % len(tgt_samples), file=f_out) 222 | for i in tgt_ranks: 223 | print('%s ||| %f' % (tgt_samples[i], bleu_scores[i]), file=f_out) 224 | print('*' * 50, file=f_out) 225 | 226 | f_out.close() 227 | 228 | 229 | def sample_from_hamming_distance_payoff_distribution(args): 230 | src_sents = read_corpus(args.src, 'src') 231 | tgt_sents = read_corpus(args.tgt, 'src') # do not read in and 232 | f_out = open(args.output, 'w') 233 | 234 | vocab = torch.load(args.vocab) 235 | tgt_vocab = vocab.tgt 236 | 237 | payoff_prob, Z_qs = generate_hamming_distance_payoff_distribution(max(len(sent) for sent in tgt_sents), 238 | vocab_size=len(vocab.tgt), 239 | tau=args.temp) 240 | 241 | for src_sent, tgt_sent in zip(src_sents, tgt_sents): 242 | tgt_samples = [] # make sure the ground truth y* is in the samples 243 | tgt_sent_len = len(tgt_sent) - 3 # remove and and ending period . 244 | tgt_ref_tokens = tgt_sent[1:-1] 245 | bleu_scores = [] 246 | 247 | # sample an edit distances 248 | e_samples = np.random.choice(range(tgt_sent_len + 1), p=payoff_prob[tgt_sent_len], size=args.sample_size, 249 | replace=True) 250 | 251 | for i, e in enumerate(e_samples): 252 | if e > 0: 253 | # sample a new tgt_sent $y$ 254 | old_word_pos = np.random.choice(range(1, tgt_sent_len + 1), size=e, replace=False) 255 | new_words = [vocab.tgt.id2word[wid] for wid in np.random.randint(3, len(vocab.tgt), size=e)] 256 | new_tgt_sent = list(tgt_sent) 257 | for pos, word in zip(old_word_pos, new_words): 258 | new_tgt_sent[pos] = word 259 | 260 | bleu_score = sentence_bleu([tgt_ref_tokens], new_tgt_sent[1:-1]) 261 | bleu_scores.append(bleu_score) 262 | else: 263 | new_tgt_sent = list(tgt_sent) 264 | bleu_scores.append(1.) 265 | 266 | # print('y: %s' % ' '.join(new_tgt_sent)) 267 | tgt_samples.append(new_tgt_sent) 268 | 269 | 270 | def generate_hamming_distance_payoff_distribution(max_sent_len, vocab_size, tau=1.): 271 | """compute the q distribution for Hamming Distance (substitution only) as in the RAML paper""" 272 | probs = dict() 273 | Z_qs = dict() 274 | for sent_len in range(1, max_sent_len + 1): 275 | counts = [1.] # e = 0, count = 1 276 | for e in range(1, sent_len + 1): 277 | # apply the rescaling trick as in https://gist.github.com/norouzi/8c4d244922fa052fa8ec18d8af52d366 278 | count = comb(sent_len, e) * math.exp(-e / tau) * ((vocab_size - 1) ** (e - e / tau)) 279 | counts.append(count) 280 | 281 | Z_qs[sent_len] = Z_q = sum(counts) 282 | prob = [count / Z_q for count in counts] 283 | probs[sent_len] = prob 284 | 285 | # print('sent_len=%d, %s' % (sent_len, prob)) 286 | 287 | return probs, Z_qs 288 | 289 | 290 | if __name__ == '__main__': 291 | parser = argparse.ArgumentParser() 292 | parser.add_argument('--mode', choices=['sample_from_model', 'sample_ngram_adapt', 'sample_ngram'], required=True) 293 | parser.add_argument('--vocab', type=str) 294 | parser.add_argument('--src', type=str) 295 | parser.add_argument('--tgt', type=str) 296 | parser.add_argument('--parallel_data', type=str) 297 | parser.add_argument('--sample_file', type=str) 298 | parser.add_argument('--output', type=str, required=True) 299 | parser.add_argument('--sample_size', type=int, default=100) 300 | parser.add_argument('--reward', choices=['bleu', 'edit_dist', 'rouge'], default='bleu') 301 | parser.add_argument('--max_ngram_size', type=int, default=4) 302 | parser.add_argument('--temp', type=float, default=0.5) 303 | parser.add_argument('--smooth_bleu', action='store_true', default=False) 304 | 305 | args = parser.parse_args() 306 | 307 | if args.mode == 'sample_ngram': 308 | sample_ngram(args) 309 | elif args.mode == 'sample_from_model': 310 | sample_from_model(args) 311 | elif args.mode == 'sample_ngram_adapt': 312 | sample_ngram_adapt(args) 313 | -------------------------------------------------------------------------------- /nmt/utils/iterator_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for iterator_utils.py""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from tensorflow.python.ops import lookup_ops 25 | 26 | from ..utils import iterator_utils 27 | 28 | 29 | class IteratorUtilsTest(tf.test.TestCase): 30 | 31 | def testGetIterator(self): 32 | tf.set_random_seed(1) 33 | tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( 34 | tf.constant(["a", "b", "c", "eos", "sos"])) 35 | src_dataset = tf.data.Dataset.from_tensor_slices( 36 | tf.constant(["f e a g", "c c a", "d", "c a"])) 37 | tgt_dataset = tf.data.Dataset.from_tensor_slices( 38 | tf.constant(["c c", "a b", "", "b c"])) 39 | hparams = tf.contrib.training.HParams( 40 | random_seed=3, 41 | num_buckets=5, 42 | eos="eos", 43 | sos="sos") 44 | batch_size = 2 45 | src_max_len = 3 46 | iterator = iterator_utils.get_iterator( 47 | src_dataset=src_dataset, 48 | tgt_dataset=tgt_dataset, 49 | src_vocab_table=src_vocab_table, 50 | tgt_vocab_table=tgt_vocab_table, 51 | batch_size=batch_size, 52 | sos=hparams.sos, 53 | eos=hparams.eos, 54 | random_seed=hparams.random_seed, 55 | num_buckets=hparams.num_buckets, 56 | src_max_len=src_max_len, 57 | reshuffle_each_iteration=False) 58 | table_initializer = tf.tables_initializer() 59 | source = iterator.source 60 | target_input = iterator.target_input 61 | target_output = iterator.target_output 62 | src_seq_len = iterator.source_sequence_length 63 | tgt_seq_len = iterator.target_sequence_length 64 | self.assertEqual([None, None], source.shape.as_list()) 65 | self.assertEqual([None, None], target_input.shape.as_list()) 66 | self.assertEqual([None, None], target_output.shape.as_list()) 67 | self.assertEqual([None], src_seq_len.shape.as_list()) 68 | self.assertEqual([None], tgt_seq_len.shape.as_list()) 69 | with self.test_session() as sess: 70 | sess.run(table_initializer) 71 | sess.run(iterator.initializer) 72 | 73 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 74 | sess.run((source, src_seq_len, target_input, target_output, 75 | tgt_seq_len))) 76 | self.assertAllEqual( 77 | [[-1, -1, 0], # "f" == unknown, "e" == unknown, a 78 | [2, 0, 3]], # c a eos -- eos is padding 79 | source_v) 80 | self.assertAllEqual([3, 2], src_len_v) 81 | self.assertAllEqual( 82 | [[4, 2, 2], # sos c c 83 | [4, 1, 2]], # sos b c 84 | target_input_v) 85 | self.assertAllEqual( 86 | [[2, 2, 3], # c c eos 87 | [1, 2, 3]], # b c eos 88 | target_output_v) 89 | self.assertAllEqual([3, 3], tgt_len_v) 90 | 91 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 92 | sess.run((source, src_seq_len, target_input, target_output, 93 | tgt_seq_len))) 94 | self.assertAllEqual( 95 | [[2, 2, 0]], # c c a 96 | source_v) 97 | self.assertAllEqual([3], src_len_v) 98 | self.assertAllEqual( 99 | [[4, 0, 1]], # sos a b 100 | target_input_v) 101 | self.assertAllEqual( 102 | [[0, 1, 3]], # a b eos 103 | target_output_v) 104 | self.assertAllEqual([3], tgt_len_v) 105 | 106 | with self.assertRaisesOpError("End of sequence"): 107 | sess.run(source) 108 | 109 | def testGetIteratorWithShard(self): 110 | tf.set_random_seed(1) 111 | tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( 112 | tf.constant(["a", "b", "c", "eos", "sos"])) 113 | src_dataset = tf.data.Dataset.from_tensor_slices( 114 | tf.constant(["c c a", "f e a g", "d", "c a"])) 115 | tgt_dataset = tf.data.Dataset.from_tensor_slices( 116 | tf.constant(["a b", "c c", "", "b c"])) 117 | hparams = tf.contrib.training.HParams( 118 | random_seed=3, 119 | num_buckets=5, 120 | eos="eos", 121 | sos="sos") 122 | batch_size = 2 123 | src_max_len = 3 124 | iterator = iterator_utils.get_iterator( 125 | src_dataset=src_dataset, 126 | tgt_dataset=tgt_dataset, 127 | src_vocab_table=src_vocab_table, 128 | tgt_vocab_table=tgt_vocab_table, 129 | batch_size=batch_size, 130 | sos=hparams.sos, 131 | eos=hparams.eos, 132 | random_seed=hparams.random_seed, 133 | num_buckets=hparams.num_buckets, 134 | src_max_len=src_max_len, 135 | num_shards=2, 136 | shard_index=1, 137 | reshuffle_each_iteration=False) 138 | table_initializer = tf.tables_initializer() 139 | source = iterator.source 140 | target_input = iterator.target_input 141 | target_output = iterator.target_output 142 | src_seq_len = iterator.source_sequence_length 143 | tgt_seq_len = iterator.target_sequence_length 144 | self.assertEqual([None, None], source.shape.as_list()) 145 | self.assertEqual([None, None], target_input.shape.as_list()) 146 | self.assertEqual([None, None], target_output.shape.as_list()) 147 | self.assertEqual([None], src_seq_len.shape.as_list()) 148 | self.assertEqual([None], tgt_seq_len.shape.as_list()) 149 | with self.test_session() as sess: 150 | sess.run(table_initializer) 151 | sess.run(iterator.initializer) 152 | 153 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 154 | sess.run((source, src_seq_len, target_input, target_output, 155 | tgt_seq_len))) 156 | self.assertAllEqual( 157 | [[-1, -1, 0], # "f" == unknown, "e" == unknown, a 158 | [2, 0, 3]], # c a eos -- eos is padding 159 | source_v) 160 | self.assertAllEqual([3, 2], src_len_v) 161 | self.assertAllEqual( 162 | [[4, 2, 2], # sos c c 163 | [4, 1, 2]], # sos b c 164 | target_input_v) 165 | self.assertAllEqual( 166 | [[2, 2, 3], # c c eos 167 | [1, 2, 3]], # b c eos 168 | target_output_v) 169 | self.assertAllEqual([3, 3], tgt_len_v) 170 | 171 | with self.assertRaisesOpError("End of sequence"): 172 | sess.run(source) 173 | 174 | def testGetIteratorWithSkipCount(self): 175 | tf.set_random_seed(1) 176 | tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( 177 | tf.constant(["a", "b", "c", "eos", "sos"])) 178 | src_dataset = tf.data.Dataset.from_tensor_slices( 179 | tf.constant(["c a", "c c a", "d", "f e a g"])) 180 | tgt_dataset = tf.data.Dataset.from_tensor_slices( 181 | tf.constant(["b c", "a b", "", "c c"])) 182 | hparams = tf.contrib.training.HParams( 183 | random_seed=3, 184 | num_buckets=5, 185 | eos="eos", 186 | sos="sos") 187 | batch_size = 2 188 | src_max_len = 3 189 | skip_count = tf.placeholder(shape=(), dtype=tf.int64) 190 | iterator = iterator_utils.get_iterator( 191 | src_dataset=src_dataset, 192 | tgt_dataset=tgt_dataset, 193 | src_vocab_table=src_vocab_table, 194 | tgt_vocab_table=tgt_vocab_table, 195 | batch_size=batch_size, 196 | sos=hparams.sos, 197 | eos=hparams.eos, 198 | random_seed=hparams.random_seed, 199 | num_buckets=hparams.num_buckets, 200 | src_max_len=src_max_len, 201 | skip_count=skip_count, 202 | reshuffle_each_iteration=False) 203 | table_initializer = tf.tables_initializer() 204 | source = iterator.source 205 | target_input = iterator.target_input 206 | target_output = iterator.target_output 207 | src_seq_len = iterator.source_sequence_length 208 | tgt_seq_len = iterator.target_sequence_length 209 | self.assertEqual([None, None], source.shape.as_list()) 210 | self.assertEqual([None, None], target_input.shape.as_list()) 211 | self.assertEqual([None, None], target_output.shape.as_list()) 212 | self.assertEqual([None], src_seq_len.shape.as_list()) 213 | self.assertEqual([None], tgt_seq_len.shape.as_list()) 214 | with self.test_session() as sess: 215 | sess.run(table_initializer) 216 | sess.run(iterator.initializer, feed_dict={skip_count: 3}) 217 | 218 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 219 | sess.run((source, src_seq_len, target_input, target_output, 220 | tgt_seq_len))) 221 | self.assertAllEqual( 222 | [[-1, -1, 0]], # "f" == unknown, "e" == unknown, a 223 | source_v) 224 | self.assertAllEqual([3], src_len_v) 225 | self.assertAllEqual( 226 | [[4, 2, 2]], # sos c c 227 | target_input_v) 228 | self.assertAllEqual( 229 | [[2, 2, 3]], # c c eos 230 | target_output_v) 231 | self.assertAllEqual([3], tgt_len_v) 232 | 233 | with self.assertRaisesOpError("End of sequence"): 234 | sess.run(source) 235 | 236 | # Re-init iterator with skip_count=0. 237 | sess.run(iterator.initializer, feed_dict={skip_count: 0}) 238 | 239 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 240 | sess.run((source, src_seq_len, target_input, target_output, 241 | tgt_seq_len))) 242 | self.assertAllEqual( 243 | [[2, 0, 3], # c a eos -- eos is padding 244 | [-1, -1, 0]], # "f" == unknown, "e" == unknown, a 245 | source_v) 246 | self.assertAllEqual([2, 3], src_len_v) 247 | self.assertAllEqual( 248 | [[4, 1, 2], # sos b c 249 | [4, 2, 2]], # sos c c 250 | target_input_v) 251 | self.assertAllEqual( 252 | [[1, 2, 3], # b c eos 253 | [2, 2, 3]], # c c eos 254 | target_output_v) 255 | self.assertAllEqual([3, 3], tgt_len_v) 256 | 257 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 258 | sess.run((source, src_seq_len, target_input, target_output, 259 | tgt_seq_len))) 260 | self.assertAllEqual( 261 | [[2, 2, 0]], # c c a 262 | source_v) 263 | self.assertAllEqual([3], src_len_v) 264 | self.assertAllEqual( 265 | [[4, 0, 1]], # sos a b 266 | target_input_v) 267 | self.assertAllEqual( 268 | [[0, 1, 3]], # a b eos 269 | target_output_v) 270 | self.assertAllEqual([3], tgt_len_v) 271 | 272 | with self.assertRaisesOpError("End of sequence"): 273 | sess.run(source) 274 | 275 | 276 | def testGetInferIterator(self): 277 | src_vocab_table = lookup_ops.index_table_from_tensor( 278 | tf.constant(["a", "b", "c", "eos", "sos"])) 279 | src_dataset = tf.data.Dataset.from_tensor_slices( 280 | tf.constant(["c c a", "c a", "d", "f e a g"])) 281 | hparams = tf.contrib.training.HParams( 282 | random_seed=3, 283 | eos="eos", 284 | sos="sos") 285 | batch_size = 2 286 | src_max_len = 3 287 | iterator = iterator_utils.get_infer_iterator( 288 | src_dataset=src_dataset, 289 | src_vocab_table=src_vocab_table, 290 | batch_size=batch_size, 291 | eos=hparams.eos, 292 | src_max_len=src_max_len) 293 | table_initializer = tf.tables_initializer() 294 | source = iterator.source 295 | seq_len = iterator.source_sequence_length 296 | self.assertEqual([None, None], source.shape.as_list()) 297 | self.assertEqual([None], seq_len.shape.as_list()) 298 | with self.test_session() as sess: 299 | sess.run(table_initializer) 300 | sess.run(iterator.initializer) 301 | 302 | (source_v, seq_len_v) = sess.run((source, seq_len)) 303 | self.assertAllEqual( 304 | [[2, 2, 0], # c c a 305 | [2, 0, 3]], # c a eos 306 | source_v) 307 | self.assertAllEqual([3, 2], seq_len_v) 308 | 309 | (source_v, seq_len_v) = sess.run((source, seq_len)) 310 | self.assertAllEqual( 311 | [[-1, 3, 3], # "d" == unknown, eos eos 312 | [-1, -1, 0]], # "f" == unknown, "e" == unknown, a 313 | source_v) 314 | self.assertAllEqual([1, 3], seq_len_v) 315 | 316 | with self.assertRaisesOpError("End of sequence"): 317 | sess.run((source, seq_len)) 318 | 319 | 320 | if __name__ == "__main__": 321 | tf.test.main() 322 | -------------------------------------------------------------------------------- /nmt/model_helper.py: -------------------------------------------------------------------------------- 1 | """Utility functions for building models.""" 2 | from __future__ import print_function 3 | 4 | import collections 5 | import six 6 | import os 7 | import time 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | from tensorflow.python.ops import lookup_ops 13 | 14 | from .utils import iterator_utils 15 | from .utils import misc_utils as utils 16 | from .utils import vocab_utils 17 | 18 | 19 | __all__ = [ 20 | "get_initializer", "get_device_str", "create_train_model", 21 | "create_eval_model", "create_infer_model", 22 | "create_emb_for_encoder_and_decoder", "create_rnn_cell", "gradient_clip", 23 | "create_or_load_model", "load_model", "avg_checkpoints", 24 | "compute_perplexity" 25 | ] 26 | 27 | # If a vocab size is greater than this value, put the embedding on cpu instead 28 | VOCAB_SIZE_THRESHOLD_CPU = 50000 29 | 30 | 31 | def get_initializer(init_op, seed=None, init_weight=None): 32 | """Create an initializer. init_weight is only for uniform.""" 33 | if init_op == "uniform": 34 | assert init_weight 35 | return tf.random_uniform_initializer( 36 | -init_weight, init_weight, seed=seed) 37 | elif init_op == "glorot_normal": 38 | return tf.keras.initializers.glorot_normal( 39 | seed=seed) 40 | elif init_op == "glorot_uniform": 41 | return tf.keras.initializers.glorot_uniform( 42 | seed=seed) 43 | else: 44 | raise ValueError("Unknown init_op %s" % init_op) 45 | 46 | 47 | def get_device_str(device_id, num_gpus): 48 | """Return a device string for multi-GPU setup.""" 49 | if num_gpus == 0: 50 | return "/cpu:0" 51 | device_str_output = "/gpu:%d" % (device_id % num_gpus) 52 | return device_str_output 53 | 54 | 55 | class ExtraArgs(collections.namedtuple( 56 | "ExtraArgs", ("single_cell_fn", "model_device_fn", 57 | "attention_mechanism_fn"))): 58 | pass 59 | 60 | 61 | class TrainModel( 62 | collections.namedtuple("TrainModel", ("graph", "model", "iterator", 63 | "skip_count_placeholder"))): 64 | pass 65 | 66 | 67 | def create_train_model( 68 | model_creator, hparams, scope=None, num_workers=1, jobid=0, 69 | extra_args=None): 70 | """Create train graph, model, and iterator.""" 71 | src_file = "%s.%s" % (hparams.train_prefix, hparams.src) 72 | tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) 73 | src_vocab_file = hparams.src_vocab_file 74 | tgt_vocab_file = hparams.tgt_vocab_file 75 | 76 | graph = tf.Graph() 77 | 78 | with graph.as_default(), tf.container(scope or "train"): 79 | src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( 80 | src_vocab_file, tgt_vocab_file, hparams.share_vocab) 81 | 82 | src_dataset = tf.data.TextLineDataset(src_file) 83 | tgt_dataset = tf.data.TextLineDataset(tgt_file) 84 | skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) 85 | 86 | iterator = iterator_utils.get_iterator( 87 | src_dataset, 88 | tgt_dataset, 89 | src_vocab_table, 90 | tgt_vocab_table, 91 | batch_size=hparams.batch_size, 92 | sos=hparams.sos, 93 | eos=hparams.eos, 94 | random_seed=hparams.random_seed, 95 | num_buckets=hparams.num_buckets, 96 | src_max_len=hparams.src_max_len, 97 | tgt_max_len=hparams.tgt_max_len, 98 | skip_count=skip_count_placeholder, 99 | num_shards=num_workers, 100 | shard_index=jobid) 101 | 102 | # Note: One can set model_device_fn to 103 | # `tf.train.replica_device_setter(ps_tasks)` for distributed training. 104 | model_device_fn = None 105 | if extra_args: model_device_fn = extra_args.model_device_fn 106 | with tf.device(model_device_fn): 107 | model = model_creator( 108 | hparams, 109 | iterator=iterator, 110 | mode=tf.contrib.learn.ModeKeys.TRAIN, 111 | source_vocab_table=src_vocab_table, 112 | target_vocab_table=tgt_vocab_table, 113 | scope=scope, 114 | extra_args=extra_args) 115 | 116 | return TrainModel( 117 | graph=graph, 118 | model=model, 119 | iterator=iterator, 120 | skip_count_placeholder=skip_count_placeholder) 121 | 122 | 123 | class EvalModel( 124 | collections.namedtuple("EvalModel", 125 | ("graph", "model", "src_file_placeholder", 126 | "tgt_file_placeholder", "iterator"))): 127 | pass 128 | 129 | 130 | def create_eval_model(model_creator, hparams, scope=None, extra_args=None): 131 | """Create train graph, model, src/tgt file holders, and iterator.""" 132 | src_vocab_file = hparams.src_vocab_file 133 | tgt_vocab_file = hparams.tgt_vocab_file 134 | graph = tf.Graph() 135 | 136 | with graph.as_default(), tf.container(scope or "eval"): 137 | src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( 138 | src_vocab_file, tgt_vocab_file, hparams.share_vocab) 139 | src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) 140 | tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) 141 | src_dataset = tf.data.TextLineDataset(src_file_placeholder) 142 | tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder) 143 | iterator = iterator_utils.get_iterator( 144 | src_dataset, 145 | tgt_dataset, 146 | src_vocab_table, 147 | tgt_vocab_table, 148 | hparams.batch_size, 149 | sos=hparams.sos, 150 | eos=hparams.eos, 151 | random_seed=hparams.random_seed, 152 | num_buckets=hparams.num_buckets, 153 | src_max_len=hparams.src_max_len_infer, 154 | tgt_max_len=hparams.tgt_max_len_infer) 155 | model = model_creator( 156 | hparams, 157 | iterator=iterator, 158 | mode=tf.contrib.learn.ModeKeys.EVAL, 159 | source_vocab_table=src_vocab_table, 160 | target_vocab_table=tgt_vocab_table, 161 | scope=scope, 162 | extra_args=extra_args) 163 | return EvalModel( 164 | graph=graph, 165 | model=model, 166 | src_file_placeholder=src_file_placeholder, 167 | tgt_file_placeholder=tgt_file_placeholder, 168 | iterator=iterator) 169 | 170 | 171 | class InferModel( 172 | collections.namedtuple("InferModel", 173 | ("graph", "model", "src_placeholder", 174 | "batch_size_placeholder", "iterator"))): 175 | pass 176 | 177 | 178 | def create_infer_model(model_creator, hparams, scope=None, extra_args=None): 179 | """Create inference model.""" 180 | graph = tf.Graph() 181 | src_vocab_file = hparams.src_vocab_file 182 | tgt_vocab_file = hparams.tgt_vocab_file 183 | 184 | with graph.as_default(), tf.container(scope or "infer"): 185 | src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( 186 | src_vocab_file, tgt_vocab_file, hparams.share_vocab) 187 | reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( 188 | tgt_vocab_file, default_value=vocab_utils.UNK) 189 | 190 | src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) 191 | batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) 192 | 193 | src_dataset = tf.data.Dataset.from_tensor_slices( 194 | src_placeholder) 195 | iterator = iterator_utils.get_infer_iterator( 196 | src_dataset, 197 | src_vocab_table, 198 | batch_size=batch_size_placeholder, 199 | eos=hparams.eos, 200 | src_max_len=hparams.src_max_len_infer) 201 | model = model_creator( 202 | hparams, 203 | iterator=iterator, 204 | mode=tf.contrib.learn.ModeKeys.INFER, 205 | source_vocab_table=src_vocab_table, 206 | target_vocab_table=tgt_vocab_table, 207 | reverse_target_vocab_table=reverse_tgt_vocab_table, 208 | scope=scope, 209 | extra_args=extra_args) 210 | return InferModel( 211 | graph=graph, 212 | model=model, 213 | src_placeholder=src_placeholder, 214 | batch_size_placeholder=batch_size_placeholder, 215 | iterator=iterator) 216 | 217 | 218 | def _get_embed_device(vocab_size): 219 | """Decide on which device to place an embed matrix given its vocab size.""" 220 | if vocab_size > VOCAB_SIZE_THRESHOLD_CPU: 221 | return "/cpu:0" 222 | else: 223 | return "/gpu:0" 224 | 225 | 226 | def _create_pretrained_emb_from_txt( 227 | vocab_file, embed_file, num_trainable_tokens=3, dtype=tf.float32, 228 | scope=None): 229 | """Load pretrain embeding from embed_file, and return an embedding matrix. 230 | 231 | Args: 232 | embed_file: Path to a Glove formated embedding txt file. 233 | num_trainable_tokens: Make the first n tokens in the vocab file as trainable 234 | variables. Default is 3, which is "", "" and "". 235 | """ 236 | vocab, _ = vocab_utils.load_vocab(vocab_file) 237 | trainable_tokens = vocab[:num_trainable_tokens] 238 | 239 | utils.print_out("# Using pretrained embedding: %s." % embed_file) 240 | utils.print_out(" with trainable tokens: ") 241 | 242 | emb_dict, emb_size = vocab_utils.load_embed_txt(embed_file) 243 | for token in trainable_tokens: 244 | utils.print_out(" %s" % token) 245 | if token not in emb_dict: 246 | emb_dict[token] = [0.0] * emb_size 247 | 248 | emb_mat = np.array( 249 | [emb_dict[token] for token in vocab], dtype=dtype.as_numpy_dtype()) 250 | emb_mat = tf.constant(emb_mat) 251 | emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1]) 252 | with tf.variable_scope(scope or "pretrain_embeddings", dtype=dtype) as scope: 253 | with tf.device(_get_embed_device(num_trainable_tokens)): 254 | emb_mat_var = tf.get_variable( 255 | "emb_mat_var", [num_trainable_tokens, emb_size]) 256 | return tf.concat([emb_mat_var, emb_mat_const], 0) 257 | 258 | 259 | def _create_or_load_embed(embed_name, vocab_file, embed_file, 260 | vocab_size, embed_size, dtype): 261 | """Create a new or load an existing embedding matrix.""" 262 | if vocab_file and embed_file: 263 | embedding = _create_pretrained_emb_from_txt(vocab_file, embed_file) 264 | else: 265 | with tf.device(_get_embed_device(vocab_size)): 266 | embedding = tf.get_variable( 267 | embed_name, [vocab_size, embed_size], dtype) 268 | return embedding 269 | 270 | 271 | def create_emb_for_encoder_and_decoder(share_vocab, 272 | src_vocab_size, 273 | tgt_vocab_size, 274 | src_embed_size, 275 | tgt_embed_size, 276 | dtype=tf.float32, 277 | num_partitions=0, 278 | src_vocab_file=None, 279 | tgt_vocab_file=None, 280 | src_embed_file=None, 281 | tgt_embed_file=None, 282 | scope=None): 283 | """Create embedding matrix for both encoder and decoder. 284 | 285 | Args: 286 | share_vocab: A boolean. Whether to share embedding matrix for both 287 | encoder and decoder. 288 | src_vocab_size: An integer. The source vocab size. 289 | tgt_vocab_size: An integer. The target vocab size. 290 | src_embed_size: An integer. The embedding dimension for the encoder's 291 | embedding. 292 | tgt_embed_size: An integer. The embedding dimension for the decoder's 293 | embedding. 294 | dtype: dtype of the embedding matrix. Default to float32. 295 | num_partitions: number of partitions used for the embedding vars. 296 | scope: VariableScope for the created subgraph. Default to "embedding". 297 | 298 | Returns: 299 | embedding_encoder: Encoder's embedding matrix. 300 | embedding_decoder: Decoder's embedding matrix. 301 | 302 | Raises: 303 | ValueError: if use share_vocab but source and target have different vocab 304 | size. 305 | """ 306 | 307 | if num_partitions <= 1: 308 | partitioner = None 309 | else: 310 | # Note: num_partitions > 1 is required for distributed training due to 311 | # embedding_lookup tries to colocate single partition-ed embedding variable 312 | # with lookup ops. This may cause embedding variables being placed on worker 313 | # jobs. 314 | partitioner = tf.fixed_size_partitioner(num_partitions) 315 | 316 | if (src_embed_file or tgt_embed_file) and partitioner: 317 | raise ValueError( 318 | "Can't set num_partitions > 1 when using pretrained embedding") 319 | 320 | with tf.variable_scope( 321 | scope or "embeddings", dtype=dtype, partitioner=partitioner) as scope: 322 | # Share embedding 323 | if share_vocab: 324 | if src_vocab_size != tgt_vocab_size: 325 | raise ValueError("Share embedding but different src/tgt vocab sizes" 326 | " %d vs. %d" % (src_vocab_size, tgt_vocab_size)) 327 | assert src_embed_size == tgt_embed_size 328 | utils.print_out("# Use the same embedding for source and target") 329 | vocab_file = src_vocab_file or tgt_vocab_file 330 | embed_file = src_embed_file or tgt_embed_file 331 | 332 | embedding_encoder = _create_or_load_embed( 333 | "embedding_share", vocab_file, embed_file, 334 | src_vocab_size, src_embed_size, dtype) 335 | embedding_decoder = embedding_encoder 336 | else: 337 | with tf.variable_scope("encoder", partitioner=partitioner): 338 | embedding_encoder = _create_or_load_embed( 339 | "embedding_encoder", src_vocab_file, src_embed_file, 340 | src_vocab_size, src_embed_size, dtype) 341 | 342 | with tf.variable_scope("decoder", partitioner=partitioner): 343 | embedding_decoder = _create_or_load_embed( 344 | "embedding_decoder", tgt_vocab_file, tgt_embed_file, 345 | tgt_vocab_size, tgt_embed_size, dtype) 346 | 347 | return embedding_encoder, embedding_decoder 348 | 349 | 350 | def _single_cell(unit_type, num_units, forget_bias, dropout, mode, 351 | residual_connection=False, device_str=None, residual_fn=None): 352 | """Create an instance of a single RNN cell.""" 353 | # dropout (= 1 - keep_prob) is set to 0 during eval and infer 354 | dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0 355 | 356 | # Cell Type 357 | if unit_type == "lstm": 358 | utils.print_out(" LSTM, forget_bias=%g" % forget_bias, new_line=False) 359 | single_cell = tf.contrib.rnn.BasicLSTMCell( 360 | num_units, 361 | forget_bias=forget_bias) 362 | elif unit_type == "gru": 363 | utils.print_out(" GRU", new_line=False) 364 | single_cell = tf.contrib.rnn.GRUCell(num_units) 365 | elif unit_type == "layer_norm_lstm": 366 | utils.print_out(" Layer Normalized LSTM, forget_bias=%g" % forget_bias, 367 | new_line=False) 368 | single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell( 369 | num_units, 370 | forget_bias=forget_bias, 371 | layer_norm=True) 372 | elif unit_type == "nas": 373 | utils.print_out(" NASCell", new_line=False) 374 | single_cell = tf.contrib.rnn.NASCell(num_units) 375 | else: 376 | raise ValueError("Unknown unit type %s!" % unit_type) 377 | 378 | # Dropout (= 1 - keep_prob) 379 | if dropout > 0.0: 380 | single_cell = tf.contrib.rnn.DropoutWrapper( 381 | cell=single_cell, input_keep_prob=(1.0 - dropout)) 382 | utils.print_out(" %s, dropout=%g " %(type(single_cell).__name__, dropout), 383 | new_line=False) 384 | 385 | # Residual 386 | if residual_connection: 387 | single_cell = tf.contrib.rnn.ResidualWrapper( 388 | single_cell, residual_fn=residual_fn) 389 | utils.print_out(" %s" % type(single_cell).__name__, new_line=False) 390 | 391 | # Device Wrapper 392 | if device_str: 393 | single_cell = tf.contrib.rnn.DeviceWrapper(single_cell, device_str) 394 | utils.print_out(" %s, device=%s" % 395 | (type(single_cell).__name__, device_str), new_line=False) 396 | 397 | return single_cell 398 | 399 | 400 | def _cell_list(unit_type, num_units, num_layers, num_residual_layers, 401 | forget_bias, dropout, mode, num_gpus, base_gpu=0, 402 | single_cell_fn=None, residual_fn=None): 403 | """Create a list of RNN cells.""" 404 | if not single_cell_fn: 405 | single_cell_fn = _single_cell 406 | 407 | # Multi-GPU 408 | cell_list = [] 409 | for i in range(num_layers): 410 | utils.print_out(" cell %d" % i, new_line=False) 411 | single_cell = single_cell_fn( 412 | unit_type=unit_type, 413 | num_units=num_units, 414 | forget_bias=forget_bias, 415 | dropout=dropout, 416 | mode=mode, 417 | residual_connection=(i >= num_layers - num_residual_layers), 418 | device_str=get_device_str(i + base_gpu, num_gpus), 419 | residual_fn=residual_fn 420 | ) 421 | utils.print_out("") 422 | cell_list.append(single_cell) 423 | 424 | return cell_list 425 | 426 | 427 | def create_rnn_cell(unit_type, num_units, num_layers, num_residual_layers, 428 | forget_bias, dropout, mode, num_gpus, base_gpu=0, 429 | single_cell_fn=None): 430 | """Create multi-layer RNN cell. 431 | 432 | Args: 433 | unit_type: string representing the unit type, i.e. "lstm". 434 | num_units: the depth of each unit. 435 | num_layers: number of cells. 436 | num_residual_layers: Number of residual layers from top to bottom. For 437 | example, if `num_layers=4` and `num_residual_layers=2`, the last 2 RNN 438 | cells in the returned list will be wrapped with `ResidualWrapper`. 439 | forget_bias: the initial forget bias of the RNNCell(s). 440 | dropout: floating point value between 0.0 and 1.0: 441 | the probability of dropout. this is ignored if `mode != TRAIN`. 442 | mode: either tf.contrib.learn.TRAIN/EVAL/INFER 443 | num_gpus: The number of gpus to use when performing round-robin 444 | placement of layers. 445 | base_gpu: The gpu device id to use for the first RNN cell in the 446 | returned list. The i-th RNN cell will use `(base_gpu + i) % num_gpus` 447 | as its device id. 448 | single_cell_fn: allow for adding customized cell. 449 | When not specified, we default to model_helper._single_cell 450 | Returns: 451 | An `RNNCell` instance. 452 | """ 453 | cell_list = _cell_list(unit_type=unit_type, 454 | num_units=num_units, 455 | num_layers=num_layers, 456 | num_residual_layers=num_residual_layers, 457 | forget_bias=forget_bias, 458 | dropout=dropout, 459 | mode=mode, 460 | num_gpus=num_gpus, 461 | base_gpu=base_gpu, 462 | single_cell_fn=single_cell_fn) 463 | 464 | if len(cell_list) == 1: # Single layer. 465 | return cell_list[0] 466 | else: # Multi layers 467 | return tf.contrib.rnn.MultiRNNCell(cell_list) 468 | 469 | 470 | def gradient_clip(gradients, max_gradient_norm): 471 | """Clipping gradients of a model.""" 472 | clipped_gradients, gradient_norm = tf.clip_by_global_norm( 473 | gradients, max_gradient_norm) 474 | gradient_norm_summary = [tf.summary.scalar("grad_norm", gradient_norm)] 475 | gradient_norm_summary.append( 476 | tf.summary.scalar("clipped_gradient", tf.global_norm(clipped_gradients))) 477 | 478 | return clipped_gradients, gradient_norm_summary, gradient_norm 479 | 480 | 481 | def load_model(model, ckpt, session, name): 482 | start_time = time.time() 483 | model.saver.restore(session, ckpt) 484 | session.run(tf.tables_initializer()) 485 | utils.print_out( 486 | " loaded %s model parameters from %s, time %.2fs" % 487 | (name, ckpt, time.time() - start_time)) 488 | return model 489 | 490 | 491 | def avg_checkpoints(model_dir, num_last_checkpoints, global_step, 492 | global_step_name): 493 | """Average the last N checkpoints in the model_dir.""" 494 | checkpoint_state = tf.train.get_checkpoint_state(model_dir) 495 | if not checkpoint_state: 496 | utils.print_out("# No checkpoint file found in directory: %s" % model_dir) 497 | return None 498 | 499 | # Checkpoints are ordered from oldest to newest. 500 | checkpoints = ( 501 | checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:]) 502 | 503 | if len(checkpoints) < num_last_checkpoints: 504 | utils.print_out( 505 | "# Skipping averaging checkpoints because not enough checkpoints is " 506 | "avaliable." 507 | ) 508 | return None 509 | 510 | avg_model_dir = os.path.join(model_dir, "avg_checkpoints") 511 | if not tf.gfile.Exists(avg_model_dir): 512 | utils.print_out( 513 | "# Creating new directory %s for saving averaged checkpoints." % 514 | avg_model_dir) 515 | tf.gfile.MakeDirs(avg_model_dir) 516 | 517 | utils.print_out("# Reading and averaging variables in checkpoints:") 518 | var_list = tf.contrib.framework.list_variables(checkpoints[0]) 519 | var_values, var_dtypes = {}, {} 520 | for (name, shape) in var_list: 521 | if name != global_step_name: 522 | var_values[name] = np.zeros(shape) 523 | 524 | for checkpoint in checkpoints: 525 | utils.print_out(" %s" % checkpoint) 526 | reader = tf.contrib.framework.load_checkpoint(checkpoint) 527 | for name in var_values: 528 | tensor = reader.get_tensor(name) 529 | var_dtypes[name] = tensor.dtype 530 | var_values[name] += tensor 531 | 532 | for name in var_values: 533 | var_values[name] /= len(checkpoints) 534 | 535 | # Build a graph with same variables in the checkpoints, and save the averaged 536 | # variables into the avg_model_dir. 537 | with tf.Graph().as_default(): 538 | tf_vars = [ 539 | tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name]) 540 | for v in var_values 541 | ] 542 | 543 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 544 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 545 | global_step_var = tf.Variable( 546 | global_step, name=global_step_name, trainable=False) 547 | saver = tf.train.Saver(tf.all_variables()) 548 | 549 | with tf.Session() as sess: 550 | sess.run(tf.initialize_all_variables()) 551 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 552 | six.iteritems(var_values)): 553 | sess.run(assign_op, {p: value}) 554 | 555 | # Use the built saver to save the averaged checkpoint. Only keep 1 556 | # checkpoint and the best checkpoint will be moved to avg_best_metric_dir. 557 | saver.save( 558 | sess, 559 | os.path.join(avg_model_dir, "translate.ckpt")) 560 | 561 | return avg_model_dir 562 | 563 | 564 | def create_or_load_model(model, model_dir, session, name): 565 | """Create translation model and initialize or load parameters in session.""" 566 | latest_ckpt = tf.train.latest_checkpoint(model_dir) 567 | if latest_ckpt: 568 | model = load_model(model, latest_ckpt, session, name) 569 | else: 570 | start_time = time.time() 571 | session.run(tf.global_variables_initializer()) 572 | session.run(tf.tables_initializer()) 573 | utils.print_out(" created %s model with fresh parameters, time %.2fs" % 574 | (name, time.time() - start_time)) 575 | 576 | global_step = model.global_step.eval(session=session) 577 | return model, global_step 578 | 579 | 580 | def compute_perplexity(model, sess, name): 581 | """Compute perplexity of the output of the model. 582 | 583 | Args: 584 | model: model for compute perplexity. 585 | sess: tensorflow session to use. 586 | name: name of the batch. 587 | 588 | Returns: 589 | The perplexity of the eval outputs. 590 | """ 591 | total_loss = 0 592 | total_predict_count = 0 593 | start_time = time.time() 594 | 595 | while True: 596 | try: 597 | loss, predict_count, batch_size = model.eval(sess) 598 | total_loss += loss * batch_size 599 | total_predict_count += predict_count 600 | except tf.errors.OutOfRangeError: 601 | break 602 | 603 | perplexity = utils.safe_exp(total_loss / total_predict_count) 604 | utils.print_time(" eval %s: perplexity %.2f" % (name, perplexity), 605 | start_time) 606 | return perplexity 607 | --------------------------------------------------------------------------------