├── nmt
├── __init__.py
├── scripts
│ ├── __init__.py
│ ├── download_iwslt15.sh
│ ├── bleu.py
│ ├── wmt16_en_de.sh
│ └── rouge.py
├── utils
│ ├── __init__.py
│ ├── misc_utils_test.py
│ ├── vocab_utils_test.py
│ ├── evaluation_utils_test.py
│ ├── standard_hparams_utils.py
│ ├── nmt_utils.py
│ ├── vocab_utils.py
│ ├── common_test_utils.py
│ ├── misc_utils.py
│ ├── evaluation_utils.py
│ ├── iterator_utils.py
│ └── iterator_utils_test.py
├── .gitignore
├── standard_hparams
│ ├── iwslt15.json
│ ├── wmt16.json
│ ├── mine.json
│ ├── wmt16_gnmt_4_layer.json
│ └── wmt16_gnmt_8_layer.json
├── OT.py
├── nmt_test.py
├── attention_model.py
├── inference_test.py
├── inference.py
├── gnmt_model.py
└── model_helper.py
├── texar
├── configs
│ ├── __init__.py
│ ├── config_model.py
│ ├── config_giga.py
│ └── config_iwslt14.py
├── requirements.txt
├── utils
│ ├── raml_samples_generation
│ │ ├── README.md
│ │ ├── gen_samples_giga.sh
│ │ ├── gen_samples_iwslt14.sh
│ │ ├── util.py
│ │ ├── vocab.py
│ │ └── process_samples.py
│ └── prepare_data.py
├── README.md
├── OT.py
├── baseline_seq2seq_attn_main.py
└── baseline_seq2seq_attn_ot.py
├── .gitattributes
├── image
├── alg.JPG
├── draw2.jpg
└── model.JPG
├── LICENSE
├── .gitignore
└── README.md
/nmt/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/nmt/scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/nmt/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/texar/configs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/texar/requirements.txt:
--------------------------------------------------------------------------------
1 | rouge==0.2.1
--------------------------------------------------------------------------------
/nmt/.gitignore:
--------------------------------------------------------------------------------
1 | bazel-bin
2 | bazel-genfiles
3 | bazel-out
4 | bazel-testlogs
5 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/image/alg.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiqunChen0606/OT-Seq2Seq/HEAD/image/alg.JPG
--------------------------------------------------------------------------------
/image/draw2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiqunChen0606/OT-Seq2Seq/HEAD/image/draw2.jpg
--------------------------------------------------------------------------------
/image/model.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiqunChen0606/OT-Seq2Seq/HEAD/image/model.JPG
--------------------------------------------------------------------------------
/texar/utils/raml_samples_generation/README.md:
--------------------------------------------------------------------------------
1 | ## Augmented Data Generation for RAML Algorithm
2 |
3 | Codes here are mainly copied from [pcyin's github](https://github.com/pcyin/pytorch_nmt), with slightly change for supporting ```rouge``` as reward. Note that we have also provided generated samples in the datasets that you can download.
4 |
5 | You may tune hyperparameters in ```gen_samples_giga.sh``` or ```gen_samples_iwslt14.sh``` and use commands like ```bash gen_samples_giga.sh``` to begin your generation.
6 |
--------------------------------------------------------------------------------
/texar/utils/raml_samples_generation/gen_samples_giga.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | train_src="../../data/giga/train.article"
4 | train_tgt="../../data/giga/train.title"
5 |
6 | python vocab.py \
7 | --src_vocab_size 30424 \
8 | --tgt_vocab_size 23738 \
9 | --train_src ${train_src} \
10 | --train_tgt ${train_tgt} \
11 | --include_singleton \
12 | --output giga_vocab.bin
13 |
14 | python process_samples.py \
15 | --mode sample_ngram \
16 | --vocab giga_vocab.bin \
17 | --src ${train_src} \
18 | --tgt ${train_tgt} \
19 | --sample_size 10 \
20 | --reward rouge \
21 | --output samples_giga.txt
22 |
--------------------------------------------------------------------------------
/texar/utils/raml_samples_generation/gen_samples_iwslt14.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | train_src="../../data/iwslt14/train.de"
4 | train_tgt="../../data/iwslt14/train.en"
5 |
6 | python vocab.py \
7 | --src_vocab_size 32007 \
8 | --tgt_vocab_size 22820 \
9 | --train_src ${train_src} \
10 | --train_tgt ${train_tgt} \
11 | --include_singleton \
12 | --output iwslt14_vocab.bin
13 |
14 | python process_samples.py \
15 | --mode sample_ngram \
16 | --vocab iwslt14_vocab.bin \
17 | --src ${train_src} \
18 | --tgt ${train_tgt} \
19 | --sample_size 10 \
20 | --reward bleu \
21 | --output samples_iwslt14.txt
22 |
--------------------------------------------------------------------------------
/nmt/standard_hparams/iwslt15.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention": "scaled_luong",
3 | "attention_architecture": "standard",
4 | "batch_size": 128,
5 | "colocate_gradients_with_ops": true,
6 | "dropout": 0.2,
7 | "encoder_type": "bi",
8 | "eos": "",
9 | "forget_bias": 1.0,
10 | "infer_batch_size": 32,
11 | "init_weight": 0.1,
12 | "learning_rate": 1.0,
13 | "max_gradient_norm": 5.0,
14 | "metrics": ["bleu"],
15 | "num_buckets": 5,
16 | "num_layers": 2,
17 | "num_train_steps": 12000,
18 | "decay_scheme": "luong234",
19 | "num_units": 512,
20 | "optimizer": "sgd",
21 | "residual": false,
22 | "share_vocab": false,
23 | "subword_option": "",
24 | "sos": "",
25 | "src_max_len": 50,
26 | "src_max_len_infer": null,
27 | "steps_per_external_eval": null,
28 | "steps_per_stats": 100,
29 | "tgt_max_len": 50,
30 | "tgt_max_len_infer": null,
31 | "time_major": true,
32 | "unit_type": "lstm",
33 | "beam_width": 10
34 | }
35 |
--------------------------------------------------------------------------------
/nmt/standard_hparams/wmt16.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention": "normed_bahdanau",
3 | "attention_architecture": "standard",
4 | "batch_size": 128,
5 | "colocate_gradients_with_ops": true,
6 | "dropout": 0.2,
7 | "encoder_type": "bi",
8 | "eos": "",
9 | "forget_bias": 1.0,
10 | "infer_batch_size": 32,
11 | "init_weight": 0.1,
12 | "learning_rate": 1.0,
13 | "max_gradient_norm": 5.0,
14 | "metrics": ["bleu"],
15 | "num_buckets": 5,
16 | "num_layers": 4,
17 | "num_train_steps": 340000,
18 | "decay_scheme": "luong10",
19 | "num_units": 1024,
20 | "optimizer": "sgd",
21 | "residual": false,
22 | "share_vocab": false,
23 | "subword_option": "bpe",
24 | "sos": "",
25 | "src_max_len": 50,
26 | "src_max_len_infer": null,
27 | "steps_per_external_eval": null,
28 | "steps_per_stats": 100,
29 | "tgt_max_len": 50,
30 | "tgt_max_len_infer": null,
31 | "time_major": true,
32 | "unit_type": "lstm",
33 | "beam_width": 10
34 | }
35 |
--------------------------------------------------------------------------------
/texar/configs/config_model.py:
--------------------------------------------------------------------------------
1 | num_units = 256
2 | beam_width = 5
3 | decoder_layers = 1
4 | dropout = 0.2
5 |
6 | embedder = {
7 | 'dim': num_units
8 | }
9 | encoder = {
10 | 'rnn_cell_fw': {
11 | 'kwargs': {
12 | 'num_units': num_units
13 | },
14 | 'dropout': {
15 | 'input_keep_prob': 1. - dropout
16 | }
17 | }
18 | }
19 | decoder = {
20 | 'rnn_cell': {
21 | 'kwargs': {
22 | 'num_units': num_units
23 | },
24 | 'dropout': {
25 | 'input_keep_prob': 1. - dropout
26 | },
27 | 'num_layers': decoder_layers
28 | },
29 | 'attention': {
30 | 'kwargs': {
31 | 'num_units': num_units,
32 | },
33 | 'attention_layer_size': num_units
34 | }
35 | }
36 | opt = {
37 | 'optimizer': {
38 | 'type': 'AdamOptimizer',
39 | 'kwargs': {
40 | 'learning_rate': 0.001,
41 | },
42 | },
43 | }
--------------------------------------------------------------------------------
/nmt/standard_hparams/mine.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention": "scaled_luong",
3 | "attention_architecture": "standard",
4 | "batch_size": 128,
5 | "colocate_gradients_with_ops": true,
6 | "dropout": 0.2,
7 | "encoder_type": "bi",
8 | "eos": "",
9 | "forget_bias": 1.0,
10 | "infer_batch_size": 32,
11 | "init_weight": 0.1,
12 | "learning_rate": 1.0,
13 | "max_gradient_norm": 5.0,
14 | "metrics": ["bleu"],
15 | "num_buckets": 5,
16 | "num_layers": 2,
17 | "num_train_steps": 340000,
18 | "decay_scheme": "luong10",
19 | "num_units": 1024,
20 | "optimizer": "sgd",
21 | "residual": true,
22 | "share_vocab": false,
23 | "subword_option": "bpe",
24 | "sos": "",
25 | "src_max_len": 50,
26 | "src_max_len_infer": null,
27 | "steps_per_external_eval": null,
28 | "steps_per_stats": 100,
29 | "tgt_max_len": 50,
30 | "tgt_max_len_infer": null,
31 | "time_major": true,
32 | "unit_type": "lstm",
33 | "beam_width": 10,
34 | "length_penalty_weight": 1.0
35 | }
36 |
--------------------------------------------------------------------------------
/nmt/standard_hparams/wmt16_gnmt_4_layer.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention": "normed_bahdanau",
3 | "attention_architecture": "gnmt_v2",
4 | "batch_size": 128,
5 | "colocate_gradients_with_ops": true,
6 | "dropout": 0.2,
7 | "encoder_type": "gnmt",
8 | "eos": "",
9 | "forget_bias": 1.0,
10 | "infer_batch_size": 32,
11 | "init_weight": 0.1,
12 | "learning_rate": 1.0,
13 | "max_gradient_norm": 5.0,
14 | "metrics": ["bleu"],
15 | "num_buckets": 5,
16 | "num_layers": 4,
17 | "num_train_steps": 340000,
18 | "decay_scheme": "luong10",
19 | "num_units": 1024,
20 | "optimizer": "sgd",
21 | "residual": true,
22 | "share_vocab": false,
23 | "subword_option": "bpe",
24 | "sos": "",
25 | "src_max_len": 50,
26 | "src_max_len_infer": null,
27 | "steps_per_external_eval": null,
28 | "steps_per_stats": 100,
29 | "tgt_max_len": 50,
30 | "tgt_max_len_infer": null,
31 | "time_major": true,
32 | "unit_type": "lstm",
33 | "beam_width": 10,
34 | "length_penalty_weight": 1.0
35 | }
--------------------------------------------------------------------------------
/nmt/standard_hparams/wmt16_gnmt_8_layer.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention": "normed_bahdanau",
3 | "attention_architecture": "gnmt_v2",
4 | "batch_size": 128,
5 | "colocate_gradients_with_ops": true,
6 | "dropout": 0.2,
7 | "encoder_type": "gnmt",
8 | "eos": "",
9 | "forget_bias": 1.0,
10 | "infer_batch_size": 32,
11 | "init_weight": 0.1,
12 | "learning_rate": 1.0,
13 | "max_gradient_norm": 5.0,
14 | "metrics": ["bleu"],
15 | "num_buckets": 5,
16 | "num_layers": 8,
17 | "num_train_steps": 340000,
18 | "decay_scheme": "luong10",
19 | "num_units": 1024,
20 | "optimizer": "sgd",
21 | "residual": true,
22 | "share_vocab": false,
23 | "subword_option": "bpe",
24 | "sos": "",
25 | "source_reverse": false,
26 | "src_max_len": 50,
27 | "src_max_len_infer": null,
28 | "steps_per_external_eval": null,
29 | "steps_per_stats": 100,
30 | "tgt_max_len": 50,
31 | "tgt_max_len_infer": null,
32 | "time_major": true,
33 | "unit_type": "lstm",
34 | "beam_width": 10,
35 | "length_penalty_weight": 1.0
36 | }
37 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 LiqunChen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/nmt/scripts/download_iwslt15.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | # Download small-scale IWSLT15 Vietnames to English translation data for NMT
3 | # model training.
4 | #
5 | # Usage:
6 | # ./download_iwslt15.sh path-to-output-dir
7 | #
8 | # If output directory is not specified, "./iwslt15" will be used as the default
9 | # output directory.
10 | OUT_DIR="${1:-iwslt15}"
11 | SITE_PREFIX="https://nlp.stanford.edu/projects/nmt/data"
12 |
13 | mkdir -v -p $OUT_DIR
14 |
15 | # Download iwslt15 small dataset from standford website.
16 | echo "Download training dataset train.en and train.vi."
17 | curl -o "$OUT_DIR/train.en" "$SITE_PREFIX/iwslt15.en-vi/train.en"
18 | curl -o "$OUT_DIR/train.vi" "$SITE_PREFIX/iwslt15.en-vi/train.vi"
19 |
20 | echo "Download dev dataset tst2012.en and tst2012.vi."
21 | curl -o "$OUT_DIR/tst2012.en" "$SITE_PREFIX/iwslt15.en-vi/tst2012.en"
22 | curl -o "$OUT_DIR/tst2012.vi" "$SITE_PREFIX/iwslt15.en-vi/tst2012.vi"
23 |
24 | echo "Download test dataset tst2013.en and tst2013.vi."
25 | curl -o "$OUT_DIR/tst2013.en" "$SITE_PREFIX/iwslt15.en-vi/tst2013.en"
26 | curl -o "$OUT_DIR/tst2013.vi" "$SITE_PREFIX/iwslt15.en-vi/tst2013.vi"
27 |
28 | echo "Download vocab file vocab.en and vocab.vi."
29 | curl -o "$OUT_DIR/vocab.en" "$SITE_PREFIX/iwslt15.en-vi/vocab.en"
30 | curl -o "$OUT_DIR/vocab.vi" "$SITE_PREFIX/iwslt15.en-vi/vocab.vi"
31 |
--------------------------------------------------------------------------------
/texar/configs/config_giga.py:
--------------------------------------------------------------------------------
1 | num_epochs = 20
2 | observe_steps = 500
3 |
4 | eval_metric = 'rouge'
5 |
6 | batch_size = 64
7 | source_vocab_file = './data/giga/vocab.article'
8 | target_vocab_file = './data/giga/vocab.title'
9 |
10 | train = {
11 | 'batch_size': batch_size,
12 | 'allow_smaller_final_batch': False,
13 | 'source_dataset': {
14 | "files": 'data/giga/train.article',
15 | 'vocab_file': source_vocab_file
16 | },
17 | 'target_dataset': {
18 | 'files': 'data/giga/train.title',
19 | 'vocab_file': target_vocab_file
20 | }
21 | }
22 | val = {
23 | 'batch_size': batch_size,
24 | 'shuffle': False,
25 | 'allow_smaller_final_batch': True,
26 | 'source_dataset': {
27 | "files": 'data/giga/valid.article',
28 | 'vocab_file': source_vocab_file,
29 | },
30 | 'target_dataset': {
31 | 'files': 'data/giga/valid.title',
32 | 'vocab_file': target_vocab_file,
33 | }
34 | }
35 | test = {
36 | 'batch_size': batch_size,
37 | 'shuffle': False,
38 | 'allow_smaller_final_batch': True,
39 | 'source_dataset': {
40 | "files": 'data/giga/test.article',
41 | 'vocab_file': source_vocab_file,
42 | },
43 | 'target_dataset': {
44 | 'files': 'data/giga/test.title',
45 | 'vocab_file': target_vocab_file,
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/nmt/OT.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 |
5 | def cost_matrix(x, y):
6 | "Returns the matrix of $|x_i-y_j|^p$."
7 | "Returns the cosine distance"
8 | #NOTE: choose cosine distance here
9 |
10 | x = tf.nn.l2_normalize(x, 1, epsilon=1e-12)
11 | y = tf.nn.l2_normalize(y, 1, epsilon=1e-12)
12 | tmp1 = tf.matmul(x, y, transpose_b=True)
13 | cos_dis = 1 - tmp1
14 |
15 | x_col = tf.expand_dims(x, 1)
16 | y_lin = tf.expand_dims(y, 0)
17 | res = tf.reduce_sum(tf.abs(x_col - y_lin), 2)
18 |
19 | return cos_dis
20 |
21 | def IPOT(C, n, m, beta=0.5):
22 |
23 | # sigma = tf.scalar_mul(1 / n, tf.ones([n, 1]))
24 | sigma = tf.ones([m, 1]) / tf.cast(m, tf.float32)
25 | T = tf.ones([n, m])
26 | A = tf.exp(-C / beta)
27 | for t in range(50):
28 | Q = tf.multiply(A, T)
29 | for k in range(1):
30 | delta = 1 / (tf.cast(n, tf.float32) * tf.matmul(Q, sigma))
31 | sigma = 1 / (
32 | tf.cast(m, tf.float32) * tf.matmul(Q, delta, transpose_a=True))
33 | # pdb.set_trace()
34 | tmp = tf.matmul(tf.diag(tf.squeeze(delta)), Q)
35 | T = tf.matmul(tmp, tf.diag(tf.squeeze(sigma)))
36 | return T
37 |
38 |
39 | def IPOT_distance(C, n, m):
40 | T = IPOT(C, n, m)
41 | distance = tf.trace(tf.matmul(C, T, transpose_a=True))
42 | return distance
43 |
--------------------------------------------------------------------------------
/texar/configs/config_iwslt14.py:
--------------------------------------------------------------------------------
1 | num_epochs = 50 # the best epoch occurs within 10 epochs in most cases
2 | observe_steps = 500
3 |
4 | eval_metric = 'bleu'
5 |
6 | batch_size = 64
7 | source_vocab_file = './data/iwslt14/vocab.de'
8 | target_vocab_file = './data/iwslt14/vocab.en'
9 |
10 | train = {
11 | 'batch_size': batch_size,
12 | 'shuffle': True,
13 | 'allow_smaller_final_batch': False,
14 | 'source_dataset': {
15 | "files": 'data/iwslt14/train.de',
16 | 'vocab_file': source_vocab_file,
17 | 'max_seq_length': 50
18 | },
19 | 'target_dataset': {
20 | 'files': 'data/iwslt14/train.en',
21 | 'vocab_file': target_vocab_file,
22 | 'max_seq_length': 50
23 | }
24 | }
25 | val = {
26 | 'batch_size': batch_size,
27 | 'shuffle': False,
28 | 'allow_smaller_final_batch': True,
29 | 'source_dataset': {
30 | "files": 'data/iwslt14/valid.de',
31 | 'vocab_file': source_vocab_file,
32 | },
33 | 'target_dataset': {
34 | 'files': 'data/iwslt14/valid.en',
35 | 'vocab_file': target_vocab_file,
36 | }
37 | }
38 | test = {
39 | 'batch_size': batch_size,
40 | 'shuffle': False,
41 | 'allow_smaller_final_batch': True,
42 | 'source_dataset': {
43 | "files": 'data/iwslt14/test.de',
44 | 'vocab_file': source_vocab_file,
45 | },
46 | 'target_dataset': {
47 | 'files': 'data/iwslt14/test.en',
48 | 'vocab_file': target_vocab_file,
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/nmt/utils/misc_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for vocab_utils."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow as tf
23 |
24 | from ..utils import misc_utils
25 |
26 |
27 | class MiscUtilsTest(tf.test.TestCase):
28 |
29 | def testFormatBpeText(self):
30 | bpe_line = (
31 | b"En@@ ough to make already reluc@@ tant men hesitate to take screening"
32 | b" tests ."
33 | )
34 | expected_result = (
35 | b"Enough to make already reluctant men hesitate to take screening tests"
36 | b" ."
37 | )
38 | self.assertEqual(expected_result,
39 | misc_utils.format_bpe_text(bpe_line.split(b" ")))
40 |
41 | def testFormatSPMText(self):
42 | spm_line = u"\u2581This \u2581is \u2581a \u2581 te st .".encode("utf-8")
43 | expected_result = "This is a test."
44 | self.assertEqual(expected_result,
45 | misc_utils.format_spm_text(spm_line.split(b" ")))
46 |
47 |
48 | if __name__ == "__main__":
49 | tf.test.main()
50 |
--------------------------------------------------------------------------------
/texar/utils/prepare_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Texar Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Downloads data.
15 | """
16 | import tensorflow as tf
17 | import texar as tx
18 |
19 | # pylint: disable=invalid-name
20 |
21 | flags = tf.flags
22 |
23 | flags.DEFINE_string("data", "iwslt14", "Data to download [iwslt14|toy_copy]")
24 |
25 | FLAGS = flags.FLAGS
26 |
27 |
28 | def prepare_data():
29 | """Downloads data.
30 | """
31 | if FLAGS.data == 'giga':
32 | tx.data.maybe_download(
33 | urls='https://drive.google.com/file/d/'
34 | '12RZs7QFwjj6dfuYNQ_0Ah-ccH1xFDMD5/view?usp=sharing',
35 | path='./',
36 | filenames='giga.zip',
37 | extract=True)
38 | elif FLAGS.data == 'iwslt14':
39 | tx.data.maybe_download(
40 | urls='https://drive.google.com/file/d/'
41 | '1y4mUWXRS2KstgHopCS9koZ42ENOh6Yb9/view?usp=sharing',
42 | path='./',
43 | filenames='iwslt14.zip',
44 | extract=True)
45 | else:
46 | raise ValueError('Unknown data: {}'.format(FLAGS.data))
47 |
48 |
49 | def main():
50 | """Entrypoint.
51 | """
52 | prepare_data()
53 |
54 |
55 | if __name__ == '__main__':
56 | main()
--------------------------------------------------------------------------------
/texar/utils/raml_samples_generation/util.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import numpy as np
3 |
4 | def read_corpus(file_path, source):
5 | data = []
6 | for line in open(file_path):
7 | sent = line.strip().split(' ')
8 | # only append and to the target sentence
9 | if source == 'tgt':
10 | sent = [''] + sent + ['']
11 | data.append(sent)
12 |
13 | return data
14 |
15 |
16 | def batch_slice(data, batch_size, sort=True):
17 | batch_num = int(np.ceil(len(data) / float(batch_size)))
18 | for i in range(batch_num):
19 | cur_batch_size = batch_size if i < batch_num - 1 else len(data) - batch_size * i
20 | src_sents = [data[i * batch_size + b][0] for b in range(cur_batch_size)]
21 | tgt_sents = [data[i * batch_size + b][1] for b in range(cur_batch_size)]
22 |
23 | if sort:
24 | src_ids = sorted(range(cur_batch_size), key=lambda src_id: len(src_sents[src_id]), reverse=True)
25 | src_sents = [src_sents[src_id] for src_id in src_ids]
26 | tgt_sents = [tgt_sents[src_id] for src_id in src_ids]
27 |
28 | yield src_sents, tgt_sents
29 |
30 |
31 | def data_iter(data, batch_size, shuffle=True):
32 | """
33 | randomly permute data, then sort by source length, and partition into batches
34 | ensure that the length of source sentences in each batch is decreasing
35 | """
36 |
37 | buckets = defaultdict(list)
38 | for pair in data:
39 | src_sent = pair[0]
40 | buckets[len(src_sent)].append(pair)
41 |
42 | batched_data = []
43 | for src_len in buckets:
44 | tuples = buckets[src_len]
45 | if shuffle: np.random.shuffle(tuples)
46 | batched_data.extend(list(batch_slice(tuples, batch_size)))
47 |
48 | if shuffle:
49 | np.random.shuffle(batched_data)
50 | for batch in batched_data:
51 | yield batch
52 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .nox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # IPython
77 | profile_default/
78 | ipython_config.py
79 |
80 | # pyenv
81 | .python-version
82 |
83 | # celery beat schedule file
84 | celerybeat-schedule
85 |
86 | # SageMath parsed files
87 | *.sage.py
88 |
89 | # Environments
90 | .env
91 | .venv
92 | env/
93 | venv/
94 | ENV/
95 | env.bak/
96 | venv.bak/
97 |
98 | # Spyder project settings
99 | .spyderproject
100 | .spyproject
101 |
102 | # Rope project settings
103 | .ropeproject
104 |
105 | # mkdocs documentation
106 | /site
107 |
108 | # mypy
109 | .mypy_cache/
110 | .dmypy.json
111 | dmypy.json
112 |
113 | # Pyre type checker
114 | .pyre/
115 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OT-Seq2Seq
2 | This is the repository for ICLR 2019 paper [IMPROVING SEQUENCE-TO-SEQUENCE LEARNING
3 | VIA OPTIMAL TRANSPORT](https://arxiv.org/pdf/1901.06283.pdf)
4 |
5 | ## Usage ##
6 | Folder [nmt](./nmt) is built upon [GoogleNMT](https://github.com/tensorflow/nmt).
7 | Please follow the instructions in that repo for dataset downloading and code testing.
8 |
9 | Folder [texar](./texar) is built upon [Texar](https://github.com/asyml/texar).
10 | Details about summarization and translation tasks, please follow this [link](./texar).
11 |
12 | ## Brief introduction ##
13 | 
14 | We present a novel Seq2Seq learning scheme that leverages optimal transport (OT) to construct sequence-level loss. Specifically, the OT objective aims to find an optimal matching of similarwords/phrases between two sequences, providing a way to promote their semantic similarity (Kusneret al., 2015). Compared with the above RL and adversarial schemes, our approach has: (i) semantic-invariance, allowing better preservation of sequence-level semantic information; and (ii) improved robustness, since neither the reinforce gradient nor the mini-max game is involved. The OT loss allows end-to-end supervised training and acts as an effective sequence-level regularization to the MLE loss.
15 |
16 | OT can be easily applied to any Seq2Seq learning framework, the framework figure is shown here:
17 | 
18 |
19 | Therefore, the training algorithm can be represented as:
20 | 
21 |
22 |
23 | ### Reference
24 | If you are interested in our paper and want to further improve the model, please cite our paper with the following BibTex entry:
25 | ```
26 | @article{chen2019improving,
27 | title={Improving Sequence-to-Sequence Learning via Optimal Transport},
28 | author={Chen, Liqun and Zhang, Yizhe and Zhang, Ruiyi and Tao, Chenyang and Gan, Zhe and Zhang, Haichao and Li, Bai and Shen, Dinghan and Chen, Changyou and Carin, Lawrence},
29 | journal={ICLR},
30 | year={2019}
31 | }
32 | ```
--------------------------------------------------------------------------------
/nmt/utils/vocab_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for vocab_utils."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import codecs
23 | import os
24 | import tensorflow as tf
25 |
26 | from ..utils import vocab_utils
27 |
28 |
29 | class VocabUtilsTest(tf.test.TestCase):
30 |
31 | def testCheckVocab(self):
32 | # Create a vocab file
33 | vocab_dir = os.path.join(tf.test.get_temp_dir(), "vocab_dir")
34 | os.makedirs(vocab_dir)
35 | vocab_file = os.path.join(vocab_dir, "vocab_file")
36 | vocab = ["a", "b", "c"]
37 | with codecs.getwriter("utf-8")(tf.gfile.GFile(vocab_file, "wb")) as f:
38 | for word in vocab:
39 | f.write("%s\n" % word)
40 |
41 | # Call vocab_utils
42 | out_dir = os.path.join(tf.test.get_temp_dir(), "out_dir")
43 | os.makedirs(out_dir)
44 | vocab_size, new_vocab_file = vocab_utils.check_vocab(
45 | vocab_file, out_dir)
46 |
47 | # Assert: we expect the code to add , , and
48 | # create a new vocab file
49 | self.assertEqual(len(vocab) + 3, vocab_size)
50 | self.assertEqual(os.path.join(out_dir, "vocab_file"), new_vocab_file)
51 | new_vocab, _ = vocab_utils.load_vocab(new_vocab_file)
52 | self.assertEqual(
53 | [vocab_utils.UNK, vocab_utils.SOS, vocab_utils.EOS] + vocab, new_vocab)
54 |
55 |
56 | if __name__ == "__main__":
57 | tf.test.main()
58 |
--------------------------------------------------------------------------------
/nmt/utils/evaluation_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for evaluation_utils.py."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow as tf
23 |
24 | from ..utils import evaluation_utils
25 |
26 |
27 | class EvaluationUtilsTest(tf.test.TestCase):
28 |
29 | def testEvaluate(self):
30 | output = "nmt/testdata/deen_output"
31 | ref_bpe = "nmt/testdata/deen_ref_bpe"
32 | ref_spm = "nmt/testdata/deen_ref_spm"
33 |
34 | expected_bleu_score = 22.5855084573
35 | expected_rouge_score = 50.8429782599
36 |
37 | bpe_bleu_score = evaluation_utils.evaluate(
38 | ref_bpe, output, "bleu", "bpe")
39 | bpe_rouge_score = evaluation_utils.evaluate(
40 | ref_bpe, output, "rouge", "bpe")
41 |
42 | self.assertAlmostEqual(expected_bleu_score, bpe_bleu_score)
43 | self.assertAlmostEqual(expected_rouge_score, bpe_rouge_score)
44 |
45 | spm_bleu_score = evaluation_utils.evaluate(
46 | ref_spm, output, "bleu", "spm")
47 | spm_rouge_score = evaluation_utils.evaluate(
48 | ref_spm, output, "rouge", "spm")
49 |
50 | self.assertAlmostEqual(expected_rouge_score, spm_rouge_score)
51 | self.assertAlmostEqual(expected_bleu_score, spm_bleu_score)
52 |
53 | def testAccuracy(self):
54 | pred_output = "nmt/testdata/pred_output"
55 | label_ref = "nmt/testdata/label_ref"
56 |
57 | expected_accuracy_score = 60.00
58 |
59 | accuracy_score = evaluation_utils.evaluate(
60 | label_ref, pred_output, "accuracy")
61 | self.assertAlmostEqual(expected_accuracy_score, accuracy_score)
62 |
63 | def testWordAccuracy(self):
64 | pred_output = "nmt/testdata/pred_output"
65 | label_ref = "nmt/testdata/label_ref"
66 |
67 | expected_word_accuracy_score = 60.00
68 |
69 | word_accuracy_score = evaluation_utils.evaluate(
70 | label_ref, pred_output, "word_accuracy")
71 | self.assertAlmostEqual(expected_word_accuracy_score, word_accuracy_score)
72 |
73 |
74 | if __name__ == "__main__":
75 | tf.test.main()
76 |
--------------------------------------------------------------------------------
/texar/README.md:
--------------------------------------------------------------------------------
1 | This code is largely borrowed from [Texar](https://texar.readthedocs.io/en/latest/).
2 | Please install the requirement first to run the code.
3 |
4 | # Sequence Generation #
5 | This example provide implementations of some classic and advanced training algorithms that tackles the exposure bias. The base model is an attentional seq2seq.
6 |
7 | * **Maximum Likelihood (MLE)**: attentional seq2seq model with maximum likelihood training.
8 | * **Maximum Likelihood (MLE) + Optimal transport (OT)**: Described in [OT-seq2seq](https://arxiv.org/pdf/1901.06283.pdf) and we use the sampling approach (n-gram replacement) by [(Ma et al., 2017)](https://arxiv.org/abs/1705.07136).
9 |
10 | ## Usage ##
11 |
12 | ### Dataset ###
13 |
14 | Two example datasets are provided:
15 |
16 | * iwslt14: The benchmark [IWSLT2014](https://sites.google.com/site/iwsltevaluation2014/home) (de-en) machine translation dataset, following [(Ranzato et al., 2015)](https://arxiv.org/pdf/1511.06732.pdf) for data pre-processing.
17 | * gigaword: The benchmark [GIGAWORD](https://catalog.ldc.upenn.edu/LDC2003T05) text summarization dataset. we sampled 200K out of the 3.8M pre-processed training examples provided by [(Rush et al., 2015)](https://www.aclweb.org/anthology/D/D15/D15-1044.pdf) for the sake of training efficiency. We used the refined validation and test sets provided by [(Zhou et al., 2017)](https://arxiv.org/pdf/1704.07073.pdf).
18 |
19 | Download the data with the following commands:
20 |
21 | ```
22 | python utils/prepare_data.py --data iwslt14
23 | python utils/prepare_data.py --data giga
24 | ```
25 |
26 | ### Train the models ###
27 |
28 | #### Baseline Attentional Seq2seq with OT
29 |
30 | ```
31 | python baseline_seq2seq_attn_main.py \
32 | --config_model configs.config_model \
33 | --config_data configs.config_iwslt14
34 | ```
35 |
36 | Here:
37 | * `--config_model` specifies the model config. Note not to include the `.py` suffix.
38 | * `--config_data` specifies the data config.
39 |
40 | [configs.config_model.py](./configs/config_model.py) specifies a single-layer seq2seq model with Luong attention and bi-directional RNN encoder. Hyperparameters taking default values can be omitted from the config file.
41 |
42 |
43 | For demonstration purpose, [configs.config_model_full.py](./configs/config_model_full.py) gives all possible hyperparameters for the model. The two config files will lead to the same model.
44 |
45 | ## Results ##
46 |
47 | ### Machine Translation
48 | | Model | BLEU Score |
49 | | -----------| -------|
50 | | MLE | 26.44 ± 0.18 |
51 | | Scheduled Sampling | 26.76 ± 0.17 |
52 | | RAML | 27.22 ± 0.14 |
53 | | Interpolation | 27.82 ± 0.11 |
54 | | MLE + OT | 27.79 ± 0.12 |
55 |
56 | ### Text Summarization
57 | | Model | Rouge-1 | Rouge-2 | Rouge-L |
58 | | -----------| -------|-------|-------|
59 | | MLE | 36.11 ± 0.21 | 16.39 ± 0.16 | 32.32 ± 0.19 |
60 | | RAML | 36.30 ± 0.24 | 16.69 ± 0.20 | 32.49 ± 0.17 |
61 | | Interpolation | 36.72 ± 0.29 |16.99 ± 0.17 | 32.95 ± 0.33|
62 | | MLE + OT | 36.82 ± 0.25 | 17.35 ± 0.10 | 33.35 ± 0.14 |
63 |
64 |
65 |
--------------------------------------------------------------------------------
/texar/OT.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from functools import partial
4 | import pdb
5 | def cost_matrix(x, y):
6 | "Returns the matrix of $|x_i-y_j|^p$."
7 | "Returns the cosine distance"
8 | #NOTE: cosine distance and Euclidean distance
9 | # x_col = x.unsqueeze(1)
10 | # y_lin = y.unsqueeze(0)
11 | # c = torch.sum((torch.abs(x_col - y_lin)) ** p, 2)
12 | # return c
13 | x = tf.nn.l2_normalize(x, 1, epsilon=1e-12)
14 | y = tf.nn.l2_normalize(y, 1, epsilon=1e-12)
15 | tmp1 = tf.matmul(x, y, transpose_b=True)
16 | cos_dis = 1 - tmp1
17 |
18 | x_col = tf.expand_dims(x, 1)
19 | y_lin = tf.expand_dims(y, 0)
20 | res = tf.reduce_sum(tf.abs(x_col - y_lin), 2)
21 |
22 | return cos_dis
23 |
24 |
25 | def IPOT(C, n, m, beta=0.5):
26 |
27 | # sigma = tf.scalar_mul(1 / n, tf.ones([n, 1]))
28 |
29 | sigma = tf.ones([m, 1]) / tf.cast(m, tf.float32)
30 | T = tf.ones([n, m])
31 | A = tf.exp(-C / beta)
32 | for t in range(50):
33 | Q = tf.multiply(A, T)
34 | for k in range(1):
35 | delta = 1 / (tf.cast(n, tf.float32) * tf.matmul(Q, sigma))
36 | sigma = 1 / (
37 | tf.cast(m, tf.float32) * tf.matmul(Q, delta, transpose_a=True))
38 | # pdb.set_trace()
39 | tmp = tf.matmul(tf.diag(tf.squeeze(delta)), Q)
40 | T = tf.matmul(tmp, tf.diag(tf.squeeze(sigma)))
41 | return T
42 |
43 |
44 | def IPOT_np(C, beta=0.5):
45 |
46 | n, m = C.shape[0], C.shape[1]
47 | sigma = np.ones([m, 1]) / m
48 | T = np.ones([n, m])
49 | A = np.exp(-C / beta)
50 | for t in range(20):
51 | Q = np.multiply(A, T)
52 | for k in range(1):
53 | delta = 1 / (n * (Q @ sigma))
54 | sigma = 1 / (m * (Q.T @ delta))
55 | # pdb.set_trace()
56 | tmp = np.diag(np.squeeze(delta)) @ Q
57 | T = tmp @ np.diag(np.squeeze(sigma))
58 | return T
59 |
60 | def IPOT_distance(C, n, m):
61 | T = IPOT(C, n, m)
62 | distance = tf.trace(tf.matmul(C, T, transpose_a=True))
63 | return distance
64 |
65 |
66 | def shape_list(x):
67 | """Return list of dims, statically where possible."""
68 | x = tf.convert_to_tensor(x)
69 | # If unknown rank, return dynamic shape
70 | if x.get_shape().dims is None:
71 | return tf.shape(x)
72 | static = x.get_shape().as_list()
73 | shape = tf.shape(x)
74 | ret = []
75 | for i in range(len(static)):
76 | dim = static[i]
77 | if dim is None:
78 | dim = shape[i]
79 | ret.append(dim)
80 | return ret
81 |
82 | def IPOT_distance2(C, beta=1, t_steps=10, k_steps=1):
83 | b, n, m = shape_list(C)
84 | sigma = tf.ones([b, m, 1]) / tf.cast(m, tf.float32) # [b, m, 1]
85 | T = tf.ones([b, n, m])
86 | A = tf.exp(-C / beta) # [b, n, m]
87 | for t in range(t_steps):
88 | Q = A * T # [b, n, m]
89 | for k in range(k_steps):
90 | delta = 1 / (tf.cast(n, tf.float32) * tf.matmul(Q, sigma)) # [b, n, 1]
91 | sigma = 1 / (tf.cast(m, tf.float32) * tf.matmul(Q, delta, transpose_a=True)) # [b, m, 1]
92 | T = delta * Q * tf.transpose(sigma, [0, 2, 1]) # [b, n, m]
93 | distance = tf.trace(tf.matmul(C, T, transpose_a=True))
94 | return distance
--------------------------------------------------------------------------------
/nmt/utils/standard_hparams_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """standard hparams utils."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow as tf
23 |
24 |
25 | def create_standard_hparams():
26 | return tf.contrib.training.HParams(
27 | # Data
28 | src="",
29 | tgt="",
30 | train_prefix="",
31 | dev_prefix="",
32 | test_prefix="",
33 | vocab_prefix="",
34 | embed_prefix="",
35 | out_dir="",
36 |
37 | # Networks
38 | num_units=512,
39 | num_layers=2,
40 | num_encoder_layers=2,
41 | num_decoder_layers=2,
42 | dropout=0.2,
43 | unit_type="lstm",
44 | encoder_type="bi",
45 | residual=False,
46 | time_major=True,
47 | num_embeddings_partitions=0,
48 |
49 | # Attention mechanisms
50 | attention="scaled_luong",
51 | attention_architecture="standard",
52 | output_attention=True,
53 | pass_hidden_state=True,
54 |
55 | # Train
56 | optimizer="sgd",
57 | batch_size=128,
58 | init_op="uniform",
59 | init_weight=0.1,
60 | max_gradient_norm=5.0,
61 | learning_rate=1.0,
62 | warmup_steps=0,
63 | warmup_scheme="t2t",
64 | decay_scheme="luong234",
65 | colocate_gradients_with_ops=True,
66 | num_train_steps=12000,
67 |
68 | # Data constraints
69 | num_buckets=5,
70 | max_train=0,
71 | src_max_len=50,
72 | tgt_max_len=50,
73 | src_max_len_infer=0,
74 | tgt_max_len_infer=0,
75 |
76 | # Data format
77 | sos="",
78 | eos="",
79 | subword_option="",
80 | check_special_token=True,
81 |
82 | # Misc
83 | forget_bias=1.0,
84 | num_gpus=1,
85 | epoch_step=0, # record where we were within an epoch.
86 | steps_per_stats=100,
87 | steps_per_external_eval=0,
88 | share_vocab=False,
89 | metrics=["bleu"],
90 | log_device_placement=False,
91 | random_seed=None,
92 | # only enable beam search during inference when beam_width > 0.
93 | beam_width=0,
94 | length_penalty_weight=0.0,
95 | override_loaded_hparams=True,
96 | num_keep_ckpts=5,
97 | avg_ckpts=False,
98 |
99 | # For inference
100 | inference_indices=None,
101 | infer_batch_size=32,
102 | sampling_temperature=0.0,
103 | num_translations_per_input=1,
104 | )
105 |
--------------------------------------------------------------------------------
/nmt/nmt_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for nmt.py, train.py and inference.py."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import argparse
22 | import os
23 |
24 | import tensorflow as tf
25 |
26 | from . import inference
27 | from . import nmt
28 | from . import train
29 |
30 |
31 | def _update_flags(flags, test_name):
32 | """Update flags for basic training."""
33 | flags.num_train_steps = 100
34 | flags.steps_per_stats = 5
35 | flags.src = "en"
36 | flags.tgt = "vi"
37 | flags.train_prefix = ("nmt/testdata/"
38 | "iwslt15.tst2013.100")
39 | flags.vocab_prefix = ("nmt/testdata/"
40 | "iwslt15.vocab.100")
41 | flags.dev_prefix = ("nmt/testdata/"
42 | "iwslt15.tst2013.100")
43 | flags.test_prefix = ("nmt/testdata/"
44 | "iwslt15.tst2013.100")
45 | flags.out_dir = os.path.join(tf.test.get_temp_dir(), test_name)
46 |
47 |
48 | class NMTTest(tf.test.TestCase):
49 |
50 | def testTrain(self):
51 | """Test the training loop is functional with basic hparams."""
52 | nmt_parser = argparse.ArgumentParser()
53 | nmt.add_arguments(nmt_parser)
54 | FLAGS, unparsed = nmt_parser.parse_known_args()
55 |
56 | _update_flags(FLAGS, "nmt_train_test")
57 |
58 | default_hparams = nmt.create_hparams(FLAGS)
59 |
60 | train_fn = train.train
61 | nmt.run_main(FLAGS, default_hparams, train_fn, None)
62 |
63 |
64 | def testTrainWithAvgCkpts(self):
65 | """Test the training loop is functional with basic hparams."""
66 | nmt_parser = argparse.ArgumentParser()
67 | nmt.add_arguments(nmt_parser)
68 | FLAGS, unparsed = nmt_parser.parse_known_args()
69 |
70 | _update_flags(FLAGS, "nmt_train_test_avg_ckpts")
71 | FLAGS.avg_ckpts = True
72 |
73 | default_hparams = nmt.create_hparams(FLAGS)
74 |
75 | train_fn = train.train
76 | nmt.run_main(FLAGS, default_hparams, train_fn, None)
77 |
78 |
79 | def testInference(self):
80 | """Test inference is function with basic hparams."""
81 | nmt_parser = argparse.ArgumentParser()
82 | nmt.add_arguments(nmt_parser)
83 | FLAGS, unparsed = nmt_parser.parse_known_args()
84 |
85 | _update_flags(FLAGS, "nmt_train_infer")
86 |
87 | # Train one step so we have a checkpoint.
88 | FLAGS.num_train_steps = 1
89 | default_hparams = nmt.create_hparams(FLAGS)
90 | train_fn = train.train
91 | nmt.run_main(FLAGS, default_hparams, train_fn, None)
92 |
93 | # Update FLAGS for inference.
94 | FLAGS.inference_input_file = ("nmt/testdata/"
95 | "iwslt15.tst2013.100.en")
96 | FLAGS.inference_output_file = os.path.join(FLAGS.out_dir, "output")
97 | FLAGS.inference_ref_file = ("nmt/testdata/"
98 | "iwslt15.tst2013.100.vi")
99 |
100 | default_hparams = nmt.create_hparams(FLAGS)
101 |
102 | inference_fn = inference.inference
103 | nmt.run_main(FLAGS, default_hparams, None, inference_fn)
104 |
105 |
106 | if __name__ == "__main__":
107 | tf.test.main()
108 |
--------------------------------------------------------------------------------
/nmt/utils/nmt_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utility functions specifically for NMT."""
17 | from __future__ import print_function
18 |
19 | import codecs
20 | import time
21 | import numpy as np
22 | import tensorflow as tf
23 |
24 | from ..utils import evaluation_utils
25 | from ..utils import misc_utils as utils
26 |
27 | __all__ = ["decode_and_evaluate", "get_translation"]
28 |
29 |
30 | def decode_and_evaluate(name,
31 | model,
32 | sess,
33 | trans_file,
34 | ref_file,
35 | metrics,
36 | subword_option,
37 | beam_width,
38 | tgt_eos,
39 | num_translations_per_input=1,
40 | decode=True):
41 | """Decode a test set and compute a score according to the evaluation task."""
42 | # Decode
43 | if decode:
44 | utils.print_out(" decoding to output %s." % trans_file)
45 |
46 | start_time = time.time()
47 | num_sentences = 0
48 | with codecs.getwriter("utf-8")(
49 | tf.gfile.GFile(trans_file, mode="wb")) as trans_f:
50 | trans_f.write("") # Write empty string to ensure file is created.
51 |
52 | num_translations_per_input = max(
53 | min(num_translations_per_input, beam_width), 1)
54 | while True:
55 | try:
56 | nmt_outputs, _ = model.decode(sess)
57 | if beam_width == 0:
58 | nmt_outputs = np.expand_dims(nmt_outputs, 0)
59 |
60 | batch_size = nmt_outputs.shape[1]
61 | num_sentences += batch_size
62 |
63 | for sent_id in range(batch_size):
64 | for beam_id in range(num_translations_per_input):
65 | translation = get_translation(
66 | nmt_outputs[beam_id],
67 | sent_id,
68 | tgt_eos=tgt_eos,
69 | subword_option=subword_option)
70 | trans_f.write((translation + b"\n").decode("utf-8"))
71 | except tf.errors.OutOfRangeError:
72 | utils.print_time(
73 | " done, num sentences %d, num translations per input %d" %
74 | (num_sentences, num_translations_per_input), start_time)
75 | break
76 |
77 | # Evaluation
78 | evaluation_scores = {}
79 | if ref_file and tf.gfile.Exists(trans_file):
80 | for metric in metrics:
81 | score = evaluation_utils.evaluate(
82 | ref_file,
83 | trans_file,
84 | metric,
85 | subword_option=subword_option)
86 | evaluation_scores[metric] = score
87 | utils.print_out(" %s %s: %.1f" % (metric, name, score))
88 |
89 | return evaluation_scores
90 |
91 |
92 | def get_translation(nmt_outputs, sent_id, tgt_eos, subword_option):
93 | """Given batch decoding outputs, select a sentence and turn to text."""
94 | if tgt_eos: tgt_eos = tgt_eos.encode("utf-8")
95 | # Select a sentence
96 | output = nmt_outputs[sent_id, :].tolist()
97 |
98 | # If there is an eos symbol in outputs, cut them at that point.
99 | if tgt_eos and tgt_eos in output:
100 | output = output[:output.index(tgt_eos)]
101 |
102 | if subword_option == "bpe": # BPE
103 | translation = utils.format_bpe_text(output)
104 | elif subword_option == "spm": # SPM
105 | translation = utils.format_spm_text(output)
106 | else:
107 | translation = utils.format_text(output)
108 |
109 | return translation
110 |
--------------------------------------------------------------------------------
/texar/utils/raml_samples_generation/vocab.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | from collections import Counter
4 | from itertools import chain
5 |
6 | import torch
7 |
8 | from util import read_corpus
9 |
10 |
11 | class VocabEntry(object):
12 | def __init__(self):
13 | self.word2id = dict()
14 | self.unk_id = 3
15 | self.word2id[''] = 0
16 | self.word2id[''] = 1
17 | self.word2id[''] = 2
18 | self.word2id[''] = 3
19 |
20 | self.id2word = {v: k for k, v in self.word2id.iteritems()}
21 |
22 | def __getitem__(self, word):
23 | return self.word2id.get(word, self.unk_id)
24 |
25 | def __contains__(self, word):
26 | return word in self.word2id
27 |
28 | def __setitem__(self, key, value):
29 | raise ValueError('vocabulary is readonly')
30 |
31 | def __len__(self):
32 | return len(self.word2id)
33 |
34 | def __repr__(self):
35 | return 'Vocabulary[size=%d]' % len(self)
36 |
37 | def id2word(self, wid):
38 | return self.id2word[wid]
39 |
40 | def add(self, word):
41 | if word not in self:
42 | wid = self.word2id[word] = len(self)
43 | self.id2word[wid] = word
44 | return wid
45 | else:
46 | return self[word]
47 |
48 | @staticmethod
49 | def from_corpus(corpus, size, remove_singleton=True):
50 | vocab_entry = VocabEntry()
51 |
52 | word_freq = Counter(chain(*corpus))
53 | non_singletons = [w for w in word_freq if word_freq[w] > 1]
54 | print('number of word types: %d, number of word types w/ frequency > 1: %d' % (len(word_freq),
55 | len(non_singletons)))
56 |
57 | top_k_words = sorted(word_freq.keys(), reverse=True, key=word_freq.get)[:size]
58 |
59 | for word in top_k_words:
60 | if len(vocab_entry) < size:
61 | if not (word_freq[word] == 1 and remove_singleton):
62 | vocab_entry.add(word)
63 |
64 | return vocab_entry
65 |
66 |
67 | class Vocab(object):
68 | def __init__(self, src_sents, tgt_sents, src_vocab_size, tgt_vocab_size, remove_singleton=True):
69 | assert len(src_sents) == len(tgt_sents)
70 |
71 | print('initialize source vocabulary ..')
72 | self.src = VocabEntry.from_corpus(src_sents, src_vocab_size, remove_singleton=remove_singleton)
73 |
74 | print('initialize target vocabulary ..')
75 | self.tgt = VocabEntry.from_corpus(tgt_sents, tgt_vocab_size, remove_singleton=remove_singleton)
76 |
77 | def __repr__(self):
78 | return 'Vocab(source %d words, target %d words)' % (len(self.src), len(self.tgt))
79 |
80 |
81 | if __name__ == '__main__':
82 | parser = argparse.ArgumentParser()
83 | parser.add_argument('--src_vocab_size', default=50000, type=int, help='source vocabulary size')
84 | parser.add_argument('--tgt_vocab_size', default=50000, type=int, help='target vocabulary size')
85 | parser.add_argument('--include_singleton', action='store_true', default=False, help='whether to include singleton'
86 | 'in the vocabulary (default=False)')
87 |
88 | parser.add_argument('--train_src', type=str, required=True, help='file of source sentences')
89 | parser.add_argument('--train_tgt', type=str, required=True, help='file of target sentences')
90 |
91 | parser.add_argument('--output', default='vocab.bin', type=str, help='output vocabulary file')
92 |
93 | args = parser.parse_args()
94 |
95 | print('read in source sentences: %s' % args.train_src)
96 | print('read in target sentences: %s' % args.train_tgt)
97 |
98 | src_sents = read_corpus(args.train_src, source='src')
99 | tgt_sents = read_corpus(args.train_tgt, source='tgt')
100 |
101 | vocab = Vocab(src_sents, tgt_sents, args.src_vocab_size, args.tgt_vocab_size, remove_singleton=not args.include_singleton)
102 | print('generated vocabulary, source %d words, target %d words' % (len(vocab.src), len(vocab.tgt)))
103 |
104 | torch.save(vocab, args.output)
105 | print('vocabulary saved to %s' % args.output)
--------------------------------------------------------------------------------
/nmt/utils/vocab_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utility to handle vocabularies."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import codecs
23 | import os
24 | import tensorflow as tf
25 |
26 | from tensorflow.python.ops import lookup_ops
27 |
28 | from ..utils import misc_utils as utils
29 |
30 |
31 | UNK = ""
32 | SOS = ""
33 | EOS = ""
34 | UNK_ID = 0
35 |
36 |
37 | def load_vocab(vocab_file):
38 | vocab = []
39 | with codecs.getreader("utf-8")(tf.gfile.GFile(vocab_file, "rb")) as f:
40 | vocab_size = 0
41 | for word in f:
42 | vocab_size += 1
43 | vocab.append(word.strip())
44 | return vocab, vocab_size
45 |
46 |
47 | def check_vocab(vocab_file, out_dir, check_special_token=True, sos=None,
48 | eos=None, unk=None):
49 | """Check if vocab_file doesn't exist, create from corpus_file."""
50 | if tf.gfile.Exists(vocab_file):
51 | utils.print_out("# Vocab file %s exists" % vocab_file)
52 | vocab, vocab_size = load_vocab(vocab_file)
53 | if check_special_token:
54 | # Verify if the vocab starts with unk, sos, eos
55 | # If not, prepend those tokens & generate a new vocab file
56 | if not unk: unk = UNK
57 | if not sos: sos = SOS
58 | if not eos: eos = EOS
59 | assert len(vocab) >= 3
60 | if vocab[0] != unk or vocab[1] != sos or vocab[2] != eos:
61 | utils.print_out("The first 3 vocab words [%s, %s, %s]"
62 | " are not [%s, %s, %s]" %
63 | (vocab[0], vocab[1], vocab[2], unk, sos, eos))
64 | vocab = [unk, sos, eos] + vocab
65 | vocab_size += 3
66 | new_vocab_file = os.path.join(out_dir, os.path.basename(vocab_file))
67 | with codecs.getwriter("utf-8")(
68 | tf.gfile.GFile(new_vocab_file, "wb")) as f:
69 | for word in vocab:
70 | f.write("%s\n" % word)
71 | vocab_file = new_vocab_file
72 | else:
73 | raise ValueError("vocab_file '%s' does not exist." % vocab_file)
74 |
75 | vocab_size = len(vocab)
76 | return vocab_size, vocab_file
77 |
78 |
79 | def create_vocab_tables(src_vocab_file, tgt_vocab_file, share_vocab):
80 | """Creates vocab tables for src_vocab_file and tgt_vocab_file."""
81 | src_vocab_table = lookup_ops.index_table_from_file(
82 | src_vocab_file, default_value=UNK_ID)
83 | if share_vocab:
84 | tgt_vocab_table = src_vocab_table
85 | else:
86 | tgt_vocab_table = lookup_ops.index_table_from_file(
87 | tgt_vocab_file, default_value=UNK_ID)
88 | return src_vocab_table, tgt_vocab_table
89 |
90 |
91 | def load_embed_txt(embed_file):
92 | """Load embed_file into a python dictionary.
93 |
94 | Note: the embed_file should be a Glove formated txt file. Assuming
95 | embed_size=5, for example:
96 |
97 | the -0.071549 0.093459 0.023738 -0.090339 0.056123
98 | to 0.57346 0.5417 -0.23477 -0.3624 0.4037
99 | and 0.20327 0.47348 0.050877 0.002103 0.060547
100 |
101 | Args:
102 | embed_file: file path to the embedding file.
103 | Returns:
104 | a dictionary that maps word to vector, and the size of embedding dimensions.
105 | """
106 | emb_dict = dict()
107 | emb_size = None
108 | with codecs.getreader("utf-8")(tf.gfile.GFile(embed_file, 'rb')) as f:
109 | for line in f:
110 | tokens = line.strip().split(" ")
111 | word = tokens[0]
112 | vec = list(map(float, tokens[1:]))
113 | emb_dict[word] = vec
114 | if emb_size:
115 | assert emb_size == len(vec), "All embedding size should be same."
116 | else:
117 | emb_size = len(vec)
118 | return emb_dict, emb_size
119 |
--------------------------------------------------------------------------------
/nmt/scripts/bleu.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Python implementation of BLEU and smooth-BLEU.
17 |
18 | This module provides a Python implementation of BLEU and smooth-BLEU.
19 | Smooth BLEU is computed following the method outlined in the paper:
20 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic
21 | evaluation metrics for machine translation. COLING 2004.
22 | """
23 |
24 | import collections
25 | import math
26 |
27 |
28 | def _get_ngrams(segment, max_order):
29 | """Extracts all n-grams upto a given maximum order from an input segment.
30 |
31 | Args:
32 | segment: text segment from which n-grams will be extracted.
33 | max_order: maximum length in tokens of the n-grams returned by this
34 | methods.
35 |
36 | Returns:
37 | The Counter containing all n-grams upto max_order in segment
38 | with a count of how many times each n-gram occurred.
39 | """
40 | ngram_counts = collections.Counter()
41 | for order in range(1, max_order + 1):
42 | for i in range(0, len(segment) - order + 1):
43 | ngram = tuple(segment[i:i+order])
44 | ngram_counts[ngram] += 1
45 | return ngram_counts
46 |
47 |
48 | def compute_bleu(reference_corpus, translation_corpus, max_order=4,
49 | smooth=False):
50 | """Computes BLEU score of translated segments against one or more references.
51 |
52 | Args:
53 | reference_corpus: list of lists of references for each translation. Each
54 | reference should be tokenized into a list of tokens.
55 | translation_corpus: list of translations to score. Each translation
56 | should be tokenized into a list of tokens.
57 | max_order: Maximum n-gram order to use when computing BLEU score.
58 | smooth: Whether or not to apply Lin et al. 2004 smoothing.
59 |
60 | Returns:
61 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
62 | precisions and brevity penalty.
63 | """
64 | matches_by_order = [0] * max_order
65 | possible_matches_by_order = [0] * max_order
66 | reference_length = 0
67 | translation_length = 0
68 | for (references, translation) in zip(reference_corpus,
69 | translation_corpus):
70 | reference_length += min(len(r) for r in references)
71 | translation_length += len(translation)
72 |
73 | merged_ref_ngram_counts = collections.Counter()
74 | for reference in references:
75 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
76 | translation_ngram_counts = _get_ngrams(translation, max_order)
77 | overlap = translation_ngram_counts & merged_ref_ngram_counts
78 | for ngram in overlap:
79 | matches_by_order[len(ngram)-1] += overlap[ngram]
80 | for order in range(1, max_order+1):
81 | possible_matches = len(translation) - order + 1
82 | if possible_matches > 0:
83 | possible_matches_by_order[order-1] += possible_matches
84 |
85 | precisions = [0] * max_order
86 | for i in range(0, max_order):
87 | if smooth:
88 | precisions[i] = ((matches_by_order[i] + 1.) /
89 | (possible_matches_by_order[i] + 1.))
90 | else:
91 | if possible_matches_by_order[i] > 0:
92 | precisions[i] = (float(matches_by_order[i]) /
93 | possible_matches_by_order[i])
94 | else:
95 | precisions[i] = 0.0
96 |
97 | if min(precisions) > 0:
98 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
99 | geo_mean = math.exp(p_log_sum)
100 | else:
101 | geo_mean = 0
102 |
103 | ratio = float(translation_length) / reference_length
104 |
105 | if ratio > 1.0:
106 | bp = 1.
107 | else:
108 | bp = math.exp(1 - 1. / ratio)
109 |
110 | bleu = geo_mean * bp
111 |
112 | return (bleu, precisions, bp, ratio, translation_length, reference_length)
113 |
--------------------------------------------------------------------------------
/nmt/utils/common_test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Common utility functions for tests."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow as tf
23 |
24 | from tensorflow.python.ops import lookup_ops
25 |
26 | from ..utils import iterator_utils
27 | from ..utils import standard_hparams_utils
28 |
29 |
30 | def create_test_hparams(unit_type="lstm",
31 | encoder_type="uni",
32 | num_layers=4,
33 | attention="",
34 | attention_architecture=None,
35 | use_residual=False,
36 | inference_indices=None,
37 | num_translations_per_input=1,
38 | beam_width=0,
39 | init_op="uniform"):
40 | """Create training and inference test hparams."""
41 | num_residual_layers = 0
42 | if use_residual:
43 | # TODO(rzhao): Put num_residual_layers computation logic into
44 | # `model_utils.py`, so we can also test it here.
45 | num_residual_layers = 2
46 |
47 | standard_hparams = standard_hparams_utils.create_standard_hparams()
48 |
49 | # Networks
50 | standard_hparams.num_units = 5
51 | standard_hparams.num_encoder_layers = num_layers
52 | standard_hparams.num_decoder_layers = num_layers
53 | standard_hparams.dropout = 0.5
54 | standard_hparams.unit_type = unit_type
55 | standard_hparams.encoder_type = encoder_type
56 | standard_hparams.residual = use_residual
57 | standard_hparams.num_residual_layers = num_residual_layers
58 |
59 | # Attention mechanisms
60 | standard_hparams.attention = attention
61 | standard_hparams.attention_architecture = attention_architecture
62 |
63 | # Train
64 | standard_hparams.init_op = init_op
65 | standard_hparams.num_train_steps = 1
66 | standard_hparams.decay_scheme = ""
67 |
68 | # Infer
69 | standard_hparams.tgt_max_len_infer = 100
70 | standard_hparams.beam_width = beam_width
71 | standard_hparams.num_translations_per_input = num_translations_per_input
72 |
73 | # Misc
74 | standard_hparams.forget_bias = 0.0
75 | standard_hparams.random_seed = 3
76 |
77 | # Vocab
78 | standard_hparams.src_vocab_size = 5
79 | standard_hparams.tgt_vocab_size = 5
80 | standard_hparams.eos = "eos"
81 | standard_hparams.sos = "sos"
82 | standard_hparams.src_vocab_file = ""
83 | standard_hparams.tgt_vocab_file = ""
84 | standard_hparams.src_embed_file = ""
85 | standard_hparams.tgt_embed_file = ""
86 |
87 | # For inference.py test
88 | standard_hparams.subword_option = "bpe"
89 | standard_hparams.src = "src"
90 | standard_hparams.tgt = "tgt"
91 | standard_hparams.src_max_len = 400
92 | standard_hparams.tgt_eos_id = 0
93 | standard_hparams.inference_indices = inference_indices
94 | return standard_hparams
95 |
96 |
97 | def create_test_iterator(hparams, mode):
98 | """Create test iterator."""
99 | src_vocab_table = lookup_ops.index_table_from_tensor(
100 | tf.constant([hparams.eos, "a", "b", "c", "d"]))
101 | tgt_vocab_mapping = tf.constant([hparams.sos, hparams.eos, "a", "b", "c"])
102 | tgt_vocab_table = lookup_ops.index_table_from_tensor(tgt_vocab_mapping)
103 | if mode == tf.contrib.learn.ModeKeys.INFER:
104 | reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_tensor(
105 | tgt_vocab_mapping)
106 |
107 | src_dataset = tf.data.Dataset.from_tensor_slices(
108 | tf.constant(["a a b b c", "a b b"]))
109 |
110 | if mode != tf.contrib.learn.ModeKeys.INFER:
111 | tgt_dataset = tf.data.Dataset.from_tensor_slices(
112 | tf.constant(["a b c b c", "a b c b"]))
113 | return (
114 | iterator_utils.get_iterator(
115 | src_dataset=src_dataset,
116 | tgt_dataset=tgt_dataset,
117 | src_vocab_table=src_vocab_table,
118 | tgt_vocab_table=tgt_vocab_table,
119 | batch_size=hparams.batch_size,
120 | sos=hparams.sos,
121 | eos=hparams.eos,
122 | random_seed=hparams.random_seed,
123 | num_buckets=hparams.num_buckets),
124 | src_vocab_table,
125 | tgt_vocab_table)
126 | else:
127 | return (
128 | iterator_utils.get_infer_iterator(
129 | src_dataset=src_dataset,
130 | src_vocab_table=src_vocab_table,
131 | eos=hparams.eos,
132 | batch_size=hparams.batch_size),
133 | src_vocab_table,
134 | tgt_vocab_table,
135 | reverse_tgt_vocab_table)
136 |
--------------------------------------------------------------------------------
/nmt/utils/misc_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Generally useful utility functions."""
17 | from __future__ import print_function
18 |
19 | import codecs
20 | import collections
21 | import json
22 | import math
23 | import os
24 | import sys
25 | import time
26 |
27 | import numpy as np
28 | import tensorflow as tf
29 |
30 |
31 | def check_tensorflow_version():
32 | min_tf_version = "1.4.0-dev20171024"
33 | if tf.__version__ < min_tf_version:
34 | raise EnvironmentError("Tensorflow version must >= %s" % min_tf_version)
35 |
36 |
37 | def safe_exp(value):
38 | """Exponentiation with catching of overflow error."""
39 | try:
40 | ans = math.exp(value)
41 | except OverflowError:
42 | ans = float("inf")
43 | return ans
44 |
45 |
46 | def print_time(s, start_time):
47 | """Take a start time, print elapsed duration, and return a new time."""
48 | print("%s, time %ds, %s." % (s, (time.time() - start_time), time.ctime()))
49 | sys.stdout.flush()
50 | return time.time()
51 |
52 |
53 | def print_out(s, f=None, new_line=True):
54 | """Similar to print but with support to flush and output to a file."""
55 | if isinstance(s, bytes):
56 | s = s.decode("utf-8")
57 |
58 | if f:
59 | f.write(s.encode("utf-8"))
60 | if new_line:
61 | f.write(b"\n")
62 |
63 | # stdout
64 | out_s = s.encode("utf-8")
65 | if not isinstance(out_s, str):
66 | out_s = out_s.decode("utf-8")
67 | print(out_s, end="", file=sys.stdout)
68 |
69 | if new_line:
70 | sys.stdout.write("\n")
71 | sys.stdout.flush()
72 |
73 |
74 | def print_hparams(hparams, skip_patterns=None, header=None):
75 | """Print hparams, can skip keys based on pattern."""
76 | if header: print_out("%s" % header)
77 | values = hparams.values()
78 | for key in sorted(values.keys()):
79 | if not skip_patterns or all(
80 | [skip_pattern not in key for skip_pattern in skip_patterns]):
81 | print_out(" %s=%s" % (key, str(values[key])))
82 |
83 |
84 | def load_hparams(model_dir):
85 | """Load hparams from an existing model directory."""
86 | hparams_file = os.path.join(model_dir, "hparams")
87 | if tf.gfile.Exists(hparams_file):
88 | print_out("# Loading hparams from %s" % hparams_file)
89 | with codecs.getreader("utf-8")(tf.gfile.GFile(hparams_file, "rb")) as f:
90 | try:
91 | hparams_values = json.load(f)
92 | hparams = tf.contrib.training.HParams(**hparams_values)
93 | except ValueError:
94 | print_out(" can't load hparams file")
95 | return None
96 | return hparams
97 | else:
98 | return None
99 |
100 |
101 | def maybe_parse_standard_hparams(hparams, hparams_path):
102 | """Override hparams values with existing standard hparams config."""
103 | if not hparams_path:
104 | return hparams
105 |
106 | if tf.gfile.Exists(hparams_path):
107 | print_out("# Loading standard hparams from %s" % hparams_path)
108 | with tf.gfile.GFile(hparams_path, "r") as f:
109 | hparams.parse_json(f.read())
110 |
111 | return hparams
112 |
113 |
114 | def save_hparams(out_dir, hparams):
115 | """Save hparams."""
116 | hparams_file = os.path.join(out_dir, "hparams")
117 | print_out(" saving hparams to %s" % hparams_file)
118 | with codecs.getwriter("utf-8")(tf.gfile.GFile(hparams_file, "wb")) as f:
119 | f.write(hparams.to_json())
120 |
121 |
122 | def debug_tensor(s, msg=None, summarize=10):
123 | """Print the shape and value of a tensor at test time. Return a new tensor."""
124 | if not msg:
125 | msg = s.name
126 | return tf.Print(s, [tf.shape(s), s], msg + " ", summarize=summarize)
127 |
128 |
129 | def add_summary(summary_writer, global_step, tag, value):
130 | """Add a new summary to the current summary_writer.
131 | Useful to log things that are not part of the training graph, e.g., tag=BLEU.
132 | """
133 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
134 | summary_writer.add_summary(summary, global_step)
135 |
136 |
137 | def get_config_proto(log_device_placement=False, allow_soft_placement=True,
138 | num_intra_threads=0, num_inter_threads=0):
139 | # GPU options:
140 | # https://www.tensorflow.org/versions/r0.10/how_tos/using_gpu/index.html
141 | config_proto = tf.ConfigProto(
142 | log_device_placement=log_device_placement,
143 | allow_soft_placement=allow_soft_placement)
144 | config_proto.gpu_options.allow_growth = True
145 |
146 | # CPU threads options
147 | if num_intra_threads:
148 | config_proto.intra_op_parallelism_threads = num_intra_threads
149 | if num_inter_threads:
150 | config_proto.inter_op_parallelism_threads = num_inter_threads
151 |
152 | return config_proto
153 |
154 |
155 | def format_text(words):
156 | """Convert a sequence words into sentence."""
157 | if (not hasattr(words, "__len__") and # for numpy array
158 | not isinstance(words, collections.Iterable)):
159 | words = [words]
160 | return b" ".join(words)
161 |
162 |
163 | def format_bpe_text(symbols, delimiter=b"@@"):
164 | """Convert a sequence of bpe words into sentence."""
165 | words = []
166 | word = b""
167 | if isinstance(symbols, str):
168 | symbols = symbols.encode()
169 | delimiter_len = len(delimiter)
170 | for symbol in symbols:
171 | if len(symbol) >= delimiter_len and symbol[-delimiter_len:] == delimiter:
172 | word += symbol[:-delimiter_len]
173 | else: # end of a word
174 | word += symbol
175 | words.append(word)
176 | word = b""
177 | return b" ".join(words)
178 |
179 |
180 | def format_spm_text(symbols):
181 | """Decode a text in SPM (https://github.com/google/sentencepiece) format."""
182 | return u"".join(format_text(symbols).decode("utf-8").split()).replace(
183 | u"\u2581", u" ").strip().encode("utf-8")
184 |
--------------------------------------------------------------------------------
/nmt/utils/evaluation_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utility for evaluating various tasks, e.g., translation & summarization."""
17 | import codecs
18 | import os
19 | import re
20 | import subprocess
21 |
22 | import tensorflow as tf
23 |
24 | from ..scripts import bleu
25 | from ..scripts import rouge
26 |
27 |
28 | __all__ = ["evaluate"]
29 |
30 |
31 | def evaluate(ref_file, trans_file, metric, subword_option=None):
32 | """Pick a metric and evaluate depending on task."""
33 | # BLEU scores for translation task
34 | if metric.lower() == "bleu":
35 | evaluation_score = _bleu(ref_file, trans_file,
36 | subword_option=subword_option)
37 | # ROUGE scores for summarization tasks
38 | elif metric.lower() == "rouge":
39 | evaluation_score = _rouge(ref_file, trans_file,
40 | subword_option=subword_option)
41 | elif metric.lower() == "accuracy":
42 | evaluation_score = _accuracy(ref_file, trans_file)
43 | elif metric.lower() == "word_accuracy":
44 | evaluation_score = _word_accuracy(ref_file, trans_file)
45 | else:
46 | raise ValueError("Unknown metric %s" % metric)
47 |
48 | return evaluation_score
49 |
50 |
51 | def _clean(sentence, subword_option):
52 | """Clean and handle BPE or SPM outputs."""
53 | sentence = sentence.strip()
54 |
55 | # BPE
56 | if subword_option == "bpe":
57 | sentence = re.sub("@@ ", "", sentence)
58 |
59 | # SPM
60 | elif subword_option == "spm":
61 | sentence = u"".join(sentence.split()).replace(u"\u2581", u" ").lstrip()
62 |
63 | return sentence
64 |
65 |
66 | # Follow //transconsole/localization/machine_translation/metrics/bleu_calc.py
67 | def _bleu(ref_file, trans_file, subword_option=None):
68 | """Compute BLEU scores and handling BPE."""
69 | max_order = 4
70 | smooth = False
71 |
72 | ref_files = [ref_file]
73 | reference_text = []
74 | for reference_filename in ref_files:
75 | with codecs.getreader("utf-8")(
76 | tf.gfile.GFile(reference_filename, "rb")) as fh:
77 | reference_text.append(fh.readlines())
78 |
79 | per_segment_references = []
80 | for references in zip(*reference_text):
81 | reference_list = []
82 | for reference in references:
83 | reference = _clean(reference, subword_option)
84 | reference_list.append(reference.split(" "))
85 | per_segment_references.append(reference_list)
86 |
87 | translations = []
88 | with codecs.getreader("utf-8")(tf.gfile.GFile(trans_file, "rb")) as fh:
89 | for line in fh:
90 | line = _clean(line, subword_option=None)
91 | translations.append(line.split(" "))
92 |
93 | # bleu_score, precisions, bp, ratio, translation_length, reference_length
94 | bleu_score, _, _, _, _, _ = bleu.compute_bleu(
95 | per_segment_references, translations, max_order, smooth)
96 | return 100 * bleu_score
97 |
98 |
99 | def _rouge(ref_file, summarization_file, subword_option=None):
100 | """Compute ROUGE scores and handling BPE."""
101 |
102 | references = []
103 | with codecs.getreader("utf-8")(tf.gfile.GFile(ref_file, "rb")) as fh:
104 | for line in fh:
105 | references.append(_clean(line, subword_option))
106 |
107 | hypotheses = []
108 | with codecs.getreader("utf-8")(
109 | tf.gfile.GFile(summarization_file, "rb")) as fh:
110 | for line in fh:
111 | hypotheses.append(_clean(line, subword_option=None))
112 |
113 | rouge_score_map = rouge.rouge(hypotheses, references)
114 | return 100 * rouge_score_map["rouge_l/f_score"]
115 |
116 |
117 | def _accuracy(label_file, pred_file):
118 | """Compute accuracy, each line contains a label."""
119 |
120 | with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "rb")) as label_fh:
121 | with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "rb")) as pred_fh:
122 | count = 0.0
123 | match = 0.0
124 | for label in label_fh:
125 | label = label.strip()
126 | pred = pred_fh.readline().strip()
127 | if label == pred:
128 | match += 1
129 | count += 1
130 | return 100 * match / count
131 |
132 |
133 | def _word_accuracy(label_file, pred_file):
134 | """Compute accuracy on per word basis."""
135 |
136 | with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "r")) as label_fh:
137 | with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "r")) as pred_fh:
138 | total_acc, total_count = 0., 0.
139 | for sentence in label_fh:
140 | labels = sentence.strip().split(" ")
141 | preds = pred_fh.readline().strip().split(" ")
142 | match = 0.0
143 | for pos in range(min(len(labels), len(preds))):
144 | label = labels[pos]
145 | pred = preds[pos]
146 | if label == pred:
147 | match += 1
148 | total_acc += 100 * match / max(len(labels), len(preds))
149 | total_count += 1
150 | return total_acc / total_count
151 |
152 |
153 | def _moses_bleu(multi_bleu_script, tgt_test, trans_file, subword_option=None):
154 | """Compute BLEU scores using Moses multi-bleu.perl script."""
155 |
156 | # TODO(thangluong): perform rewrite using python
157 | # BPE
158 | if subword_option == "bpe":
159 | debpe_tgt_test = tgt_test + ".debpe"
160 | if not os.path.exists(debpe_tgt_test):
161 | # TODO(thangluong): not use shell=True, can be a security hazard
162 | subprocess.call("cp %s %s" % (tgt_test, debpe_tgt_test), shell=True)
163 | subprocess.call("sed s/@@ //g %s" % (debpe_tgt_test),
164 | shell=True)
165 | tgt_test = debpe_tgt_test
166 | elif subword_option == "spm":
167 | despm_tgt_test = tgt_test + ".despm"
168 | if not os.path.exists(despm_tgt_test):
169 | subprocess.call("cp %s %s" % (tgt_test, despm_tgt_test))
170 | subprocess.call("sed s/ //g %s" % (despm_tgt_test))
171 | subprocess.call(u"sed s/^\u2581/g %s" % (despm_tgt_test))
172 | subprocess.call(u"sed s/\u2581/ /g %s" % (despm_tgt_test))
173 | tgt_test = despm_tgt_test
174 | cmd = "%s %s < %s" % (multi_bleu_script, tgt_test, trans_file)
175 |
176 | # subprocess
177 | # TODO(thangluong): not use shell=True, can be a security hazard
178 | bleu_output = subprocess.check_output(cmd, shell=True)
179 |
180 | # extract BLEU score
181 | m = re.search("BLEU = (.+?),", bleu_output)
182 | bleu_score = float(m.group(1))
183 |
184 | return bleu_score
185 |
--------------------------------------------------------------------------------
/nmt/scripts/wmt16_en_de.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Copyright 2017 Google Inc.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | set -e
18 |
19 | BASE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )"
20 |
21 | OUTPUT_DIR="${1:-wmt16_de_en}"
22 | echo "Writing to ${OUTPUT_DIR}. To change this, set the OUTPUT_DIR environment variable."
23 |
24 | OUTPUT_DIR_DATA="${OUTPUT_DIR}/data"
25 | mkdir -p $OUTPUT_DIR_DATA
26 |
27 | echo "Downloading Europarl v7. This may take a while..."
28 | curl -o ${OUTPUT_DIR_DATA}/europarl-v7-de-en.tgz \
29 | http://www.statmt.org/europarl/v7/de-en.tgz
30 |
31 | echo "Downloading Common Crawl corpus. This may take a while..."
32 | curl -o ${OUTPUT_DIR_DATA}/common-crawl.tgz \
33 | http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz
34 |
35 | echo "Downloading News Commentary v11. This may take a while..."
36 | curl -o ${OUTPUT_DIR_DATA}/nc-v11.tgz \
37 | http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz
38 |
39 | echo "Downloading dev/test sets"
40 | curl -o ${OUTPUT_DIR_DATA}/dev.tgz \
41 | http://data.statmt.org/wmt16/translation-task/dev.tgz
42 | curl -o ${OUTPUT_DIR_DATA}/test.tgz \
43 | http://data.statmt.org/wmt16/translation-task/test.tgz
44 |
45 | # Extract everything
46 | echo "Extracting all files..."
47 | mkdir -p "${OUTPUT_DIR_DATA}/europarl-v7-de-en"
48 | tar -xvzf "${OUTPUT_DIR_DATA}/europarl-v7-de-en.tgz" -C "${OUTPUT_DIR_DATA}/europarl-v7-de-en"
49 | mkdir -p "${OUTPUT_DIR_DATA}/common-crawl"
50 | tar -xvzf "${OUTPUT_DIR_DATA}/common-crawl.tgz" -C "${OUTPUT_DIR_DATA}/common-crawl"
51 | mkdir -p "${OUTPUT_DIR_DATA}/nc-v11"
52 | tar -xvzf "${OUTPUT_DIR_DATA}/nc-v11.tgz" -C "${OUTPUT_DIR_DATA}/nc-v11"
53 | mkdir -p "${OUTPUT_DIR_DATA}/dev"
54 | tar -xvzf "${OUTPUT_DIR_DATA}/dev.tgz" -C "${OUTPUT_DIR_DATA}/dev"
55 | mkdir -p "${OUTPUT_DIR_DATA}/test"
56 | tar -xvzf "${OUTPUT_DIR_DATA}/test.tgz" -C "${OUTPUT_DIR_DATA}/test"
57 |
58 | # Concatenate Training data
59 | cat "${OUTPUT_DIR_DATA}/europarl-v7-de-en/europarl-v7.de-en.en" \
60 | "${OUTPUT_DIR_DATA}/common-crawl/commoncrawl.de-en.en" \
61 | "${OUTPUT_DIR_DATA}/nc-v11/training-parallel-nc-v11/news-commentary-v11.de-en.en" \
62 | > "${OUTPUT_DIR}/train.en"
63 | wc -l "${OUTPUT_DIR}/train.en"
64 |
65 | cat "${OUTPUT_DIR_DATA}/europarl-v7-de-en/europarl-v7.de-en.de" \
66 | "${OUTPUT_DIR_DATA}/common-crawl/commoncrawl.de-en.de" \
67 | "${OUTPUT_DIR_DATA}/nc-v11/training-parallel-nc-v11/news-commentary-v11.de-en.de" \
68 | > "${OUTPUT_DIR}/train.de"
69 | wc -l "${OUTPUT_DIR}/train.de"
70 |
71 | # Clone Moses
72 | if [ ! -d "${OUTPUT_DIR}/mosesdecoder" ]; then
73 | echo "Cloning moses for data processing"
74 | git clone https://github.com/moses-smt/mosesdecoder.git "${OUTPUT_DIR}/mosesdecoder"
75 | fi
76 |
77 | # Convert SGM files
78 | # Convert newstest2014 data into raw text format
79 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
80 | < ${OUTPUT_DIR_DATA}/dev/dev/newstest2014-deen-src.de.sgm \
81 | > ${OUTPUT_DIR_DATA}/dev/dev/newstest2014.de
82 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
83 | < ${OUTPUT_DIR_DATA}/dev/dev/newstest2014-deen-ref.en.sgm \
84 | > ${OUTPUT_DIR_DATA}/dev/dev/newstest2014.en
85 |
86 | # Convert newstest2015 data into raw text format
87 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
88 | < ${OUTPUT_DIR_DATA}/dev/dev/newstest2015-deen-src.de.sgm \
89 | > ${OUTPUT_DIR_DATA}/dev/dev/newstest2015.de
90 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
91 | < ${OUTPUT_DIR_DATA}/dev/dev/newstest2015-deen-ref.en.sgm \
92 | > ${OUTPUT_DIR_DATA}/dev/dev/newstest2015.en
93 |
94 | # Convert newstest2016 data into raw text format
95 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
96 | < ${OUTPUT_DIR_DATA}/test/test/newstest2016-deen-src.de.sgm \
97 | > ${OUTPUT_DIR_DATA}/test/test/newstest2016.de
98 | ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
99 | < ${OUTPUT_DIR_DATA}/test/test/newstest2016-deen-ref.en.sgm \
100 | > ${OUTPUT_DIR_DATA}/test/test/newstest2016.en
101 |
102 | # Copy dev/test data to output dir
103 | cp ${OUTPUT_DIR_DATA}/dev/dev/newstest20*.de ${OUTPUT_DIR}
104 | cp ${OUTPUT_DIR_DATA}/dev/dev/newstest20*.en ${OUTPUT_DIR}
105 | cp ${OUTPUT_DIR_DATA}/test/test/newstest20*.de ${OUTPUT_DIR}
106 | cp ${OUTPUT_DIR_DATA}/test/test/newstest20*.en ${OUTPUT_DIR}
107 |
108 | # Tokenize data
109 | for f in ${OUTPUT_DIR}/*.de; do
110 | echo "Tokenizing $f..."
111 | ${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -q -l de -threads 8 < $f > ${f%.*}.tok.de
112 | done
113 |
114 | for f in ${OUTPUT_DIR}/*.en; do
115 | echo "Tokenizing $f..."
116 | ${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -q -l en -threads 8 < $f > ${f%.*}.tok.en
117 | done
118 |
119 | # Clean train corpora
120 | for f in ${OUTPUT_DIR}/train.tok.en; do
121 | fbase=${f%.*}
122 | echo "Cleaning ${fbase}..."
123 | ${OUTPUT_DIR}/mosesdecoder/scripts/training/clean-corpus-n.perl $fbase de en "${fbase}.clean" 1 80
124 | done
125 |
126 | # Generate Subword Units (BPE)
127 | # Clone Subword NMT
128 | if [ ! -d "${OUTPUT_DIR}/subword-nmt" ]; then
129 | git clone https://github.com/rsennrich/subword-nmt.git "${OUTPUT_DIR}/subword-nmt"
130 | fi
131 |
132 | # Learn Shared BPE
133 | for merge_ops in 32000; do
134 | echo "Learning BPE with merge_ops=${merge_ops}. This may take a while..."
135 | cat "${OUTPUT_DIR}/train.tok.clean.de" "${OUTPUT_DIR}/train.tok.clean.en" | \
136 | ${OUTPUT_DIR}/subword-nmt/learn_bpe.py -s $merge_ops > "${OUTPUT_DIR}/bpe.${merge_ops}"
137 |
138 | echo "Apply BPE with merge_ops=${merge_ops} to tokenized files..."
139 | for lang in en de; do
140 | for f in ${OUTPUT_DIR}/*.tok.${lang} ${OUTPUT_DIR}/*.tok.clean.${lang}; do
141 | outfile="${f%.*}.bpe.${merge_ops}.${lang}"
142 | ${OUTPUT_DIR}/subword-nmt/apply_bpe.py -c "${OUTPUT_DIR}/bpe.${merge_ops}" < $f > "${outfile}"
143 | echo ${outfile}
144 | done
145 | done
146 |
147 | # Create vocabulary file for BPE
148 | echo -e "\n\n" > "${OUTPUT_DIR}/vocab.bpe.${merge_ops}"
149 | cat "${OUTPUT_DIR}/train.tok.clean.bpe.${merge_ops}.en" "${OUTPUT_DIR}/train.tok.clean.bpe.${merge_ops}.de" | \
150 | ${OUTPUT_DIR}/subword-nmt/get_vocab.py | cut -f1 -d ' ' >> "${OUTPUT_DIR}/vocab.bpe.${merge_ops}"
151 |
152 | done
153 |
154 | # Duplicate vocab file with language suffix
155 | cp "${OUTPUT_DIR}/vocab.bpe.32000" "${OUTPUT_DIR}/vocab.bpe.32000.en"
156 | cp "${OUTPUT_DIR}/vocab.bpe.32000" "${OUTPUT_DIR}/vocab.bpe.32000.de"
157 |
158 | echo "All done."
159 |
--------------------------------------------------------------------------------
/nmt/attention_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Attention-based sequence-to-sequence model with dynamic RNN support."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import tensorflow as tf
21 |
22 | from . import model
23 | from . import model_helper
24 |
25 | __all__ = ["AttentionModel"]
26 |
27 |
28 | class AttentionModel(model.Model):
29 | """Sequence-to-sequence dynamic model with attention.
30 |
31 | This class implements a multi-layer recurrent neural network as encoder,
32 | and an attention-based decoder. This is the same as the model described in
33 | (Luong et al., EMNLP'2015) paper: https://arxiv.org/pdf/1508.04025v5.pdf.
34 | This class also allows to use GRU cells in addition to LSTM cells with
35 | support for dropout.
36 | """
37 |
38 | def __init__(self,
39 | hparams,
40 | mode,
41 | iterator,
42 | source_vocab_table,
43 | target_vocab_table,
44 | reverse_target_vocab_table=None,
45 | scope=None,
46 | extra_args=None):
47 | # Set attention_mechanism_fn
48 | if extra_args and extra_args.attention_mechanism_fn:
49 | self.attention_mechanism_fn = extra_args.attention_mechanism_fn
50 | else:
51 | self.attention_mechanism_fn = create_attention_mechanism
52 |
53 | super(AttentionModel, self).__init__(
54 | hparams=hparams,
55 | mode=mode,
56 | iterator=iterator,
57 | source_vocab_table=source_vocab_table,
58 | target_vocab_table=target_vocab_table,
59 | reverse_target_vocab_table=reverse_target_vocab_table,
60 | scope=scope,
61 | extra_args=extra_args)
62 |
63 | if self.mode == tf.contrib.learn.ModeKeys.INFER:
64 | self.infer_summary = self._get_infer_summary(hparams)
65 |
66 | def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
67 | source_sequence_length):
68 | """Build a RNN cell with attention mechanism that can be used by decoder."""
69 | attention_option = hparams.attention
70 | attention_architecture = hparams.attention_architecture
71 |
72 | if attention_architecture != "standard":
73 | raise ValueError(
74 | "Unknown attention architecture %s" % attention_architecture)
75 |
76 | num_units = hparams.num_units
77 | num_layers = self.num_decoder_layers
78 | num_residual_layers = self.num_decoder_residual_layers
79 | beam_width = hparams.beam_width
80 |
81 | dtype = tf.float32
82 |
83 | # Ensure memory is batch-major
84 | if self.time_major:
85 | memory = tf.transpose(encoder_outputs, [1, 0, 2])
86 | else:
87 | memory = encoder_outputs
88 |
89 | if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0:
90 | memory = tf.contrib.seq2seq.tile_batch(
91 | memory, multiplier=beam_width)
92 | source_sequence_length = tf.contrib.seq2seq.tile_batch(
93 | source_sequence_length, multiplier=beam_width)
94 | encoder_state = tf.contrib.seq2seq.tile_batch(
95 | encoder_state, multiplier=beam_width)
96 | batch_size = self.batch_size * beam_width
97 | else:
98 | batch_size = self.batch_size
99 |
100 | attention_mechanism = self.attention_mechanism_fn(
101 | attention_option, num_units, memory, source_sequence_length, self.mode)
102 |
103 | cell = model_helper.create_rnn_cell(
104 | unit_type=hparams.unit_type,
105 | num_units=num_units,
106 | num_layers=num_layers,
107 | num_residual_layers=num_residual_layers,
108 | forget_bias=hparams.forget_bias,
109 | dropout=hparams.dropout,
110 | num_gpus=self.num_gpus,
111 | mode=self.mode,
112 | single_cell_fn=self.single_cell_fn)
113 |
114 | # Only generate alignment in greedy INFER mode.
115 | alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and
116 | beam_width == 0)
117 | cell = tf.contrib.seq2seq.AttentionWrapper(
118 | cell,
119 | attention_mechanism,
120 | attention_layer_size=num_units,
121 | alignment_history=alignment_history,
122 | output_attention=hparams.output_attention,
123 | name="attention")
124 |
125 | # TODO(thangluong): do we need num_layers, num_gpus?
126 | cell = tf.contrib.rnn.DeviceWrapper(cell,
127 | model_helper.get_device_str(
128 | num_layers - 1, self.num_gpus))
129 |
130 | if hparams.pass_hidden_state:
131 | decoder_initial_state = cell.zero_state(batch_size, dtype).clone(
132 | cell_state=encoder_state)
133 | else:
134 | decoder_initial_state = cell.zero_state(batch_size, dtype)
135 |
136 | return cell, decoder_initial_state
137 |
138 | def _get_infer_summary(self, hparams):
139 | if hparams.beam_width > 0:
140 | return tf.no_op()
141 | return _create_attention_images_summary(self.final_context_state)
142 |
143 |
144 | def create_attention_mechanism(attention_option, num_units, memory,
145 | source_sequence_length, mode):
146 | """Create attention mechanism based on the attention_option."""
147 | del mode # unused
148 |
149 | # Mechanism
150 | if attention_option == "luong":
151 | attention_mechanism = tf.contrib.seq2seq.LuongAttention(
152 | num_units, memory, memory_sequence_length=source_sequence_length)
153 | elif attention_option == "scaled_luong":
154 | attention_mechanism = tf.contrib.seq2seq.LuongAttention(
155 | num_units,
156 | memory,
157 | memory_sequence_length=source_sequence_length,
158 | scale=True)
159 | elif attention_option == "bahdanau":
160 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
161 | num_units, memory, memory_sequence_length=source_sequence_length)
162 | elif attention_option == "normed_bahdanau":
163 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
164 | num_units,
165 | memory,
166 | memory_sequence_length=source_sequence_length,
167 | normalize=True)
168 | else:
169 | raise ValueError("Unknown attention option %s" % attention_option)
170 |
171 | return attention_mechanism
172 |
173 |
174 | def _create_attention_images_summary(final_context_state):
175 | """create attention image and attention summary."""
176 | attention_images = (final_context_state.alignment_history.stack())
177 | # Reshape to (batch, src_seq_len, tgt_seq_len,1)
178 | attention_images = tf.expand_dims(
179 | tf.transpose(attention_images, [1, 2, 0]), -1)
180 | # Scale to range [0, 255]
181 | attention_images *= 255
182 | attention_summary = tf.summary.image("attention_images", attention_images)
183 | return attention_summary
184 |
--------------------------------------------------------------------------------
/nmt/inference_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for model inference."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 | import numpy as np
24 | import tensorflow as tf
25 |
26 | from . import attention_model
27 | from . import model_helper
28 | from . import model as nmt_model
29 | from . import gnmt_model
30 | from . import inference
31 | from .utils import common_test_utils
32 |
33 | float32 = np.float32
34 | int32 = np.int32
35 | array = np.array
36 |
37 |
38 | class InferenceTest(tf.test.TestCase):
39 |
40 | def _createTestInferCheckpoint(self, hparams, out_dir):
41 | if not hparams.attention:
42 | model_creator = nmt_model.Model
43 | elif hparams.attention_architecture == "standard":
44 | model_creator = attention_model.AttentionModel
45 | elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
46 | model_creator = gnmt_model.GNMTModel
47 | else:
48 | raise ValueError("Unknown model architecture")
49 |
50 | infer_model = model_helper.create_infer_model(model_creator, hparams)
51 | with self.test_session(graph=infer_model.graph) as sess:
52 | loaded_model, global_step = model_helper.create_or_load_model(
53 | infer_model.model, out_dir, sess, "infer_name")
54 | ckpt = loaded_model.saver.save(
55 | sess, os.path.join(out_dir, "translate.ckpt"),
56 | global_step=global_step)
57 | return ckpt
58 |
59 | def testBasicModel(self):
60 | hparams = common_test_utils.create_test_hparams(
61 | encoder_type="uni",
62 | num_layers=1,
63 | attention="",
64 | attention_architecture="",
65 | use_residual=False,)
66 | vocab_prefix = "nmt/testdata/test_infer_vocab"
67 | hparams.src_vocab_file = vocab_prefix + "." + hparams.src
68 | hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt
69 |
70 | infer_file = "nmt/testdata/test_infer_file"
71 | out_dir = os.path.join(tf.test.get_temp_dir(), "basic_infer")
72 | hparams.out_dir = out_dir
73 | os.makedirs(out_dir)
74 | output_infer = os.path.join(out_dir, "output_infer")
75 | ckpt = self._createTestInferCheckpoint(hparams, out_dir)
76 | inference.inference(ckpt, infer_file, output_infer, hparams)
77 | with open(output_infer) as f:
78 | self.assertEqual(5, len(list(f)))
79 |
80 | def testBasicModelWithMultipleTranslations(self):
81 | hparams = common_test_utils.create_test_hparams(
82 | encoder_type="uni",
83 | num_layers=1,
84 | attention="",
85 | attention_architecture="",
86 | use_residual=False,
87 | num_translations_per_input=2,
88 | beam_width=2,
89 | )
90 | vocab_prefix = "nmt/testdata/test_infer_vocab"
91 | hparams.src_vocab_file = vocab_prefix + "." + hparams.src
92 | hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt
93 |
94 | infer_file = "nmt/testdata/test_infer_file"
95 | out_dir = os.path.join(tf.test.get_temp_dir(), "multi_basic_infer")
96 | hparams.out_dir = out_dir
97 | os.makedirs(out_dir)
98 | output_infer = os.path.join(out_dir, "output_infer")
99 | ckpt = self._createTestInferCheckpoint(hparams, out_dir)
100 | inference.inference(ckpt, infer_file, output_infer, hparams)
101 | with open(output_infer) as f:
102 | self.assertEqual(10, len(list(f)))
103 |
104 | def testAttentionModel(self):
105 | hparams = common_test_utils.create_test_hparams(
106 | encoder_type="uni",
107 | num_layers=1,
108 | attention="scaled_luong",
109 | attention_architecture="standard",
110 | use_residual=False,)
111 | vocab_prefix = "nmt/testdata/test_infer_vocab"
112 | hparams.src_vocab_file = vocab_prefix + "." + hparams.src
113 | hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt
114 |
115 | infer_file = "nmt/testdata/test_infer_file"
116 | out_dir = os.path.join(tf.test.get_temp_dir(), "attention_infer")
117 | hparams.out_dir = out_dir
118 | os.makedirs(out_dir)
119 | output_infer = os.path.join(out_dir, "output_infer")
120 | ckpt = self._createTestInferCheckpoint(hparams, out_dir)
121 | inference.inference(ckpt, infer_file, output_infer, hparams)
122 | with open(output_infer) as f:
123 | self.assertEqual(5, len(list(f)))
124 |
125 | def testMultiWorkers(self):
126 | hparams = common_test_utils.create_test_hparams(
127 | encoder_type="uni",
128 | num_layers=2,
129 | attention="scaled_luong",
130 | attention_architecture="standard",
131 | use_residual=False,)
132 | vocab_prefix = "nmt/testdata/test_infer_vocab"
133 | hparams.src_vocab_file = vocab_prefix + "." + hparams.src
134 | hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt
135 |
136 | infer_file = "nmt/testdata/test_infer_file"
137 | out_dir = os.path.join(tf.test.get_temp_dir(), "multi_worker_infer")
138 | hparams.out_dir = out_dir
139 | os.makedirs(out_dir)
140 | output_infer = os.path.join(out_dir, "output_infer")
141 |
142 | num_workers = 3
143 |
144 | # There are 5 examples, make batch_size=3 makes job0 has 3 examples, job1
145 | # has 2 examples, and job2 has 0 example. This helps testing some edge
146 | # cases.
147 | hparams.batch_size = 3
148 |
149 | ckpt = self._createTestInferCheckpoint(hparams, out_dir)
150 | inference.inference(
151 | ckpt, infer_file, output_infer, hparams, num_workers, jobid=1)
152 |
153 | inference.inference(
154 | ckpt, infer_file, output_infer, hparams, num_workers, jobid=2)
155 |
156 | # Note: Need to start job 0 at the end; otherwise, it will block the testing
157 | # thread.
158 | inference.inference(
159 | ckpt, infer_file, output_infer, hparams, num_workers, jobid=0)
160 |
161 | with open(output_infer) as f:
162 | self.assertEqual(5, len(list(f)))
163 |
164 | def testBasicModelWithInferIndices(self):
165 | hparams = common_test_utils.create_test_hparams(
166 | encoder_type="uni",
167 | num_layers=1,
168 | attention="",
169 | attention_architecture="",
170 | use_residual=False,
171 | inference_indices=[0])
172 | vocab_prefix = "nmt/testdata/test_infer_vocab"
173 | hparams.src_vocab_file = vocab_prefix + "." + hparams.src
174 | hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt
175 |
176 | infer_file = "nmt/testdata/test_infer_file"
177 | out_dir = os.path.join(tf.test.get_temp_dir(), "basic_infer_with_indices")
178 | hparams.out_dir = out_dir
179 | os.makedirs(out_dir)
180 | output_infer = os.path.join(out_dir, "output_infer")
181 | ckpt = self._createTestInferCheckpoint(hparams, out_dir)
182 | inference.inference(ckpt, infer_file, output_infer, hparams)
183 | with open(output_infer) as f:
184 | self.assertEqual(1, len(list(f)))
185 |
186 | def testAttentionModelWithInferIndices(self):
187 | hparams = common_test_utils.create_test_hparams(
188 | encoder_type="uni",
189 | num_layers=1,
190 | attention="scaled_luong",
191 | attention_architecture="standard",
192 | use_residual=False,
193 | inference_indices=[1, 2])
194 | # TODO(rzhao): Make infer indices support batch_size > 1.
195 | hparams.infer_batch_size = 1
196 | vocab_prefix = "nmt/testdata/test_infer_vocab"
197 | hparams.src_vocab_file = vocab_prefix + "." + hparams.src
198 | hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt
199 |
200 | infer_file = "nmt/testdata/test_infer_file"
201 | out_dir = os.path.join(tf.test.get_temp_dir(),
202 | "attention_infer_with_indices")
203 | hparams.out_dir = out_dir
204 | os.makedirs(out_dir)
205 | output_infer = os.path.join(out_dir, "output_infer")
206 | ckpt = self._createTestInferCheckpoint(hparams, out_dir)
207 | inference.inference(ckpt, infer_file, output_infer, hparams)
208 | with open(output_infer) as f:
209 | self.assertEqual(2, len(list(f)))
210 | self.assertTrue(os.path.exists(output_infer+str(1)+".png"))
211 | self.assertTrue(os.path.exists(output_infer+str(2)+".png"))
212 |
213 |
214 | if __name__ == "__main__":
215 | tf.test.main()
216 |
--------------------------------------------------------------------------------
/nmt/utils/iterator_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """For loading data into NMT models."""
16 | from __future__ import print_function
17 |
18 | import collections
19 |
20 | import tensorflow as tf
21 |
22 | __all__ = ["BatchedInput", "get_iterator", "get_infer_iterator"]
23 |
24 |
25 | # NOTE(ebrevdo): When we subclass this, instances' __dict__ becomes empty.
26 | class BatchedInput(
27 | collections.namedtuple("BatchedInput",
28 | ("initializer", "source", "target_input",
29 | "target_output", "source_sequence_length",
30 | "target_sequence_length"))):
31 | pass
32 |
33 |
34 | def get_infer_iterator(src_dataset,
35 | src_vocab_table,
36 | batch_size,
37 | eos,
38 | src_max_len=None):
39 | src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32)
40 | src_dataset = src_dataset.map(lambda src: tf.string_split([src]).values)
41 |
42 | if src_max_len:
43 | src_dataset = src_dataset.map(lambda src: src[:src_max_len])
44 | # Convert the word strings to ids
45 | src_dataset = src_dataset.map(
46 | lambda src: tf.cast(src_vocab_table.lookup(src), tf.int32))
47 | # Add in the word counts.
48 | src_dataset = src_dataset.map(lambda src: (src, tf.size(src)))
49 |
50 | def batching_func(x):
51 | return x.padded_batch(
52 | batch_size,
53 | # The entry is the source line rows;
54 | # this has unknown-length vectors. The last entry is
55 | # the source row size; this is a scalar.
56 | padded_shapes=(
57 | tf.TensorShape([None]), # src
58 | tf.TensorShape([])), # src_len
59 | # Pad the source sequences with eos tokens.
60 | # (Though notice we don't generally need to do this since
61 | # later on we will be masking out calculations past the true sequence.
62 | padding_values=(
63 | src_eos_id, # src
64 | 0)) # src_len -- unused
65 |
66 | batched_dataset = batching_func(src_dataset)
67 | batched_iter = batched_dataset.make_initializable_iterator()
68 | (src_ids, src_seq_len) = batched_iter.get_next()
69 | return BatchedInput(
70 | initializer=batched_iter.initializer,
71 | source=src_ids,
72 | target_input=None,
73 | target_output=None,
74 | source_sequence_length=src_seq_len,
75 | target_sequence_length=None)
76 |
77 |
78 | def get_iterator(src_dataset,
79 | tgt_dataset,
80 | src_vocab_table,
81 | tgt_vocab_table,
82 | batch_size,
83 | sos,
84 | eos,
85 | random_seed,
86 | num_buckets,
87 | src_max_len=None,
88 | tgt_max_len=None,
89 | num_parallel_calls=4,
90 | output_buffer_size=None,
91 | skip_count=None,
92 | num_shards=1,
93 | shard_index=0,
94 | reshuffle_each_iteration=True):
95 | if not output_buffer_size:
96 | output_buffer_size = batch_size * 1000
97 | src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32)
98 | tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32)
99 | tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32)
100 |
101 | src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset))
102 |
103 | src_tgt_dataset = src_tgt_dataset.shard(num_shards, shard_index)
104 | if skip_count is not None:
105 | src_tgt_dataset = src_tgt_dataset.skip(skip_count)
106 |
107 | src_tgt_dataset = src_tgt_dataset.shuffle(
108 | output_buffer_size, random_seed, reshuffle_each_iteration)
109 |
110 | src_tgt_dataset = src_tgt_dataset.map(
111 | lambda src, tgt: (
112 | tf.string_split([src]).values, tf.string_split([tgt]).values),
113 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
114 |
115 | # Filter zero length input sequences.
116 | src_tgt_dataset = src_tgt_dataset.filter(
117 | lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0))
118 |
119 | if src_max_len:
120 | src_tgt_dataset = src_tgt_dataset.map(
121 | lambda src, tgt: (src[:src_max_len], tgt),
122 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
123 | if tgt_max_len:
124 | src_tgt_dataset = src_tgt_dataset.map(
125 | lambda src, tgt: (src, tgt[:tgt_max_len]),
126 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
127 | # Convert the word strings to ids. Word strings that are not in the
128 | # vocab get the lookup table's default_value integer.
129 | src_tgt_dataset = src_tgt_dataset.map(
130 | lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32),
131 | tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)),
132 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
133 | # Create a tgt_input prefixed with and a tgt_output suffixed with .
134 | src_tgt_dataset = src_tgt_dataset.map(
135 | lambda src, tgt: (src,
136 | tf.concat(([tgt_sos_id], tgt), 0),
137 | tf.concat((tgt, [tgt_eos_id]), 0)),
138 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
139 | # Add in sequence lengths.
140 | src_tgt_dataset = src_tgt_dataset.map(
141 | lambda src, tgt_in, tgt_out: (
142 | src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)),
143 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
144 |
145 | # Bucket by source sequence length (buckets for lengths 0-9, 10-19, ...)
146 | def batching_func(x):
147 | return x.padded_batch(
148 | batch_size,
149 | # The first three entries are the source and target line rows;
150 | # these have unknown-length vectors. The last two entries are
151 | # the source and target row sizes; these are scalars.
152 | padded_shapes=(
153 | tf.TensorShape([None]), # src
154 | tf.TensorShape([None]), # tgt_input
155 | tf.TensorShape([None]), # tgt_output
156 | tf.TensorShape([]), # src_len
157 | tf.TensorShape([])), # tgt_len
158 | # Pad the source and target sequences with eos tokens.
159 | # (Though notice we don't generally need to do this since
160 | # later on we will be masking out calculations past the true sequence.
161 | padding_values=(
162 | src_eos_id, # src
163 | tgt_eos_id, # tgt_input
164 | tgt_eos_id, # tgt_output
165 | 0, # src_len -- unused
166 | 0)) # tgt_len -- unused
167 |
168 | if num_buckets > 1:
169 |
170 | def key_func(unused_1, unused_2, unused_3, src_len, tgt_len):
171 | # Calculate bucket_width by maximum source sequence length.
172 | # Pairs with length [0, bucket_width) go to bucket 0, length
173 | # [bucket_width, 2 * bucket_width) go to bucket 1, etc. Pairs with length
174 | # over ((num_bucket-1) * bucket_width) words all go into the last bucket.
175 | if src_max_len:
176 | bucket_width = (src_max_len + num_buckets - 1) // num_buckets
177 | else:
178 | bucket_width = 10
179 |
180 | # Bucket sentence pairs by the length of their source sentence and target
181 | # sentence.
182 | bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width)
183 | return tf.to_int64(tf.minimum(num_buckets, bucket_id))
184 |
185 | def reduce_func(unused_key, windowed_data):
186 | return batching_func(windowed_data)
187 |
188 | batched_dataset = src_tgt_dataset.apply(
189 | tf.contrib.data.group_by_window(
190 | key_func=key_func, reduce_func=reduce_func, window_size=batch_size))
191 |
192 | else:
193 | batched_dataset = batching_func(src_tgt_dataset)
194 | batched_iter = batched_dataset.make_initializable_iterator()
195 | (src_ids, tgt_input_ids, tgt_output_ids, src_seq_len,
196 | tgt_seq_len) = (batched_iter.get_next())
197 | return BatchedInput(
198 | initializer=batched_iter.initializer,
199 | source=src_ids,
200 | target_input=tgt_input_ids,
201 | target_output=tgt_output_ids,
202 | source_sequence_length=src_seq_len,
203 | target_sequence_length=tgt_seq_len)
204 |
--------------------------------------------------------------------------------
/nmt/inference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """To perform inference on test set given a trained model."""
17 | from __future__ import print_function
18 |
19 | import codecs
20 | import time
21 |
22 | import tensorflow as tf
23 |
24 | from . import attention_model
25 | from . import gnmt_model
26 | from . import model as nmt_model
27 | from . import model_helper
28 | from .utils import misc_utils as utils
29 | from .utils import nmt_utils
30 |
31 | __all__ = ["load_data", "inference",
32 | "single_worker_inference", "multi_worker_inference"]
33 |
34 |
35 | def _decode_inference_indices(model, sess, output_infer,
36 | output_infer_summary_prefix,
37 | inference_indices,
38 | tgt_eos,
39 | subword_option):
40 | """Decoding only a specific set of sentences."""
41 | utils.print_out(" decoding to output %s , num sents %d." %
42 | (output_infer, len(inference_indices)))
43 | start_time = time.time()
44 | with codecs.getwriter("utf-8")(
45 | tf.gfile.GFile(output_infer, mode="wb")) as trans_f:
46 | trans_f.write("") # Write empty string to ensure file is created.
47 | for decode_id in inference_indices:
48 | nmt_outputs, infer_summary = model.decode(sess)
49 |
50 | # get text translation
51 | assert nmt_outputs.shape[0] == 1
52 | translation = nmt_utils.get_translation(
53 | nmt_outputs,
54 | sent_id=0,
55 | tgt_eos=tgt_eos,
56 | subword_option=subword_option)
57 |
58 | if infer_summary is not None: # Attention models
59 | image_file = output_infer_summary_prefix + str(decode_id) + ".png"
60 | utils.print_out(" save attention image to %s*" % image_file)
61 | image_summ = tf.Summary()
62 | image_summ.ParseFromString(infer_summary)
63 | with tf.gfile.GFile(image_file, mode="w") as img_f:
64 | img_f.write(image_summ.value[0].image.encoded_image_string)
65 |
66 | trans_f.write("%s\n" % translation)
67 | utils.print_out(translation + b"\n")
68 | utils.print_time(" done", start_time)
69 |
70 |
71 | def load_data(inference_input_file, hparams=None):
72 | """Load inference data."""
73 | with codecs.getreader("utf-8")(
74 | tf.gfile.GFile(inference_input_file, mode="rb")) as f:
75 | inference_data = f.read().splitlines()
76 |
77 | if hparams and hparams.inference_indices:
78 | inference_data = [inference_data[i] for i in hparams.inference_indices]
79 |
80 | return inference_data
81 |
82 |
83 | def inference(ckpt,
84 | inference_input_file,
85 | inference_output_file,
86 | hparams,
87 | num_workers=1,
88 | jobid=0,
89 | scope=None):
90 | """Perform translation."""
91 | if hparams.inference_indices:
92 | assert num_workers == 1
93 |
94 | if not hparams.attention:
95 | model_creator = nmt_model.Model
96 | elif hparams.attention_architecture == "standard":
97 | model_creator = attention_model.AttentionModel
98 | elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
99 | model_creator = gnmt_model.GNMTModel
100 | else:
101 | raise ValueError("Unknown model architecture")
102 | infer_model = model_helper.create_infer_model(model_creator, hparams, scope)
103 |
104 | if num_workers == 1:
105 | single_worker_inference(
106 | infer_model,
107 | ckpt,
108 | inference_input_file,
109 | inference_output_file,
110 | hparams)
111 | else:
112 | multi_worker_inference(
113 | infer_model,
114 | ckpt,
115 | inference_input_file,
116 | inference_output_file,
117 | hparams,
118 | num_workers=num_workers,
119 | jobid=jobid)
120 |
121 |
122 | def single_worker_inference(infer_model,
123 | ckpt,
124 | inference_input_file,
125 | inference_output_file,
126 | hparams):
127 | """Inference with a single worker."""
128 | output_infer = inference_output_file
129 |
130 | # Read data
131 | infer_data = load_data(inference_input_file, hparams)
132 |
133 | with tf.Session(
134 | graph=infer_model.graph, config=utils.get_config_proto()) as sess:
135 | loaded_infer_model = model_helper.load_model(
136 | infer_model.model, ckpt, sess, "infer")
137 | sess.run(
138 | infer_model.iterator.initializer,
139 | feed_dict={
140 | infer_model.src_placeholder: infer_data,
141 | infer_model.batch_size_placeholder: hparams.infer_batch_size
142 | })
143 | # Decode
144 | utils.print_out("# Start decoding")
145 | if hparams.inference_indices:
146 | _decode_inference_indices(
147 | loaded_infer_model,
148 | sess,
149 | output_infer=output_infer,
150 | output_infer_summary_prefix=output_infer,
151 | inference_indices=hparams.inference_indices,
152 | tgt_eos=hparams.eos,
153 | subword_option=hparams.subword_option)
154 | else:
155 | nmt_utils.decode_and_evaluate(
156 | "infer",
157 | loaded_infer_model,
158 | sess,
159 | output_infer,
160 | ref_file=None,
161 | metrics=hparams.metrics,
162 | subword_option=hparams.subword_option,
163 | beam_width=hparams.beam_width,
164 | tgt_eos=hparams.eos,
165 | num_translations_per_input=hparams.num_translations_per_input)
166 |
167 |
168 | def multi_worker_inference(infer_model,
169 | ckpt,
170 | inference_input_file,
171 | inference_output_file,
172 | hparams,
173 | num_workers,
174 | jobid):
175 | """Inference using multiple workers."""
176 | assert num_workers > 1
177 |
178 | final_output_infer = inference_output_file
179 | output_infer = "%s_%d" % (inference_output_file, jobid)
180 | output_infer_done = "%s_done_%d" % (inference_output_file, jobid)
181 |
182 | # Read data
183 | infer_data = load_data(inference_input_file, hparams)
184 |
185 | # Split data to multiple workers
186 | total_load = len(infer_data)
187 | load_per_worker = int((total_load - 1) / num_workers) + 1
188 | start_position = jobid * load_per_worker
189 | end_position = min(start_position + load_per_worker, total_load)
190 | infer_data = infer_data[start_position:end_position]
191 |
192 | with tf.Session(
193 | graph=infer_model.graph, config=utils.get_config_proto()) as sess:
194 | loaded_infer_model = model_helper.load_model(
195 | infer_model.model, ckpt, sess, "infer")
196 | sess.run(infer_model.iterator.initializer,
197 | {
198 | infer_model.src_placeholder: infer_data,
199 | infer_model.batch_size_placeholder: hparams.infer_batch_size
200 | })
201 | # Decode
202 | utils.print_out("# Start decoding")
203 | nmt_utils.decode_and_evaluate(
204 | "infer",
205 | loaded_infer_model,
206 | sess,
207 | output_infer,
208 | ref_file=None,
209 | metrics=hparams.metrics,
210 | subword_option=hparams.subword_option,
211 | beam_width=hparams.beam_width,
212 | tgt_eos=hparams.eos,
213 | num_translations_per_input=hparams.num_translations_per_input)
214 |
215 | # Change file name to indicate the file writing is completed.
216 | tf.gfile.Rename(output_infer, output_infer_done, overwrite=True)
217 |
218 | # Job 0 is responsible for the clean up.
219 | if jobid != 0: return
220 |
221 | # Now write all translations
222 | with codecs.getwriter("utf-8")(
223 | tf.gfile.GFile(final_output_infer, mode="wb")) as final_f:
224 | for worker_id in range(num_workers):
225 | worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id)
226 | while not tf.gfile.Exists(worker_infer_done):
227 | utils.print_out(" waitting job %d to complete." % worker_id)
228 | time.sleep(10)
229 |
230 | with codecs.getreader("utf-8")(
231 | tf.gfile.GFile(worker_infer_done, mode="rb")) as f:
232 | for translation in f:
233 | final_f.write("%s" % translation)
234 |
235 | for worker_id in range(num_workers):
236 | worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id)
237 | tf.gfile.Remove(worker_infer_done)
238 |
--------------------------------------------------------------------------------
/texar/baseline_seq2seq_attn_main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Texar Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Attentional Seq2seq.
16 | same as examples/seq2seq_attn except that here Rouge is also supported.
17 | """
18 | from __future__ import absolute_import
19 | from __future__ import print_function
20 | from __future__ import division
21 | from __future__ import unicode_literals
22 |
23 | # pylint: disable=invalid-name, too-many-arguments, too-many-locals
24 |
25 | from io import open
26 | import importlib
27 | import tensorflow as tf
28 | import texar as tx
29 | from rouge import Rouge
30 | import pdb
31 | import OT
32 |
33 | flags = tf.flags
34 |
35 | flags.DEFINE_string("config_model", "configs.config_model", "The model config.")
36 | flags.DEFINE_string("config_data", "configs.config_iwslt14",
37 | "The dataset config.")
38 |
39 | flags.DEFINE_string('output_dir', '.', 'where to keep training logs')
40 |
41 | FLAGS = flags.FLAGS
42 |
43 | config_model = importlib.import_module(FLAGS.config_model)
44 | config_data = importlib.import_module(FLAGS.config_data)
45 |
46 | if not FLAGS.output_dir.endswith('/'):
47 | FLAGS.output_dir += '/'
48 | log_dir = FLAGS.output_dir + 'training_log_baseline/'
49 | tx.utils.maybe_create_dir(log_dir)
50 |
51 |
52 | def build_model(batch, train_data):
53 | """Assembles the seq2seq model.
54 | """
55 | source_embedder = tx.modules.WordEmbedder(
56 | vocab_size=train_data.source_vocab.size, hparams=config_model.embedder)
57 |
58 | encoder = tx.modules.BidirectionalRNNEncoder(
59 | hparams=config_model.encoder)
60 |
61 | enc_outputs, _ = encoder(source_embedder(batch['source_text_ids']))
62 |
63 | target_embedder = tx.modules.WordEmbedder(
64 | vocab_size=train_data.target_vocab.size, hparams=config_model.embedder)
65 |
66 | decoder = tx.modules.AttentionRNNDecoder(
67 | memory=tf.concat(enc_outputs, axis=2),
68 | memory_sequence_length=batch['source_length'],
69 | vocab_size=train_data.target_vocab.size,
70 | hparams=config_model.decoder)
71 |
72 | training_outputs, _, _ = decoder(
73 | decoding_strategy='train_greedy',
74 | inputs=target_embedder(batch['target_text_ids'][:, :-1]),
75 | sequence_length=batch['target_length'] - 1)
76 |
77 | MLE_loss = tx.losses.sequence_sparse_softmax_cross_entropy(
78 | labels=batch['target_text_ids'][:, 1:],
79 | logits=training_outputs.logits,
80 | sequence_length=batch['target_length'] - 1)
81 |
82 | train_op = tx.core.get_train_op(
83 | MLE_loss,
84 | hparams=config_model.opt)
85 |
86 |
87 | start_tokens = tf.ones_like(batch['target_length']) *\
88 | train_data.target_vocab.bos_token_id
89 | beam_search_outputs, _, _ = \
90 | tx.modules.beam_search_decode(
91 | decoder_or_cell=decoder,
92 | embedding=target_embedder,
93 | start_tokens=start_tokens,
94 | end_token=train_data.target_vocab.eos_token_id,
95 | beam_width=config_model.beam_width,
96 | max_decoding_length=60)
97 |
98 | return train_op, beam_search_outputs
99 |
100 |
101 | def print_stdout_and_file(content, file):
102 | print(content)
103 | print(content, file=file)
104 |
105 |
106 | def main():
107 | """Entrypoint.
108 | """
109 | train_data = tx.data.PairedTextData(hparams=config_data.train)
110 | val_data = tx.data.PairedTextData(hparams=config_data.val)
111 | test_data = tx.data.PairedTextData(hparams=config_data.test)
112 |
113 | data_iterator = tx.data.TrainTestDataIterator(
114 | train=train_data, val=val_data, test=test_data)
115 |
116 | batch = data_iterator.get_next()
117 |
118 | train_op, infer_outputs = build_model(batch, train_data)
119 |
120 | def _train_epoch(sess, epoch_no):
121 | data_iterator.switch_to_train_data(sess)
122 | training_log_file = \
123 | open(log_dir + 'training_log' + str(epoch_no) + '.txt', 'w',
124 | encoding='utf-8')
125 |
126 | step = 0
127 | while True:
128 | try:
129 | loss = sess.run(train_op)
130 | print("step={}, loss={:.4f}".format(step, loss),
131 | file=training_log_file)
132 | if step % config_data.observe_steps == 0:
133 | print("step={}, loss={:.4f}".format(step, loss))
134 | training_log_file.flush()
135 | step += 1
136 | except tf.errors.OutOfRangeError:
137 | break
138 |
139 | def _eval_epoch(sess, mode, epoch_no):
140 | if mode == 'val':
141 | data_iterator.switch_to_val_data(sess)
142 | else:
143 | data_iterator.switch_to_test_data(sess)
144 |
145 | refs, hypos = [], []
146 | while True:
147 | try:
148 | fetches = [
149 | batch['target_text'][:, 1:],
150 | infer_outputs.predicted_ids[:, :, 0]
151 | ]
152 | feed_dict = {
153 | tx.global_mode(): tf.estimator.ModeKeys.EVAL
154 | }
155 | target_texts_ori, output_ids = \
156 | sess.run(fetches, feed_dict=feed_dict)
157 |
158 | target_texts = tx.utils.strip_special_tokens(
159 | target_texts_ori.tolist(), is_token_list=True)
160 | target_texts = tx.utils.str_join(target_texts)
161 | output_texts = tx.utils.map_ids_to_strs(
162 | ids=output_ids, vocab=val_data.target_vocab)
163 |
164 | tx.utils.write_paired_text(
165 | target_texts, output_texts,
166 | log_dir + mode + '_results' + str(epoch_no) + '.txt',
167 | append=True, mode='h', sep=' ||| ')
168 |
169 | for hypo, ref in zip(output_texts, target_texts):
170 | if config_data.eval_metric == 'bleu':
171 | hypos.append(hypo)
172 | refs.append([ref])
173 | elif config_data.eval_metric == 'rouge':
174 | hypos.append(tx.utils.compat_as_text(hypo))
175 | refs.append(tx.utils.compat_as_text(ref))
176 | except tf.errors.OutOfRangeError:
177 | break
178 |
179 | if config_data.eval_metric == 'bleu':
180 | return tx.evals.corpus_bleu_moses(
181 | list_of_references=refs, hypotheses=hypos)
182 | elif config_data.eval_metric == 'rouge':
183 | rouge = Rouge()
184 | return rouge.get_scores(hyps=hypos, refs=refs, avg=True)
185 |
186 | def _calc_reward(score):
187 | """
188 | Return the bleu score or the sum of (Rouge-1, Rouge-2, Rouge-L).
189 | """
190 | if config_data.eval_metric == 'bleu':
191 | return score
192 | elif config_data.eval_metric == 'rouge':
193 | return sum([value['f'] for key, value in score.items()])
194 |
195 | with tf.Session() as sess:
196 | sess.run(tf.global_variables_initializer())
197 | sess.run(tf.local_variables_initializer())
198 | sess.run(tf.tables_initializer())
199 |
200 | best_val_score = -1.
201 | scores_file = open(log_dir + 'scores.txt', 'w', encoding='utf-8')
202 | for i in range(config_data.num_epochs):
203 | _train_epoch(sess, i)
204 |
205 | val_score = _eval_epoch(sess, 'val', i)
206 | test_score = _eval_epoch(sess, 'test', i)
207 |
208 | best_val_score = max(best_val_score, _calc_reward(val_score))
209 |
210 | if config_data.eval_metric == 'bleu':
211 | print_stdout_and_file(
212 | 'val epoch={}, BLEU={:.4f}; best-ever={:.4f}'.format(
213 | i, val_score, best_val_score), file=scores_file)
214 |
215 | print_stdout_and_file(
216 | 'test epoch={}, BLEU={:.4f}'.format(i, test_score),
217 | file=scores_file)
218 | print_stdout_and_file('=' * 50, file=scores_file)
219 |
220 | elif config_data.eval_metric == 'rouge':
221 | print_stdout_and_file(
222 | 'valid epoch {}:'.format(i), file=scores_file)
223 | for key, value in val_score.items():
224 | print_stdout_and_file(
225 | '{}: {}'.format(key, value), file=scores_file)
226 | print_stdout_and_file('fsum: {}; best_val_fsum: {}'.format(
227 | _calc_reward(val_score), best_val_score), file=scores_file)
228 |
229 | print_stdout_and_file(
230 | 'test epoch {}:'.format(i), file=scores_file)
231 | for key, value in test_score.items():
232 | print_stdout_and_file(
233 | '{}: {}'.format(key, value), file=scores_file)
234 | print_stdout_and_file('=' * 110, file=scores_file)
235 |
236 | scores_file.flush()
237 |
238 |
239 | if __name__ == '__main__':
240 | main()
241 |
--------------------------------------------------------------------------------
/texar/baseline_seq2seq_attn_ot.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Texar Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Attentional Seq2seq.
16 | same as examples/seq2seq_attn except that here Rouge is also supported.
17 | """
18 | from __future__ import absolute_import
19 | from __future__ import print_function
20 | from __future__ import division
21 | from __future__ import unicode_literals
22 |
23 | # pylint: disable=invalid-name, too-many-arguments, too-many-locals
24 | import os
25 | from io import open
26 | import importlib
27 | import tensorflow as tf
28 | import texar as tx
29 | from rouge import Rouge
30 | import OT
31 | import pdb
32 |
33 | GPUID = 0
34 | os.environ["CUDA_VISIBLE_DEVICES"] = str(GPUID)
35 |
36 | flags = tf.flags
37 |
38 | flags.DEFINE_string("config_model", "configs.config_model", "The model config.")
39 | flags.DEFINE_string("config_data", "configs.config_iwslt14",
40 | "The dataset config.")
41 |
42 | flags.DEFINE_string('output_dir', '.', 'where to keep training logs')
43 |
44 | FLAGS = flags.FLAGS
45 |
46 | config_model = importlib.import_module(FLAGS.config_model)
47 | config_data = importlib.import_module(FLAGS.config_data)
48 |
49 | if not FLAGS.output_dir.endswith('/'):
50 | FLAGS.output_dir += '/'
51 | log_dir = FLAGS.output_dir + 'training_log_baseline/'
52 | tx.utils.maybe_create_dir(log_dir)
53 |
54 |
55 | def build_model(batch, train_data):
56 | """Assembles the seq2seq model.
57 | """
58 | source_embedder = tx.modules.WordEmbedder(
59 | vocab_size=train_data.source_vocab.size, hparams=config_model.embedder)
60 |
61 | encoder = tx.modules.BidirectionalRNNEncoder(
62 | hparams=config_model.encoder)
63 |
64 | enc_outputs, _ = encoder(source_embedder(batch['source_text_ids']))
65 |
66 | target_embedder = tx.modules.WordEmbedder(
67 | vocab_size=train_data.target_vocab.size, hparams=config_model.embedder)
68 |
69 | decoder = tx.modules.AttentionRNNDecoder(
70 | memory=tf.concat(enc_outputs, axis=2),
71 | memory_sequence_length=batch['source_length'],
72 | vocab_size=train_data.target_vocab.size,
73 | hparams=config_model.decoder)
74 |
75 | training_outputs, _, _ = decoder(
76 | decoding_strategy='train_greedy',
77 | inputs=target_embedder(batch['target_text_ids'][:, :-1]),
78 | sequence_length=batch['target_length'] - 1)
79 |
80 | # Modify loss
81 | MLE_loss = tx.losses.sequence_sparse_softmax_cross_entropy(
82 | labels=batch['target_text_ids'][:, 1:],
83 | logits=training_outputs.logits,
84 | sequence_length=batch['target_length'] - 1)
85 |
86 | # TODO: key words matching loss
87 | tgt_logits = training_outputs.logits
88 | tgt_words = target_embedder(soft_ids = tgt_logits)
89 | src_words = source_embedder(ids = batch['source_text_ids'])
90 | src_words = tf.nn.l2_normalize(src_words, 2, epsilon=1e-12)
91 | tgt_words = tf.nn.l2_normalize(tgt_words, 2, epsilon=1e-12)
92 |
93 | cosine_cost = 1 - tf.einsum(
94 | 'aij,ajk->aik', src_words, tf.transpose(tgt_words, [0,2,1]))
95 | # pdb.set_trace()
96 | OT_loss = tf.reduce_mean(OT.IPOT_distance2(cosine_cost))
97 |
98 | Total_loss = MLE_loss + 0.1 * OT_loss
99 |
100 | train_op = tx.core.get_train_op(
101 | Total_loss,
102 | hparams=config_model.opt)
103 |
104 |
105 | start_tokens = tf.ones_like(batch['target_length']) *\
106 | train_data.target_vocab.bos_token_id
107 | beam_search_outputs, _, _ = \
108 | tx.modules.beam_search_decode(
109 | decoder_or_cell=decoder,
110 | embedding=target_embedder,
111 | start_tokens=start_tokens,
112 | end_token=train_data.target_vocab.eos_token_id,
113 | beam_width=config_model.beam_width,
114 | max_decoding_length=60)
115 |
116 | return train_op, beam_search_outputs
117 |
118 |
119 | def print_stdout_and_file(content, file):
120 | print(content)
121 | print(content, file=file)
122 |
123 |
124 | def main():
125 | """Entrypoint.
126 | """
127 | train_data = tx.data.PairedTextData(hparams=config_data.train)
128 | val_data = tx.data.PairedTextData(hparams=config_data.val)
129 | test_data = tx.data.PairedTextData(hparams=config_data.test)
130 | # pdb.set_trace()
131 | data_iterator = tx.data.TrainTestDataIterator(
132 | train=train_data, val=val_data, test=test_data)
133 |
134 | batch = data_iterator.get_next()
135 |
136 | train_op, infer_outputs = build_model(batch, train_data)
137 |
138 | def _train_epoch(sess, epoch_no):
139 | data_iterator.switch_to_train_data(sess)
140 | training_log_file = \
141 | open(log_dir + 'training_log' + str(epoch_no) + '.txt', 'w',
142 | encoding='utf-8')
143 |
144 | step = 0
145 | while True:
146 | try:
147 | loss = sess.run(train_op)
148 | print("step={}, loss={:.4f}".format(step, loss),
149 | file=training_log_file)
150 | if step % config_data.observe_steps == 0:
151 | print("step={}, loss={:.4f}".format(step, loss))
152 | training_log_file.flush()
153 | step += 1
154 | except tf.errors.OutOfRangeError:
155 | break
156 |
157 | def _eval_epoch(sess, mode, epoch_no):
158 | if mode == 'val':
159 | data_iterator.switch_to_val_data(sess)
160 | else:
161 | data_iterator.switch_to_test_data(sess)
162 |
163 | refs, hypos = [], []
164 | while True:
165 | try:
166 | fetches = [
167 | batch['target_text'][:, 1:],
168 | infer_outputs.predicted_ids[:, :, 0]
169 | ]
170 | feed_dict = {
171 | tx.global_mode(): tf.estimator.ModeKeys.EVAL
172 | }
173 | target_texts_ori, output_ids = \
174 | sess.run(fetches, feed_dict=feed_dict)
175 |
176 | target_texts = tx.utils.strip_special_tokens(
177 | target_texts_ori.tolist(), is_token_list=True)
178 | target_texts = tx.utils.str_join(target_texts)
179 | output_texts = tx.utils.map_ids_to_strs(
180 | ids=output_ids, vocab=val_data.target_vocab)
181 |
182 | tx.utils.write_paired_text(
183 | target_texts, output_texts,
184 | log_dir + mode + '_results' + str(epoch_no) + '.txt',
185 | append=True, mode='h', sep=' ||| ')
186 |
187 | for hypo, ref in zip(output_texts, target_texts):
188 | if config_data.eval_metric == 'bleu':
189 | hypos.append(hypo)
190 | refs.append([ref])
191 | elif config_data.eval_metric == 'rouge':
192 | hypos.append(tx.utils.compat_as_text(hypo))
193 | refs.append(tx.utils.compat_as_text(ref))
194 | except tf.errors.OutOfRangeError:
195 | break
196 |
197 | if config_data.eval_metric == 'bleu':
198 | return tx.evals.corpus_bleu_moses(
199 | list_of_references=refs, hypotheses=hypos)
200 | elif config_data.eval_metric == 'rouge':
201 | rouge = Rouge()
202 | return rouge.get_scores(hyps=hypos, refs=refs, avg=True)
203 |
204 | def _calc_reward(score):
205 | """
206 | Return the bleu score or the sum of (Rouge-1, Rouge-2, Rouge-L).
207 | """
208 | if config_data.eval_metric == 'bleu':
209 | return score
210 | elif config_data.eval_metric == 'rouge':
211 | return sum([value['f'] for key, value in score.items()])
212 |
213 | config = tf.ConfigProto()
214 | config.gpu_options.per_process_gpu_memory_fraction = 0.4
215 | with tf.Session(config=config) as sess:
216 | sess.run(tf.global_variables_initializer())
217 | sess.run(tf.local_variables_initializer())
218 | sess.run(tf.tables_initializer())
219 |
220 | best_val_score = -1.
221 | scores_file = open(log_dir + 'scores.txt', 'w', encoding='utf-8')
222 | for i in range(config_data.num_epochs):
223 | _train_epoch(sess, i)
224 |
225 | val_score = _eval_epoch(sess, 'val', i)
226 | test_score = _eval_epoch(sess, 'test', i)
227 |
228 | best_val_score = max(best_val_score, _calc_reward(val_score))
229 |
230 | if config_data.eval_metric == 'bleu':
231 | print_stdout_and_file(
232 | 'val epoch={}, BLEU={:.4f}; best-ever={:.4f}'.format(
233 | i, val_score, best_val_score), file=scores_file)
234 |
235 | print_stdout_and_file(
236 | 'test epoch={}, BLEU={:.4f}'.format(i, test_score),
237 | file=scores_file)
238 | print_stdout_and_file('=' * 50, file=scores_file)
239 |
240 | elif config_data.eval_metric == 'rouge':
241 | print_stdout_and_file(
242 | 'valid epoch {}:'.format(i), file=scores_file)
243 | for key, value in val_score.items():
244 | print_stdout_and_file(
245 | '{}: {}'.format(key, value), file=scores_file)
246 | print_stdout_and_file('fsum: {}; best_val_fsum: {}'.format(
247 | _calc_reward(val_score), best_val_score), file=scores_file)
248 |
249 | print_stdout_and_file(
250 | 'test epoch {}:'.format(i), file=scores_file)
251 | for key, value in test_score.items():
252 | print_stdout_and_file(
253 | '{}: {}'.format(key, value), file=scores_file)
254 | print_stdout_and_file('=' * 110, file=scores_file)
255 |
256 | scores_file.flush()
257 |
258 |
259 | if __name__ == '__main__':
260 | main()
261 |
--------------------------------------------------------------------------------
/nmt/gnmt_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """GNMT attention sequence-to-sequence model with dynamic RNN support."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | # TODO(rzhao): Use tf.contrib.framework.nest once 1.3 is out.
24 | from tensorflow.python.util import nest
25 |
26 | from . import attention_model
27 | from . import model_helper
28 | from .utils import misc_utils as utils
29 |
30 | __all__ = ["GNMTModel"]
31 |
32 |
33 | class GNMTModel(attention_model.AttentionModel):
34 | """Sequence-to-sequence dynamic model with GNMT attention architecture.
35 | """
36 |
37 | def __init__(self,
38 | hparams,
39 | mode,
40 | iterator,
41 | source_vocab_table,
42 | target_vocab_table,
43 | reverse_target_vocab_table=None,
44 | scope=None,
45 | extra_args=None):
46 | super(GNMTModel, self).__init__(
47 | hparams=hparams,
48 | mode=mode,
49 | iterator=iterator,
50 | source_vocab_table=source_vocab_table,
51 | target_vocab_table=target_vocab_table,
52 | reverse_target_vocab_table=reverse_target_vocab_table,
53 | scope=scope,
54 | extra_args=extra_args)
55 |
56 | def _build_encoder(self, hparams):
57 | """Build a GNMT encoder."""
58 | if hparams.encoder_type == "uni" or hparams.encoder_type == "bi":
59 | return super(GNMTModel, self)._build_encoder(hparams)
60 |
61 | if hparams.encoder_type != "gnmt":
62 | raise ValueError("Unknown encoder_type %s" % hparams.encoder_type)
63 |
64 | # Build GNMT encoder.
65 | num_bi_layers = 1
66 | num_uni_layers = self.num_encoder_layers - num_bi_layers
67 | utils.print_out(" num_bi_layers = %d" % num_bi_layers)
68 | utils.print_out(" num_uni_layers = %d" % num_uni_layers)
69 |
70 | iterator = self.iterator
71 | source = iterator.source
72 | if self.time_major:
73 | source = tf.transpose(source)
74 |
75 | with tf.variable_scope("encoder") as scope:
76 | dtype = scope.dtype
77 |
78 | # Look up embedding, emp_inp: [max_time, batch_size, num_units]
79 | # when time_major = True
80 | encoder_emb_inp = tf.nn.embedding_lookup(self.embedding_encoder,
81 | source)
82 |
83 | # Execute _build_bidirectional_rnn from Model class
84 | bi_encoder_outputs, bi_encoder_state = self._build_bidirectional_rnn(
85 | inputs=encoder_emb_inp,
86 | sequence_length=iterator.source_sequence_length,
87 | dtype=dtype,
88 | hparams=hparams,
89 | num_bi_layers=num_bi_layers,
90 | num_bi_residual_layers=0, # no residual connection
91 | )
92 |
93 | uni_cell = model_helper.create_rnn_cell(
94 | unit_type=hparams.unit_type,
95 | num_units=hparams.num_units,
96 | num_layers=num_uni_layers,
97 | num_residual_layers=self.num_encoder_residual_layers,
98 | forget_bias=hparams.forget_bias,
99 | dropout=hparams.dropout,
100 | num_gpus=self.num_gpus,
101 | base_gpu=1,
102 | mode=self.mode,
103 | single_cell_fn=self.single_cell_fn)
104 |
105 | # encoder_outputs: size [max_time, batch_size, num_units]
106 | # when time_major = True
107 | encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
108 | uni_cell,
109 | bi_encoder_outputs,
110 | dtype=dtype,
111 | sequence_length=iterator.source_sequence_length,
112 | time_major=self.time_major)
113 |
114 | # Pass all encoder state except the first bi-directional layer's state to
115 | # decoder.
116 | encoder_state = (bi_encoder_state[1],) + (
117 | (encoder_state,) if num_uni_layers == 1 else encoder_state)
118 |
119 | return encoder_outputs, encoder_state, encoder_emb_inp
120 |
121 | def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
122 | source_sequence_length):
123 | """Build a RNN cell with GNMT attention architecture."""
124 | # Standard attention
125 | if hparams.attention_architecture == "standard":
126 | return super(GNMTModel, self)._build_decoder_cell(
127 | hparams, encoder_outputs, encoder_state, source_sequence_length)
128 |
129 | # GNMT attention
130 | attention_option = hparams.attention
131 | attention_architecture = hparams.attention_architecture
132 | num_units = hparams.num_units
133 | beam_width = hparams.beam_width
134 |
135 | dtype = tf.float32
136 |
137 | if self.time_major:
138 | memory = tf.transpose(encoder_outputs, [1, 0, 2])
139 | else:
140 | memory = encoder_outputs
141 |
142 | if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0:
143 | memory = tf.contrib.seq2seq.tile_batch(
144 | memory, multiplier=beam_width)
145 | source_sequence_length = tf.contrib.seq2seq.tile_batch(
146 | source_sequence_length, multiplier=beam_width)
147 | encoder_state = tf.contrib.seq2seq.tile_batch(
148 | encoder_state, multiplier=beam_width)
149 | batch_size = self.batch_size * beam_width
150 | else:
151 | batch_size = self.batch_size
152 |
153 | attention_mechanism = self.attention_mechanism_fn(
154 | attention_option, num_units, memory, source_sequence_length, self.mode)
155 |
156 | cell_list = model_helper._cell_list( # pylint: disable=protected-access
157 | unit_type=hparams.unit_type,
158 | num_units=num_units,
159 | num_layers=self.num_decoder_layers,
160 | num_residual_layers=self.num_decoder_residual_layers,
161 | forget_bias=hparams.forget_bias,
162 | dropout=hparams.dropout,
163 | num_gpus=self.num_gpus,
164 | mode=self.mode,
165 | single_cell_fn=self.single_cell_fn,
166 | residual_fn=gnmt_residual_fn
167 | )
168 |
169 | # Only wrap the bottom layer with the attention mechanism.
170 | attention_cell = cell_list.pop(0)
171 |
172 | # Only generate alignment in greedy INFER mode.
173 | alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and
174 | beam_width == 0)
175 | attention_cell = tf.contrib.seq2seq.AttentionWrapper(
176 | attention_cell,
177 | attention_mechanism,
178 | attention_layer_size=None, # don't use attention layer.
179 | output_attention=False,
180 | alignment_history=alignment_history,
181 | name="attention")
182 |
183 | if attention_architecture == "gnmt":
184 | cell = GNMTAttentionMultiCell(
185 | attention_cell, cell_list)
186 | elif attention_architecture == "gnmt_v2":
187 | cell = GNMTAttentionMultiCell(
188 | attention_cell, cell_list, use_new_attention=True)
189 | else:
190 | raise ValueError(
191 | "Unknown attention_architecture %s" % attention_architecture)
192 |
193 | if hparams.pass_hidden_state:
194 | decoder_initial_state = tuple(
195 | zs.clone(cell_state=es)
196 | if isinstance(zs, tf.contrib.seq2seq.AttentionWrapperState) else es
197 | for zs, es in zip(
198 | cell.zero_state(batch_size, dtype), encoder_state))
199 | else:
200 | decoder_initial_state = cell.zero_state(batch_size, dtype)
201 |
202 | return cell, decoder_initial_state
203 |
204 | def _get_infer_summary(self, hparams):
205 | # Standard attention
206 | if hparams.attention_architecture == "standard":
207 | return super(GNMTModel, self)._get_infer_summary(hparams)
208 |
209 | # GNMT attention
210 | if hparams.beam_width > 0:
211 | return tf.no_op()
212 | return attention_model._create_attention_images_summary(
213 | self.final_context_state[0])
214 |
215 |
216 | class GNMTAttentionMultiCell(tf.nn.rnn_cell.MultiRNNCell):
217 | """A MultiCell with GNMT attention style."""
218 |
219 | def __init__(self, attention_cell, cells, use_new_attention=False):
220 | """Creates a GNMTAttentionMultiCell.
221 |
222 | Args:
223 | attention_cell: An instance of AttentionWrapper.
224 | cells: A list of RNNCell wrapped with AttentionInputWrapper.
225 | use_new_attention: Whether to use the attention generated from current
226 | step bottom layer's output. Default is False.
227 | """
228 | cells = [attention_cell] + cells
229 | self.use_new_attention = use_new_attention
230 | super(GNMTAttentionMultiCell, self).__init__(cells, state_is_tuple=True)
231 |
232 | def __call__(self, inputs, state, scope=None):
233 | """Run the cell with bottom layer's attention copied to all upper layers."""
234 | if not nest.is_sequence(state):
235 | raise ValueError(
236 | "Expected state to be a tuple of length %d, but received: %s"
237 | % (len(self.state_size), state))
238 |
239 | with tf.variable_scope(scope or "multi_rnn_cell"):
240 | new_states = []
241 |
242 | with tf.variable_scope("cell_0_attention"):
243 | attention_cell = self._cells[0]
244 | attention_state = state[0]
245 | cur_inp, new_attention_state = attention_cell(inputs, attention_state)
246 | new_states.append(new_attention_state)
247 |
248 | for i in range(1, len(self._cells)):
249 | with tf.variable_scope("cell_%d" % i):
250 |
251 | cell = self._cells[i]
252 | cur_state = state[i]
253 |
254 | if self.use_new_attention:
255 | cur_inp = tf.concat([cur_inp, new_attention_state.attention], -1)
256 | else:
257 | cur_inp = tf.concat([cur_inp, attention_state.attention], -1)
258 |
259 | cur_inp, new_state = cell(cur_inp, cur_state)
260 | new_states.append(new_state)
261 |
262 | return cur_inp, tuple(new_states)
263 |
264 |
265 | def gnmt_residual_fn(inputs, outputs):
266 | """Residual function that handles different inputs and outputs inner dims.
267 |
268 | Args:
269 | inputs: cell inputs, this is actual inputs concatenated with the attention
270 | vector.
271 | outputs: cell outputs
272 |
273 | Returns:
274 | outputs + actual inputs
275 | """
276 | def split_input(inp, out):
277 | out_dim = out.get_shape().as_list()[-1]
278 | inp_dim = inp.get_shape().as_list()[-1]
279 | return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1)
280 | actual_inputs, _ = nest.map_structure(split_input, inputs, outputs)
281 | def assert_shape_match(inp, out):
282 | inp.get_shape().assert_is_compatible_with(out.get_shape())
283 | nest.assert_same_structure(actual_inputs, outputs)
284 | nest.map_structure(assert_shape_match, actual_inputs, outputs)
285 | return nest.map_structure(lambda inp, out: inp + out, actual_inputs, outputs)
286 |
--------------------------------------------------------------------------------
/nmt/scripts/rouge.py:
--------------------------------------------------------------------------------
1 | """ROUGE metric implementation.
2 |
3 | Copy from tf_seq2seq/seq2seq/metrics/rouge.py.
4 | This is a modified and slightly extended verison of
5 | https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py.
6 | """
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 | from __future__ import unicode_literals
12 |
13 | import itertools
14 | import numpy as np
15 |
16 | #pylint: disable=C0103
17 |
18 |
19 | def _get_ngrams(n, text):
20 | """Calcualtes n-grams.
21 |
22 | Args:
23 | n: which n-grams to calculate
24 | text: An array of tokens
25 |
26 | Returns:
27 | A set of n-grams
28 | """
29 | ngram_set = set()
30 | text_length = len(text)
31 | max_index_ngram_start = text_length - n
32 | for i in range(max_index_ngram_start + 1):
33 | ngram_set.add(tuple(text[i:i + n]))
34 | return ngram_set
35 |
36 |
37 | def _split_into_words(sentences):
38 | """Splits multiple sentences into words and flattens the result"""
39 | return list(itertools.chain(*[_.split(" ") for _ in sentences]))
40 |
41 |
42 | def _get_word_ngrams(n, sentences):
43 | """Calculates word n-grams for multiple sentences.
44 | """
45 | assert len(sentences) > 0
46 | assert n > 0
47 |
48 | words = _split_into_words(sentences)
49 | return _get_ngrams(n, words)
50 |
51 |
52 | def _len_lcs(x, y):
53 | """
54 | Returns the length of the Longest Common Subsequence between sequences x
55 | and y.
56 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
57 |
58 | Args:
59 | x: sequence of words
60 | y: sequence of words
61 |
62 | Returns
63 | integer: Length of LCS between x and y
64 | """
65 | table = _lcs(x, y)
66 | n, m = len(x), len(y)
67 | return table[n, m]
68 |
69 |
70 | def _lcs(x, y):
71 | """
72 | Computes the length of the longest common subsequence (lcs) between two
73 | strings. The implementation below uses a DP programming algorithm and runs
74 | in O(nm) time where n = len(x) and m = len(y).
75 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
76 |
77 | Args:
78 | x: collection of words
79 | y: collection of words
80 |
81 | Returns:
82 | Table of dictionary of coord and len lcs
83 | """
84 | n, m = len(x), len(y)
85 | table = dict()
86 | for i in range(n + 1):
87 | for j in range(m + 1):
88 | if i == 0 or j == 0:
89 | table[i, j] = 0
90 | elif x[i - 1] == y[j - 1]:
91 | table[i, j] = table[i - 1, j - 1] + 1
92 | else:
93 | table[i, j] = max(table[i - 1, j], table[i, j - 1])
94 | return table
95 |
96 |
97 | def _recon_lcs(x, y):
98 | """
99 | Returns the Longest Subsequence between x and y.
100 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
101 |
102 | Args:
103 | x: sequence of words
104 | y: sequence of words
105 |
106 | Returns:
107 | sequence: LCS of x and y
108 | """
109 | i, j = len(x), len(y)
110 | table = _lcs(x, y)
111 |
112 | def _recon(i, j):
113 | """private recon calculation"""
114 | if i == 0 or j == 0:
115 | return []
116 | elif x[i - 1] == y[j - 1]:
117 | return _recon(i - 1, j - 1) + [(x[i - 1], i)]
118 | elif table[i - 1, j] > table[i, j - 1]:
119 | return _recon(i - 1, j)
120 | else:
121 | return _recon(i, j - 1)
122 |
123 | recon_tuple = tuple(map(lambda x: x[0], _recon(i, j)))
124 | return recon_tuple
125 |
126 |
127 | def rouge_n(evaluated_sentences, reference_sentences, n=2):
128 | """
129 | Computes ROUGE-N of two text collections of sentences.
130 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/
131 | papers/rouge-working-note-v1.3.1.pdf
132 |
133 | Args:
134 | evaluated_sentences: The sentences that have been picked by the summarizer
135 | reference_sentences: The sentences from the referene set
136 | n: Size of ngram. Defaults to 2.
137 |
138 | Returns:
139 | A tuple (f1, precision, recall) for ROUGE-N
140 |
141 | Raises:
142 | ValueError: raises exception if a param has len <= 0
143 | """
144 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0:
145 | raise ValueError("Collections must contain at least 1 sentence.")
146 |
147 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences)
148 | reference_ngrams = _get_word_ngrams(n, reference_sentences)
149 | reference_count = len(reference_ngrams)
150 | evaluated_count = len(evaluated_ngrams)
151 |
152 | # Gets the overlapping ngrams between evaluated and reference
153 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
154 | overlapping_count = len(overlapping_ngrams)
155 |
156 | # Handle edge case. This isn't mathematically correct, but it's good enough
157 | if evaluated_count == 0:
158 | precision = 0.0
159 | else:
160 | precision = overlapping_count / evaluated_count
161 |
162 | if reference_count == 0:
163 | recall = 0.0
164 | else:
165 | recall = overlapping_count / reference_count
166 |
167 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))
168 |
169 | # return overlapping_count / reference_count
170 | return f1_score, precision, recall
171 |
172 |
173 | def _f_p_r_lcs(llcs, m, n):
174 | """
175 | Computes the LCS-based F-measure score
176 | Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/
177 | rouge-working-note-v1.3.1.pdf
178 |
179 | Args:
180 | llcs: Length of LCS
181 | m: number of words in reference summary
182 | n: number of words in candidate summary
183 |
184 | Returns:
185 | Float. LCS-based F-measure score
186 | """
187 | r_lcs = llcs / m
188 | p_lcs = llcs / n
189 | beta = p_lcs / (r_lcs + 1e-12)
190 | num = (1 + (beta**2)) * r_lcs * p_lcs
191 | denom = r_lcs + ((beta**2) * p_lcs)
192 | f_lcs = num / (denom + 1e-12)
193 | return f_lcs, p_lcs, r_lcs
194 |
195 |
196 | def rouge_l_sentence_level(evaluated_sentences, reference_sentences):
197 | """
198 | Computes ROUGE-L (sentence level) of two text collections of sentences.
199 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/
200 | rouge-working-note-v1.3.1.pdf
201 |
202 | Calculated according to:
203 | R_lcs = LCS(X,Y)/m
204 | P_lcs = LCS(X,Y)/n
205 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs)
206 |
207 | where:
208 | X = reference summary
209 | Y = Candidate summary
210 | m = length of reference summary
211 | n = length of candidate summary
212 |
213 | Args:
214 | evaluated_sentences: The sentences that have been picked by the summarizer
215 | reference_sentences: The sentences from the referene set
216 |
217 | Returns:
218 | A float: F_lcs
219 |
220 | Raises:
221 | ValueError: raises exception if a param has len <= 0
222 | """
223 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0:
224 | raise ValueError("Collections must contain at least 1 sentence.")
225 | reference_words = _split_into_words(reference_sentences)
226 | evaluated_words = _split_into_words(evaluated_sentences)
227 | m = len(reference_words)
228 | n = len(evaluated_words)
229 | lcs = _len_lcs(evaluated_words, reference_words)
230 | return _f_p_r_lcs(lcs, m, n)
231 |
232 |
233 | def _union_lcs(evaluated_sentences, reference_sentence):
234 | """
235 | Returns LCS_u(r_i, C) which is the LCS score of the union longest common
236 | subsequence between reference sentence ri and candidate summary C. For example
237 | if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 and
238 | c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 is
239 | "w1 w2" and the longest common subsequence of r_i and c2 is "w1 w3 w5". The
240 | union longest common subsequence of r_i, c1, and c2 is "w1 w2 w3 w5" and
241 | LCS_u(r_i, C) = 4/5.
242 |
243 | Args:
244 | evaluated_sentences: The sentences that have been picked by the summarizer
245 | reference_sentence: One of the sentences in the reference summaries
246 |
247 | Returns:
248 | float: LCS_u(r_i, C)
249 |
250 | ValueError:
251 | Raises exception if a param has len <= 0
252 | """
253 | if len(evaluated_sentences) <= 0:
254 | raise ValueError("Collections must contain at least 1 sentence.")
255 |
256 | lcs_union = set()
257 | reference_words = _split_into_words([reference_sentence])
258 | combined_lcs_length = 0
259 | for eval_s in evaluated_sentences:
260 | evaluated_words = _split_into_words([eval_s])
261 | lcs = set(_recon_lcs(reference_words, evaluated_words))
262 | combined_lcs_length += len(lcs)
263 | lcs_union = lcs_union.union(lcs)
264 |
265 | union_lcs_count = len(lcs_union)
266 | union_lcs_value = union_lcs_count / combined_lcs_length
267 | return union_lcs_value
268 |
269 |
270 | def rouge_l_summary_level(evaluated_sentences, reference_sentences):
271 | """
272 | Computes ROUGE-L (summary level) of two text collections of sentences.
273 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/
274 | rouge-working-note-v1.3.1.pdf
275 |
276 | Calculated according to:
277 | R_lcs = SUM(1, u)[LCS(r_i,C)]/m
278 | P_lcs = SUM(1, u)[LCS(r_i,C)]/n
279 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs)
280 |
281 | where:
282 | SUM(i,u) = SUM from i through u
283 | u = number of sentences in reference summary
284 | C = Candidate summary made up of v sentences
285 | m = number of words in reference summary
286 | n = number of words in candidate summary
287 |
288 | Args:
289 | evaluated_sentences: The sentences that have been picked by the summarizer
290 | reference_sentence: One of the sentences in the reference summaries
291 |
292 | Returns:
293 | A float: F_lcs
294 |
295 | Raises:
296 | ValueError: raises exception if a param has len <= 0
297 | """
298 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0:
299 | raise ValueError("Collections must contain at least 1 sentence.")
300 |
301 | # total number of words in reference sentences
302 | m = len(_split_into_words(reference_sentences))
303 |
304 | # total number of words in evaluated sentences
305 | n = len(_split_into_words(evaluated_sentences))
306 |
307 | union_lcs_sum_across_all_references = 0
308 | for ref_s in reference_sentences:
309 | union_lcs_sum_across_all_references += _union_lcs(evaluated_sentences,
310 | ref_s)
311 | return _f_p_r_lcs(union_lcs_sum_across_all_references, m, n)
312 |
313 |
314 | def rouge(hypotheses, references):
315 | """Calculates average rouge scores for a list of hypotheses and
316 | references"""
317 |
318 | # Filter out hyps that are of 0 length
319 | # hyps_and_refs = zip(hypotheses, references)
320 | # hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0]
321 | # hypotheses, references = zip(*hyps_and_refs)
322 |
323 | # Calculate ROUGE-1 F1, precision, recall scores
324 | rouge_1 = [
325 | rouge_n([hyp], [ref], 1) for hyp, ref in zip(hypotheses, references)
326 | ]
327 | rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1))
328 |
329 | # Calculate ROUGE-2 F1, precision, recall scores
330 | rouge_2 = [
331 | rouge_n([hyp], [ref], 2) for hyp, ref in zip(hypotheses, references)
332 | ]
333 | rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2))
334 |
335 | # Calculate ROUGE-L F1, precision, recall scores
336 | rouge_l = [
337 | rouge_l_sentence_level([hyp], [ref])
338 | for hyp, ref in zip(hypotheses, references)
339 | ]
340 | rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l))
341 |
342 | return {
343 | "rouge_1/f_score": rouge_1_f,
344 | "rouge_1/r_score": rouge_1_r,
345 | "rouge_1/p_score": rouge_1_p,
346 | "rouge_2/f_score": rouge_2_f,
347 | "rouge_2/r_score": rouge_2_r,
348 | "rouge_2/p_score": rouge_2_p,
349 | "rouge_l/f_score": rouge_l_f,
350 | "rouge_l/r_score": rouge_l_r,
351 | "rouge_l/p_score": rouge_l_p,
352 | }
353 |
--------------------------------------------------------------------------------
/texar/utils/raml_samples_generation/process_samples.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from nltk.translate.bleu_score import sentence_bleu
3 | from nltk.translate.bleu_score import SmoothingFunction
4 | import sys
5 | import re
6 | import argparse
7 | import torch
8 | from util import read_corpus
9 | import numpy as np
10 | from scipy.misc import comb
11 | from vocab import Vocab, VocabEntry
12 | import math
13 | from rouge import Rouge
14 |
15 |
16 | def is_valid_sample(sent):
17 | tokens = sent.split(' ')
18 | return len(tokens) >= 1 and len(tokens) < 50
19 |
20 |
21 | def sample_from_model(args):
22 | para_data = args.parallel_data
23 | sample_file = args.sample_file
24 | output = args.output
25 |
26 | tgt_sent_pattern = re.compile('^\[(\d+)\] (.*?)$')
27 | para_data = [l.strip().split(' ||| ') for l in open(para_data)]
28 |
29 | f_out = open(output, 'w')
30 | f = open(sample_file)
31 | f.readline()
32 | for src_sent, tgt_sent in para_data:
33 | line = f.readline().strip()
34 | assert line.startswith('****')
35 | line = f.readline().strip()
36 | print(line)
37 | assert line.startswith('target:')
38 |
39 | tgt_sent2 = line[len('target:'):]
40 | assert tgt_sent == tgt_sent2
41 |
42 | line = f.readline().strip() # samples
43 |
44 | tgt_sent = ' '.join(tgt_sent.split(' ')[1:-1])
45 | tgt_samples = set()
46 | for i in range(1, 101):
47 | line = f.readline().rstrip('\n')
48 | m = tgt_sent_pattern.match(line)
49 |
50 | assert m, line
51 | assert int(m.group(1)) == i
52 |
53 | sampled_tgt_sent = m.group(2).strip()
54 |
55 | if is_valid_sample(sampled_tgt_sent):
56 | tgt_samples.add(sampled_tgt_sent)
57 |
58 | line = f.readline().strip()
59 | assert line.startswith('****')
60 |
61 | tgt_samples.add(tgt_sent)
62 | tgt_samples = list(tgt_samples)
63 |
64 | assert len(tgt_samples) > 0
65 |
66 | tgt_ref_tokens = tgt_sent.split(' ')
67 | bleu_scores = []
68 | for tgt_sample in tgt_samples:
69 | bleu_score = sentence_bleu([tgt_ref_tokens], tgt_sample.split(' '))
70 | bleu_scores.append(bleu_score)
71 |
72 | tgt_ranks = sorted(range(len(tgt_samples)), key=lambda i: bleu_scores[i], reverse=True)
73 |
74 | print('%d samples' % len(tgt_samples))
75 |
76 | print('*' * 50, file=f_out)
77 | print('source: ' + src_sent, file=f_out)
78 | print('%d samples' % len(tgt_samples), file=f_out)
79 | for i in tgt_ranks:
80 | print('%s ||| %f' % (tgt_samples[i], bleu_scores[i]), file=f_out)
81 | print('*' * 50, file=f_out)
82 |
83 | f_out.close()
84 |
85 |
86 | def get_new_ngram(ngram, n, vocab):
87 | """
88 | replace ngram `ngram` with a newly sampled ngram of the same length
89 | """
90 |
91 | new_ngram_wids = [np.random.randint(3, len(vocab)) for i in range(n)]
92 | new_ngram = [vocab.id2word[wid] for wid in new_ngram_wids]
93 |
94 | return new_ngram
95 |
96 |
97 | def sample_ngram(args):
98 | src_sents = read_corpus(args.src, 'src')
99 | tgt_sents = read_corpus(args.tgt, 'src') # do not read in and
100 | f_out = open(args.output, 'w')
101 |
102 | vocab = torch.load(args.vocab)
103 | tgt_vocab = vocab.tgt
104 |
105 | smooth_bleu = args.smooth_bleu
106 | sm_func = None
107 | if smooth_bleu:
108 | sm_func = SmoothingFunction().method3
109 |
110 | for src_sent, tgt_sent in zip(src_sents, tgt_sents):
111 | src_sent = ' '.join(src_sent)
112 |
113 | tgt_len = len(tgt_sent)
114 | tgt_samples = []
115 | tgt_samples_distort_rates = [] # how many unigrams are replaced
116 |
117 | # generate 100 samples
118 |
119 | # append itself
120 | tgt_samples.append(tgt_sent)
121 | tgt_samples_distort_rates.append(0)
122 |
123 | for sid in range(args.sample_size - 1):
124 | n = np.random.randint(1, min(tgt_len, args.max_ngram_size + 1)) # we do not replace the last token: it must be a period!
125 |
126 | idx = np.random.randint(tgt_len - n)
127 | ngram = tgt_sent[idx: idx+n]
128 | new_ngram = get_new_ngram(ngram, n, tgt_vocab)
129 |
130 | sampled_tgt_sent = list(tgt_sent)
131 | sampled_tgt_sent[idx: idx+n] = new_ngram
132 |
133 | # compute the probability of this sample
134 | # prob = 1. / args.max_ngram_size * 1. / (tgt_len - 1 + n) * 1 / (len(tgt_vocab) ** n)
135 |
136 | tgt_samples.append(sampled_tgt_sent)
137 | tgt_samples_distort_rates.append(n)
138 |
139 | # compute bleu scores or edit distances and rank the samples by bleu scores
140 | rewards = []
141 | for tgt_sample, tgt_sample_distort_rate in zip(tgt_samples, tgt_samples_distort_rates):
142 | if args.reward == 'bleu':
143 | reward = sentence_bleu([tgt_sent], tgt_sample, smoothing_function=sm_func)
144 | elif args.reward == 'rouge':
145 | rouge = Rouge()
146 | scores = rouge.get_scores(hyps=[' '.join(tgt_sample).decode('utf-8')], refs=[' '.join(tgt_sent).decode('utf-8')], avg=True)
147 | reward = sum([value['f'] for key, value in scores.items()])
148 | else:
149 | reward = -tgt_sample_distort_rate
150 |
151 | rewards.append(reward)
152 |
153 | tgt_ranks = sorted(range(len(tgt_samples)), key=lambda i: rewards[i], reverse=True)
154 | # convert list of tokens into a string
155 | tgt_samples = [' '.join(tgt_sample) for tgt_sample in tgt_samples]
156 |
157 | print('*' * 50, file=f_out)
158 | print('source: ' + src_sent, file=f_out)
159 | print('%d samples' % len(tgt_samples), file=f_out)
160 | for i in tgt_ranks:
161 | print('%s ||| %f' % (tgt_samples[i], rewards[i]), file=f_out)
162 | print('*' * 50, file=f_out)
163 |
164 | f_out.close()
165 |
166 |
167 | def sample_ngram_adapt(args):
168 | src_sents = read_corpus(args.src, 'src')
169 | tgt_sents = read_corpus(args.tgt, 'src') # do not read in and
170 | f_out = open(args.output, 'w')
171 |
172 | vocab = torch.load(args.vocab)
173 | tgt_vocab = vocab.tgt
174 |
175 | max_len = max([len(tgt_sent) for tgt_sent in tgt_sents]) + 1
176 |
177 | for src_sent, tgt_sent in zip(src_sents, tgt_sents):
178 | src_sent = ' '.join(src_sent)
179 |
180 | tgt_len = len(tgt_sent)
181 | tgt_samples = []
182 |
183 | # generate 100 samples
184 |
185 | # append itself
186 | tgt_samples.append(tgt_sent)
187 |
188 | for sid in range(args.sample_size - 1):
189 | max_n = min(tgt_len - 1, 4)
190 | bias_n = int(max_n * tgt_len / max_len) + 1
191 | assert 1 <= bias_n <= 4, 'bias_n={}, not in [1,4], max_n={}, tgt_len={}, max_len={}'.format(bias_n, max_n, tgt_len, max_len)
192 |
193 | p = [1.0/(max_n + 5)] * max_n
194 | p[bias_n - 1] = 1 - p[0] * (max_n - 1)
195 | assert abs(sum(p) - 1) < 1e-10, 'sum(p) != 1'
196 |
197 | n = np.random.choice(np.arange(1, int(max_n + 1)), p=p) # we do not replace the last token: it must be a period!
198 | assert n < tgt_len, 'n={}, tgt_len={}'.format(n, tgt_len)
199 |
200 | idx = np.random.randint(tgt_len - n)
201 | ngram = tgt_sent[idx: idx+n]
202 | new_ngram = get_new_ngram(ngram, n, tgt_vocab)
203 |
204 | sampled_tgt_sent = list(tgt_sent)
205 | sampled_tgt_sent[idx: idx+n] = new_ngram
206 |
207 | tgt_samples.append(sampled_tgt_sent)
208 |
209 | # compute bleu scores and rank the samples by bleu scores
210 | bleu_scores = []
211 | for tgt_sample in tgt_samples:
212 | bleu_score = sentence_bleu([tgt_sent], tgt_sample)
213 | bleu_scores.append(bleu_score)
214 |
215 | tgt_ranks = sorted(range(len(tgt_samples)), key=lambda i: bleu_scores[i], reverse=True)
216 | # convert list of tokens into a string
217 | tgt_samples = [' '.join(tgt_sample) for tgt_sample in tgt_samples]
218 |
219 | print('*' * 50, file=f_out)
220 | print('source: ' + src_sent, file=f_out)
221 | print('%d samples' % len(tgt_samples), file=f_out)
222 | for i in tgt_ranks:
223 | print('%s ||| %f' % (tgt_samples[i], bleu_scores[i]), file=f_out)
224 | print('*' * 50, file=f_out)
225 |
226 | f_out.close()
227 |
228 |
229 | def sample_from_hamming_distance_payoff_distribution(args):
230 | src_sents = read_corpus(args.src, 'src')
231 | tgt_sents = read_corpus(args.tgt, 'src') # do not read in and
232 | f_out = open(args.output, 'w')
233 |
234 | vocab = torch.load(args.vocab)
235 | tgt_vocab = vocab.tgt
236 |
237 | payoff_prob, Z_qs = generate_hamming_distance_payoff_distribution(max(len(sent) for sent in tgt_sents),
238 | vocab_size=len(vocab.tgt),
239 | tau=args.temp)
240 |
241 | for src_sent, tgt_sent in zip(src_sents, tgt_sents):
242 | tgt_samples = [] # make sure the ground truth y* is in the samples
243 | tgt_sent_len = len(tgt_sent) - 3 # remove and and ending period .
244 | tgt_ref_tokens = tgt_sent[1:-1]
245 | bleu_scores = []
246 |
247 | # sample an edit distances
248 | e_samples = np.random.choice(range(tgt_sent_len + 1), p=payoff_prob[tgt_sent_len], size=args.sample_size,
249 | replace=True)
250 |
251 | for i, e in enumerate(e_samples):
252 | if e > 0:
253 | # sample a new tgt_sent $y$
254 | old_word_pos = np.random.choice(range(1, tgt_sent_len + 1), size=e, replace=False)
255 | new_words = [vocab.tgt.id2word[wid] for wid in np.random.randint(3, len(vocab.tgt), size=e)]
256 | new_tgt_sent = list(tgt_sent)
257 | for pos, word in zip(old_word_pos, new_words):
258 | new_tgt_sent[pos] = word
259 |
260 | bleu_score = sentence_bleu([tgt_ref_tokens], new_tgt_sent[1:-1])
261 | bleu_scores.append(bleu_score)
262 | else:
263 | new_tgt_sent = list(tgt_sent)
264 | bleu_scores.append(1.)
265 |
266 | # print('y: %s' % ' '.join(new_tgt_sent))
267 | tgt_samples.append(new_tgt_sent)
268 |
269 |
270 | def generate_hamming_distance_payoff_distribution(max_sent_len, vocab_size, tau=1.):
271 | """compute the q distribution for Hamming Distance (substitution only) as in the RAML paper"""
272 | probs = dict()
273 | Z_qs = dict()
274 | for sent_len in range(1, max_sent_len + 1):
275 | counts = [1.] # e = 0, count = 1
276 | for e in range(1, sent_len + 1):
277 | # apply the rescaling trick as in https://gist.github.com/norouzi/8c4d244922fa052fa8ec18d8af52d366
278 | count = comb(sent_len, e) * math.exp(-e / tau) * ((vocab_size - 1) ** (e - e / tau))
279 | counts.append(count)
280 |
281 | Z_qs[sent_len] = Z_q = sum(counts)
282 | prob = [count / Z_q for count in counts]
283 | probs[sent_len] = prob
284 |
285 | # print('sent_len=%d, %s' % (sent_len, prob))
286 |
287 | return probs, Z_qs
288 |
289 |
290 | if __name__ == '__main__':
291 | parser = argparse.ArgumentParser()
292 | parser.add_argument('--mode', choices=['sample_from_model', 'sample_ngram_adapt', 'sample_ngram'], required=True)
293 | parser.add_argument('--vocab', type=str)
294 | parser.add_argument('--src', type=str)
295 | parser.add_argument('--tgt', type=str)
296 | parser.add_argument('--parallel_data', type=str)
297 | parser.add_argument('--sample_file', type=str)
298 | parser.add_argument('--output', type=str, required=True)
299 | parser.add_argument('--sample_size', type=int, default=100)
300 | parser.add_argument('--reward', choices=['bleu', 'edit_dist', 'rouge'], default='bleu')
301 | parser.add_argument('--max_ngram_size', type=int, default=4)
302 | parser.add_argument('--temp', type=float, default=0.5)
303 | parser.add_argument('--smooth_bleu', action='store_true', default=False)
304 |
305 | args = parser.parse_args()
306 |
307 | if args.mode == 'sample_ngram':
308 | sample_ngram(args)
309 | elif args.mode == 'sample_from_model':
310 | sample_from_model(args)
311 | elif args.mode == 'sample_ngram_adapt':
312 | sample_ngram_adapt(args)
313 |
--------------------------------------------------------------------------------
/nmt/utils/iterator_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for iterator_utils.py"""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow as tf
23 |
24 | from tensorflow.python.ops import lookup_ops
25 |
26 | from ..utils import iterator_utils
27 |
28 |
29 | class IteratorUtilsTest(tf.test.TestCase):
30 |
31 | def testGetIterator(self):
32 | tf.set_random_seed(1)
33 | tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor(
34 | tf.constant(["a", "b", "c", "eos", "sos"]))
35 | src_dataset = tf.data.Dataset.from_tensor_slices(
36 | tf.constant(["f e a g", "c c a", "d", "c a"]))
37 | tgt_dataset = tf.data.Dataset.from_tensor_slices(
38 | tf.constant(["c c", "a b", "", "b c"]))
39 | hparams = tf.contrib.training.HParams(
40 | random_seed=3,
41 | num_buckets=5,
42 | eos="eos",
43 | sos="sos")
44 | batch_size = 2
45 | src_max_len = 3
46 | iterator = iterator_utils.get_iterator(
47 | src_dataset=src_dataset,
48 | tgt_dataset=tgt_dataset,
49 | src_vocab_table=src_vocab_table,
50 | tgt_vocab_table=tgt_vocab_table,
51 | batch_size=batch_size,
52 | sos=hparams.sos,
53 | eos=hparams.eos,
54 | random_seed=hparams.random_seed,
55 | num_buckets=hparams.num_buckets,
56 | src_max_len=src_max_len,
57 | reshuffle_each_iteration=False)
58 | table_initializer = tf.tables_initializer()
59 | source = iterator.source
60 | target_input = iterator.target_input
61 | target_output = iterator.target_output
62 | src_seq_len = iterator.source_sequence_length
63 | tgt_seq_len = iterator.target_sequence_length
64 | self.assertEqual([None, None], source.shape.as_list())
65 | self.assertEqual([None, None], target_input.shape.as_list())
66 | self.assertEqual([None, None], target_output.shape.as_list())
67 | self.assertEqual([None], src_seq_len.shape.as_list())
68 | self.assertEqual([None], tgt_seq_len.shape.as_list())
69 | with self.test_session() as sess:
70 | sess.run(table_initializer)
71 | sess.run(iterator.initializer)
72 |
73 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (
74 | sess.run((source, src_seq_len, target_input, target_output,
75 | tgt_seq_len)))
76 | self.assertAllEqual(
77 | [[-1, -1, 0], # "f" == unknown, "e" == unknown, a
78 | [2, 0, 3]], # c a eos -- eos is padding
79 | source_v)
80 | self.assertAllEqual([3, 2], src_len_v)
81 | self.assertAllEqual(
82 | [[4, 2, 2], # sos c c
83 | [4, 1, 2]], # sos b c
84 | target_input_v)
85 | self.assertAllEqual(
86 | [[2, 2, 3], # c c eos
87 | [1, 2, 3]], # b c eos
88 | target_output_v)
89 | self.assertAllEqual([3, 3], tgt_len_v)
90 |
91 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (
92 | sess.run((source, src_seq_len, target_input, target_output,
93 | tgt_seq_len)))
94 | self.assertAllEqual(
95 | [[2, 2, 0]], # c c a
96 | source_v)
97 | self.assertAllEqual([3], src_len_v)
98 | self.assertAllEqual(
99 | [[4, 0, 1]], # sos a b
100 | target_input_v)
101 | self.assertAllEqual(
102 | [[0, 1, 3]], # a b eos
103 | target_output_v)
104 | self.assertAllEqual([3], tgt_len_v)
105 |
106 | with self.assertRaisesOpError("End of sequence"):
107 | sess.run(source)
108 |
109 | def testGetIteratorWithShard(self):
110 | tf.set_random_seed(1)
111 | tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor(
112 | tf.constant(["a", "b", "c", "eos", "sos"]))
113 | src_dataset = tf.data.Dataset.from_tensor_slices(
114 | tf.constant(["c c a", "f e a g", "d", "c a"]))
115 | tgt_dataset = tf.data.Dataset.from_tensor_slices(
116 | tf.constant(["a b", "c c", "", "b c"]))
117 | hparams = tf.contrib.training.HParams(
118 | random_seed=3,
119 | num_buckets=5,
120 | eos="eos",
121 | sos="sos")
122 | batch_size = 2
123 | src_max_len = 3
124 | iterator = iterator_utils.get_iterator(
125 | src_dataset=src_dataset,
126 | tgt_dataset=tgt_dataset,
127 | src_vocab_table=src_vocab_table,
128 | tgt_vocab_table=tgt_vocab_table,
129 | batch_size=batch_size,
130 | sos=hparams.sos,
131 | eos=hparams.eos,
132 | random_seed=hparams.random_seed,
133 | num_buckets=hparams.num_buckets,
134 | src_max_len=src_max_len,
135 | num_shards=2,
136 | shard_index=1,
137 | reshuffle_each_iteration=False)
138 | table_initializer = tf.tables_initializer()
139 | source = iterator.source
140 | target_input = iterator.target_input
141 | target_output = iterator.target_output
142 | src_seq_len = iterator.source_sequence_length
143 | tgt_seq_len = iterator.target_sequence_length
144 | self.assertEqual([None, None], source.shape.as_list())
145 | self.assertEqual([None, None], target_input.shape.as_list())
146 | self.assertEqual([None, None], target_output.shape.as_list())
147 | self.assertEqual([None], src_seq_len.shape.as_list())
148 | self.assertEqual([None], tgt_seq_len.shape.as_list())
149 | with self.test_session() as sess:
150 | sess.run(table_initializer)
151 | sess.run(iterator.initializer)
152 |
153 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (
154 | sess.run((source, src_seq_len, target_input, target_output,
155 | tgt_seq_len)))
156 | self.assertAllEqual(
157 | [[-1, -1, 0], # "f" == unknown, "e" == unknown, a
158 | [2, 0, 3]], # c a eos -- eos is padding
159 | source_v)
160 | self.assertAllEqual([3, 2], src_len_v)
161 | self.assertAllEqual(
162 | [[4, 2, 2], # sos c c
163 | [4, 1, 2]], # sos b c
164 | target_input_v)
165 | self.assertAllEqual(
166 | [[2, 2, 3], # c c eos
167 | [1, 2, 3]], # b c eos
168 | target_output_v)
169 | self.assertAllEqual([3, 3], tgt_len_v)
170 |
171 | with self.assertRaisesOpError("End of sequence"):
172 | sess.run(source)
173 |
174 | def testGetIteratorWithSkipCount(self):
175 | tf.set_random_seed(1)
176 | tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor(
177 | tf.constant(["a", "b", "c", "eos", "sos"]))
178 | src_dataset = tf.data.Dataset.from_tensor_slices(
179 | tf.constant(["c a", "c c a", "d", "f e a g"]))
180 | tgt_dataset = tf.data.Dataset.from_tensor_slices(
181 | tf.constant(["b c", "a b", "", "c c"]))
182 | hparams = tf.contrib.training.HParams(
183 | random_seed=3,
184 | num_buckets=5,
185 | eos="eos",
186 | sos="sos")
187 | batch_size = 2
188 | src_max_len = 3
189 | skip_count = tf.placeholder(shape=(), dtype=tf.int64)
190 | iterator = iterator_utils.get_iterator(
191 | src_dataset=src_dataset,
192 | tgt_dataset=tgt_dataset,
193 | src_vocab_table=src_vocab_table,
194 | tgt_vocab_table=tgt_vocab_table,
195 | batch_size=batch_size,
196 | sos=hparams.sos,
197 | eos=hparams.eos,
198 | random_seed=hparams.random_seed,
199 | num_buckets=hparams.num_buckets,
200 | src_max_len=src_max_len,
201 | skip_count=skip_count,
202 | reshuffle_each_iteration=False)
203 | table_initializer = tf.tables_initializer()
204 | source = iterator.source
205 | target_input = iterator.target_input
206 | target_output = iterator.target_output
207 | src_seq_len = iterator.source_sequence_length
208 | tgt_seq_len = iterator.target_sequence_length
209 | self.assertEqual([None, None], source.shape.as_list())
210 | self.assertEqual([None, None], target_input.shape.as_list())
211 | self.assertEqual([None, None], target_output.shape.as_list())
212 | self.assertEqual([None], src_seq_len.shape.as_list())
213 | self.assertEqual([None], tgt_seq_len.shape.as_list())
214 | with self.test_session() as sess:
215 | sess.run(table_initializer)
216 | sess.run(iterator.initializer, feed_dict={skip_count: 3})
217 |
218 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (
219 | sess.run((source, src_seq_len, target_input, target_output,
220 | tgt_seq_len)))
221 | self.assertAllEqual(
222 | [[-1, -1, 0]], # "f" == unknown, "e" == unknown, a
223 | source_v)
224 | self.assertAllEqual([3], src_len_v)
225 | self.assertAllEqual(
226 | [[4, 2, 2]], # sos c c
227 | target_input_v)
228 | self.assertAllEqual(
229 | [[2, 2, 3]], # c c eos
230 | target_output_v)
231 | self.assertAllEqual([3], tgt_len_v)
232 |
233 | with self.assertRaisesOpError("End of sequence"):
234 | sess.run(source)
235 |
236 | # Re-init iterator with skip_count=0.
237 | sess.run(iterator.initializer, feed_dict={skip_count: 0})
238 |
239 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (
240 | sess.run((source, src_seq_len, target_input, target_output,
241 | tgt_seq_len)))
242 | self.assertAllEqual(
243 | [[2, 0, 3], # c a eos -- eos is padding
244 | [-1, -1, 0]], # "f" == unknown, "e" == unknown, a
245 | source_v)
246 | self.assertAllEqual([2, 3], src_len_v)
247 | self.assertAllEqual(
248 | [[4, 1, 2], # sos b c
249 | [4, 2, 2]], # sos c c
250 | target_input_v)
251 | self.assertAllEqual(
252 | [[1, 2, 3], # b c eos
253 | [2, 2, 3]], # c c eos
254 | target_output_v)
255 | self.assertAllEqual([3, 3], tgt_len_v)
256 |
257 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (
258 | sess.run((source, src_seq_len, target_input, target_output,
259 | tgt_seq_len)))
260 | self.assertAllEqual(
261 | [[2, 2, 0]], # c c a
262 | source_v)
263 | self.assertAllEqual([3], src_len_v)
264 | self.assertAllEqual(
265 | [[4, 0, 1]], # sos a b
266 | target_input_v)
267 | self.assertAllEqual(
268 | [[0, 1, 3]], # a b eos
269 | target_output_v)
270 | self.assertAllEqual([3], tgt_len_v)
271 |
272 | with self.assertRaisesOpError("End of sequence"):
273 | sess.run(source)
274 |
275 |
276 | def testGetInferIterator(self):
277 | src_vocab_table = lookup_ops.index_table_from_tensor(
278 | tf.constant(["a", "b", "c", "eos", "sos"]))
279 | src_dataset = tf.data.Dataset.from_tensor_slices(
280 | tf.constant(["c c a", "c a", "d", "f e a g"]))
281 | hparams = tf.contrib.training.HParams(
282 | random_seed=3,
283 | eos="eos",
284 | sos="sos")
285 | batch_size = 2
286 | src_max_len = 3
287 | iterator = iterator_utils.get_infer_iterator(
288 | src_dataset=src_dataset,
289 | src_vocab_table=src_vocab_table,
290 | batch_size=batch_size,
291 | eos=hparams.eos,
292 | src_max_len=src_max_len)
293 | table_initializer = tf.tables_initializer()
294 | source = iterator.source
295 | seq_len = iterator.source_sequence_length
296 | self.assertEqual([None, None], source.shape.as_list())
297 | self.assertEqual([None], seq_len.shape.as_list())
298 | with self.test_session() as sess:
299 | sess.run(table_initializer)
300 | sess.run(iterator.initializer)
301 |
302 | (source_v, seq_len_v) = sess.run((source, seq_len))
303 | self.assertAllEqual(
304 | [[2, 2, 0], # c c a
305 | [2, 0, 3]], # c a eos
306 | source_v)
307 | self.assertAllEqual([3, 2], seq_len_v)
308 |
309 | (source_v, seq_len_v) = sess.run((source, seq_len))
310 | self.assertAllEqual(
311 | [[-1, 3, 3], # "d" == unknown, eos eos
312 | [-1, -1, 0]], # "f" == unknown, "e" == unknown, a
313 | source_v)
314 | self.assertAllEqual([1, 3], seq_len_v)
315 |
316 | with self.assertRaisesOpError("End of sequence"):
317 | sess.run((source, seq_len))
318 |
319 |
320 | if __name__ == "__main__":
321 | tf.test.main()
322 |
--------------------------------------------------------------------------------
/nmt/model_helper.py:
--------------------------------------------------------------------------------
1 | """Utility functions for building models."""
2 | from __future__ import print_function
3 |
4 | import collections
5 | import six
6 | import os
7 | import time
8 |
9 | import numpy as np
10 | import tensorflow as tf
11 |
12 | from tensorflow.python.ops import lookup_ops
13 |
14 | from .utils import iterator_utils
15 | from .utils import misc_utils as utils
16 | from .utils import vocab_utils
17 |
18 |
19 | __all__ = [
20 | "get_initializer", "get_device_str", "create_train_model",
21 | "create_eval_model", "create_infer_model",
22 | "create_emb_for_encoder_and_decoder", "create_rnn_cell", "gradient_clip",
23 | "create_or_load_model", "load_model", "avg_checkpoints",
24 | "compute_perplexity"
25 | ]
26 |
27 | # If a vocab size is greater than this value, put the embedding on cpu instead
28 | VOCAB_SIZE_THRESHOLD_CPU = 50000
29 |
30 |
31 | def get_initializer(init_op, seed=None, init_weight=None):
32 | """Create an initializer. init_weight is only for uniform."""
33 | if init_op == "uniform":
34 | assert init_weight
35 | return tf.random_uniform_initializer(
36 | -init_weight, init_weight, seed=seed)
37 | elif init_op == "glorot_normal":
38 | return tf.keras.initializers.glorot_normal(
39 | seed=seed)
40 | elif init_op == "glorot_uniform":
41 | return tf.keras.initializers.glorot_uniform(
42 | seed=seed)
43 | else:
44 | raise ValueError("Unknown init_op %s" % init_op)
45 |
46 |
47 | def get_device_str(device_id, num_gpus):
48 | """Return a device string for multi-GPU setup."""
49 | if num_gpus == 0:
50 | return "/cpu:0"
51 | device_str_output = "/gpu:%d" % (device_id % num_gpus)
52 | return device_str_output
53 |
54 |
55 | class ExtraArgs(collections.namedtuple(
56 | "ExtraArgs", ("single_cell_fn", "model_device_fn",
57 | "attention_mechanism_fn"))):
58 | pass
59 |
60 |
61 | class TrainModel(
62 | collections.namedtuple("TrainModel", ("graph", "model", "iterator",
63 | "skip_count_placeholder"))):
64 | pass
65 |
66 |
67 | def create_train_model(
68 | model_creator, hparams, scope=None, num_workers=1, jobid=0,
69 | extra_args=None):
70 | """Create train graph, model, and iterator."""
71 | src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
72 | tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
73 | src_vocab_file = hparams.src_vocab_file
74 | tgt_vocab_file = hparams.tgt_vocab_file
75 |
76 | graph = tf.Graph()
77 |
78 | with graph.as_default(), tf.container(scope or "train"):
79 | src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
80 | src_vocab_file, tgt_vocab_file, hparams.share_vocab)
81 |
82 | src_dataset = tf.data.TextLineDataset(src_file)
83 | tgt_dataset = tf.data.TextLineDataset(tgt_file)
84 | skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)
85 |
86 | iterator = iterator_utils.get_iterator(
87 | src_dataset,
88 | tgt_dataset,
89 | src_vocab_table,
90 | tgt_vocab_table,
91 | batch_size=hparams.batch_size,
92 | sos=hparams.sos,
93 | eos=hparams.eos,
94 | random_seed=hparams.random_seed,
95 | num_buckets=hparams.num_buckets,
96 | src_max_len=hparams.src_max_len,
97 | tgt_max_len=hparams.tgt_max_len,
98 | skip_count=skip_count_placeholder,
99 | num_shards=num_workers,
100 | shard_index=jobid)
101 |
102 | # Note: One can set model_device_fn to
103 | # `tf.train.replica_device_setter(ps_tasks)` for distributed training.
104 | model_device_fn = None
105 | if extra_args: model_device_fn = extra_args.model_device_fn
106 | with tf.device(model_device_fn):
107 | model = model_creator(
108 | hparams,
109 | iterator=iterator,
110 | mode=tf.contrib.learn.ModeKeys.TRAIN,
111 | source_vocab_table=src_vocab_table,
112 | target_vocab_table=tgt_vocab_table,
113 | scope=scope,
114 | extra_args=extra_args)
115 |
116 | return TrainModel(
117 | graph=graph,
118 | model=model,
119 | iterator=iterator,
120 | skip_count_placeholder=skip_count_placeholder)
121 |
122 |
123 | class EvalModel(
124 | collections.namedtuple("EvalModel",
125 | ("graph", "model", "src_file_placeholder",
126 | "tgt_file_placeholder", "iterator"))):
127 | pass
128 |
129 |
130 | def create_eval_model(model_creator, hparams, scope=None, extra_args=None):
131 | """Create train graph, model, src/tgt file holders, and iterator."""
132 | src_vocab_file = hparams.src_vocab_file
133 | tgt_vocab_file = hparams.tgt_vocab_file
134 | graph = tf.Graph()
135 |
136 | with graph.as_default(), tf.container(scope or "eval"):
137 | src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
138 | src_vocab_file, tgt_vocab_file, hparams.share_vocab)
139 | src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
140 | tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
141 | src_dataset = tf.data.TextLineDataset(src_file_placeholder)
142 | tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder)
143 | iterator = iterator_utils.get_iterator(
144 | src_dataset,
145 | tgt_dataset,
146 | src_vocab_table,
147 | tgt_vocab_table,
148 | hparams.batch_size,
149 | sos=hparams.sos,
150 | eos=hparams.eos,
151 | random_seed=hparams.random_seed,
152 | num_buckets=hparams.num_buckets,
153 | src_max_len=hparams.src_max_len_infer,
154 | tgt_max_len=hparams.tgt_max_len_infer)
155 | model = model_creator(
156 | hparams,
157 | iterator=iterator,
158 | mode=tf.contrib.learn.ModeKeys.EVAL,
159 | source_vocab_table=src_vocab_table,
160 | target_vocab_table=tgt_vocab_table,
161 | scope=scope,
162 | extra_args=extra_args)
163 | return EvalModel(
164 | graph=graph,
165 | model=model,
166 | src_file_placeholder=src_file_placeholder,
167 | tgt_file_placeholder=tgt_file_placeholder,
168 | iterator=iterator)
169 |
170 |
171 | class InferModel(
172 | collections.namedtuple("InferModel",
173 | ("graph", "model", "src_placeholder",
174 | "batch_size_placeholder", "iterator"))):
175 | pass
176 |
177 |
178 | def create_infer_model(model_creator, hparams, scope=None, extra_args=None):
179 | """Create inference model."""
180 | graph = tf.Graph()
181 | src_vocab_file = hparams.src_vocab_file
182 | tgt_vocab_file = hparams.tgt_vocab_file
183 |
184 | with graph.as_default(), tf.container(scope or "infer"):
185 | src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
186 | src_vocab_file, tgt_vocab_file, hparams.share_vocab)
187 | reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
188 | tgt_vocab_file, default_value=vocab_utils.UNK)
189 |
190 | src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
191 | batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)
192 |
193 | src_dataset = tf.data.Dataset.from_tensor_slices(
194 | src_placeholder)
195 | iterator = iterator_utils.get_infer_iterator(
196 | src_dataset,
197 | src_vocab_table,
198 | batch_size=batch_size_placeholder,
199 | eos=hparams.eos,
200 | src_max_len=hparams.src_max_len_infer)
201 | model = model_creator(
202 | hparams,
203 | iterator=iterator,
204 | mode=tf.contrib.learn.ModeKeys.INFER,
205 | source_vocab_table=src_vocab_table,
206 | target_vocab_table=tgt_vocab_table,
207 | reverse_target_vocab_table=reverse_tgt_vocab_table,
208 | scope=scope,
209 | extra_args=extra_args)
210 | return InferModel(
211 | graph=graph,
212 | model=model,
213 | src_placeholder=src_placeholder,
214 | batch_size_placeholder=batch_size_placeholder,
215 | iterator=iterator)
216 |
217 |
218 | def _get_embed_device(vocab_size):
219 | """Decide on which device to place an embed matrix given its vocab size."""
220 | if vocab_size > VOCAB_SIZE_THRESHOLD_CPU:
221 | return "/cpu:0"
222 | else:
223 | return "/gpu:0"
224 |
225 |
226 | def _create_pretrained_emb_from_txt(
227 | vocab_file, embed_file, num_trainable_tokens=3, dtype=tf.float32,
228 | scope=None):
229 | """Load pretrain embeding from embed_file, and return an embedding matrix.
230 |
231 | Args:
232 | embed_file: Path to a Glove formated embedding txt file.
233 | num_trainable_tokens: Make the first n tokens in the vocab file as trainable
234 | variables. Default is 3, which is "", "" and "".
235 | """
236 | vocab, _ = vocab_utils.load_vocab(vocab_file)
237 | trainable_tokens = vocab[:num_trainable_tokens]
238 |
239 | utils.print_out("# Using pretrained embedding: %s." % embed_file)
240 | utils.print_out(" with trainable tokens: ")
241 |
242 | emb_dict, emb_size = vocab_utils.load_embed_txt(embed_file)
243 | for token in trainable_tokens:
244 | utils.print_out(" %s" % token)
245 | if token not in emb_dict:
246 | emb_dict[token] = [0.0] * emb_size
247 |
248 | emb_mat = np.array(
249 | [emb_dict[token] for token in vocab], dtype=dtype.as_numpy_dtype())
250 | emb_mat = tf.constant(emb_mat)
251 | emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1])
252 | with tf.variable_scope(scope or "pretrain_embeddings", dtype=dtype) as scope:
253 | with tf.device(_get_embed_device(num_trainable_tokens)):
254 | emb_mat_var = tf.get_variable(
255 | "emb_mat_var", [num_trainable_tokens, emb_size])
256 | return tf.concat([emb_mat_var, emb_mat_const], 0)
257 |
258 |
259 | def _create_or_load_embed(embed_name, vocab_file, embed_file,
260 | vocab_size, embed_size, dtype):
261 | """Create a new or load an existing embedding matrix."""
262 | if vocab_file and embed_file:
263 | embedding = _create_pretrained_emb_from_txt(vocab_file, embed_file)
264 | else:
265 | with tf.device(_get_embed_device(vocab_size)):
266 | embedding = tf.get_variable(
267 | embed_name, [vocab_size, embed_size], dtype)
268 | return embedding
269 |
270 |
271 | def create_emb_for_encoder_and_decoder(share_vocab,
272 | src_vocab_size,
273 | tgt_vocab_size,
274 | src_embed_size,
275 | tgt_embed_size,
276 | dtype=tf.float32,
277 | num_partitions=0,
278 | src_vocab_file=None,
279 | tgt_vocab_file=None,
280 | src_embed_file=None,
281 | tgt_embed_file=None,
282 | scope=None):
283 | """Create embedding matrix for both encoder and decoder.
284 |
285 | Args:
286 | share_vocab: A boolean. Whether to share embedding matrix for both
287 | encoder and decoder.
288 | src_vocab_size: An integer. The source vocab size.
289 | tgt_vocab_size: An integer. The target vocab size.
290 | src_embed_size: An integer. The embedding dimension for the encoder's
291 | embedding.
292 | tgt_embed_size: An integer. The embedding dimension for the decoder's
293 | embedding.
294 | dtype: dtype of the embedding matrix. Default to float32.
295 | num_partitions: number of partitions used for the embedding vars.
296 | scope: VariableScope for the created subgraph. Default to "embedding".
297 |
298 | Returns:
299 | embedding_encoder: Encoder's embedding matrix.
300 | embedding_decoder: Decoder's embedding matrix.
301 |
302 | Raises:
303 | ValueError: if use share_vocab but source and target have different vocab
304 | size.
305 | """
306 |
307 | if num_partitions <= 1:
308 | partitioner = None
309 | else:
310 | # Note: num_partitions > 1 is required for distributed training due to
311 | # embedding_lookup tries to colocate single partition-ed embedding variable
312 | # with lookup ops. This may cause embedding variables being placed on worker
313 | # jobs.
314 | partitioner = tf.fixed_size_partitioner(num_partitions)
315 |
316 | if (src_embed_file or tgt_embed_file) and partitioner:
317 | raise ValueError(
318 | "Can't set num_partitions > 1 when using pretrained embedding")
319 |
320 | with tf.variable_scope(
321 | scope or "embeddings", dtype=dtype, partitioner=partitioner) as scope:
322 | # Share embedding
323 | if share_vocab:
324 | if src_vocab_size != tgt_vocab_size:
325 | raise ValueError("Share embedding but different src/tgt vocab sizes"
326 | " %d vs. %d" % (src_vocab_size, tgt_vocab_size))
327 | assert src_embed_size == tgt_embed_size
328 | utils.print_out("# Use the same embedding for source and target")
329 | vocab_file = src_vocab_file or tgt_vocab_file
330 | embed_file = src_embed_file or tgt_embed_file
331 |
332 | embedding_encoder = _create_or_load_embed(
333 | "embedding_share", vocab_file, embed_file,
334 | src_vocab_size, src_embed_size, dtype)
335 | embedding_decoder = embedding_encoder
336 | else:
337 | with tf.variable_scope("encoder", partitioner=partitioner):
338 | embedding_encoder = _create_or_load_embed(
339 | "embedding_encoder", src_vocab_file, src_embed_file,
340 | src_vocab_size, src_embed_size, dtype)
341 |
342 | with tf.variable_scope("decoder", partitioner=partitioner):
343 | embedding_decoder = _create_or_load_embed(
344 | "embedding_decoder", tgt_vocab_file, tgt_embed_file,
345 | tgt_vocab_size, tgt_embed_size, dtype)
346 |
347 | return embedding_encoder, embedding_decoder
348 |
349 |
350 | def _single_cell(unit_type, num_units, forget_bias, dropout, mode,
351 | residual_connection=False, device_str=None, residual_fn=None):
352 | """Create an instance of a single RNN cell."""
353 | # dropout (= 1 - keep_prob) is set to 0 during eval and infer
354 | dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0
355 |
356 | # Cell Type
357 | if unit_type == "lstm":
358 | utils.print_out(" LSTM, forget_bias=%g" % forget_bias, new_line=False)
359 | single_cell = tf.contrib.rnn.BasicLSTMCell(
360 | num_units,
361 | forget_bias=forget_bias)
362 | elif unit_type == "gru":
363 | utils.print_out(" GRU", new_line=False)
364 | single_cell = tf.contrib.rnn.GRUCell(num_units)
365 | elif unit_type == "layer_norm_lstm":
366 | utils.print_out(" Layer Normalized LSTM, forget_bias=%g" % forget_bias,
367 | new_line=False)
368 | single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell(
369 | num_units,
370 | forget_bias=forget_bias,
371 | layer_norm=True)
372 | elif unit_type == "nas":
373 | utils.print_out(" NASCell", new_line=False)
374 | single_cell = tf.contrib.rnn.NASCell(num_units)
375 | else:
376 | raise ValueError("Unknown unit type %s!" % unit_type)
377 |
378 | # Dropout (= 1 - keep_prob)
379 | if dropout > 0.0:
380 | single_cell = tf.contrib.rnn.DropoutWrapper(
381 | cell=single_cell, input_keep_prob=(1.0 - dropout))
382 | utils.print_out(" %s, dropout=%g " %(type(single_cell).__name__, dropout),
383 | new_line=False)
384 |
385 | # Residual
386 | if residual_connection:
387 | single_cell = tf.contrib.rnn.ResidualWrapper(
388 | single_cell, residual_fn=residual_fn)
389 | utils.print_out(" %s" % type(single_cell).__name__, new_line=False)
390 |
391 | # Device Wrapper
392 | if device_str:
393 | single_cell = tf.contrib.rnn.DeviceWrapper(single_cell, device_str)
394 | utils.print_out(" %s, device=%s" %
395 | (type(single_cell).__name__, device_str), new_line=False)
396 |
397 | return single_cell
398 |
399 |
400 | def _cell_list(unit_type, num_units, num_layers, num_residual_layers,
401 | forget_bias, dropout, mode, num_gpus, base_gpu=0,
402 | single_cell_fn=None, residual_fn=None):
403 | """Create a list of RNN cells."""
404 | if not single_cell_fn:
405 | single_cell_fn = _single_cell
406 |
407 | # Multi-GPU
408 | cell_list = []
409 | for i in range(num_layers):
410 | utils.print_out(" cell %d" % i, new_line=False)
411 | single_cell = single_cell_fn(
412 | unit_type=unit_type,
413 | num_units=num_units,
414 | forget_bias=forget_bias,
415 | dropout=dropout,
416 | mode=mode,
417 | residual_connection=(i >= num_layers - num_residual_layers),
418 | device_str=get_device_str(i + base_gpu, num_gpus),
419 | residual_fn=residual_fn
420 | )
421 | utils.print_out("")
422 | cell_list.append(single_cell)
423 |
424 | return cell_list
425 |
426 |
427 | def create_rnn_cell(unit_type, num_units, num_layers, num_residual_layers,
428 | forget_bias, dropout, mode, num_gpus, base_gpu=0,
429 | single_cell_fn=None):
430 | """Create multi-layer RNN cell.
431 |
432 | Args:
433 | unit_type: string representing the unit type, i.e. "lstm".
434 | num_units: the depth of each unit.
435 | num_layers: number of cells.
436 | num_residual_layers: Number of residual layers from top to bottom. For
437 | example, if `num_layers=4` and `num_residual_layers=2`, the last 2 RNN
438 | cells in the returned list will be wrapped with `ResidualWrapper`.
439 | forget_bias: the initial forget bias of the RNNCell(s).
440 | dropout: floating point value between 0.0 and 1.0:
441 | the probability of dropout. this is ignored if `mode != TRAIN`.
442 | mode: either tf.contrib.learn.TRAIN/EVAL/INFER
443 | num_gpus: The number of gpus to use when performing round-robin
444 | placement of layers.
445 | base_gpu: The gpu device id to use for the first RNN cell in the
446 | returned list. The i-th RNN cell will use `(base_gpu + i) % num_gpus`
447 | as its device id.
448 | single_cell_fn: allow for adding customized cell.
449 | When not specified, we default to model_helper._single_cell
450 | Returns:
451 | An `RNNCell` instance.
452 | """
453 | cell_list = _cell_list(unit_type=unit_type,
454 | num_units=num_units,
455 | num_layers=num_layers,
456 | num_residual_layers=num_residual_layers,
457 | forget_bias=forget_bias,
458 | dropout=dropout,
459 | mode=mode,
460 | num_gpus=num_gpus,
461 | base_gpu=base_gpu,
462 | single_cell_fn=single_cell_fn)
463 |
464 | if len(cell_list) == 1: # Single layer.
465 | return cell_list[0]
466 | else: # Multi layers
467 | return tf.contrib.rnn.MultiRNNCell(cell_list)
468 |
469 |
470 | def gradient_clip(gradients, max_gradient_norm):
471 | """Clipping gradients of a model."""
472 | clipped_gradients, gradient_norm = tf.clip_by_global_norm(
473 | gradients, max_gradient_norm)
474 | gradient_norm_summary = [tf.summary.scalar("grad_norm", gradient_norm)]
475 | gradient_norm_summary.append(
476 | tf.summary.scalar("clipped_gradient", tf.global_norm(clipped_gradients)))
477 |
478 | return clipped_gradients, gradient_norm_summary, gradient_norm
479 |
480 |
481 | def load_model(model, ckpt, session, name):
482 | start_time = time.time()
483 | model.saver.restore(session, ckpt)
484 | session.run(tf.tables_initializer())
485 | utils.print_out(
486 | " loaded %s model parameters from %s, time %.2fs" %
487 | (name, ckpt, time.time() - start_time))
488 | return model
489 |
490 |
491 | def avg_checkpoints(model_dir, num_last_checkpoints, global_step,
492 | global_step_name):
493 | """Average the last N checkpoints in the model_dir."""
494 | checkpoint_state = tf.train.get_checkpoint_state(model_dir)
495 | if not checkpoint_state:
496 | utils.print_out("# No checkpoint file found in directory: %s" % model_dir)
497 | return None
498 |
499 | # Checkpoints are ordered from oldest to newest.
500 | checkpoints = (
501 | checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:])
502 |
503 | if len(checkpoints) < num_last_checkpoints:
504 | utils.print_out(
505 | "# Skipping averaging checkpoints because not enough checkpoints is "
506 | "avaliable."
507 | )
508 | return None
509 |
510 | avg_model_dir = os.path.join(model_dir, "avg_checkpoints")
511 | if not tf.gfile.Exists(avg_model_dir):
512 | utils.print_out(
513 | "# Creating new directory %s for saving averaged checkpoints." %
514 | avg_model_dir)
515 | tf.gfile.MakeDirs(avg_model_dir)
516 |
517 | utils.print_out("# Reading and averaging variables in checkpoints:")
518 | var_list = tf.contrib.framework.list_variables(checkpoints[0])
519 | var_values, var_dtypes = {}, {}
520 | for (name, shape) in var_list:
521 | if name != global_step_name:
522 | var_values[name] = np.zeros(shape)
523 |
524 | for checkpoint in checkpoints:
525 | utils.print_out(" %s" % checkpoint)
526 | reader = tf.contrib.framework.load_checkpoint(checkpoint)
527 | for name in var_values:
528 | tensor = reader.get_tensor(name)
529 | var_dtypes[name] = tensor.dtype
530 | var_values[name] += tensor
531 |
532 | for name in var_values:
533 | var_values[name] /= len(checkpoints)
534 |
535 | # Build a graph with same variables in the checkpoints, and save the averaged
536 | # variables into the avg_model_dir.
537 | with tf.Graph().as_default():
538 | tf_vars = [
539 | tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name])
540 | for v in var_values
541 | ]
542 |
543 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
544 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
545 | global_step_var = tf.Variable(
546 | global_step, name=global_step_name, trainable=False)
547 | saver = tf.train.Saver(tf.all_variables())
548 |
549 | with tf.Session() as sess:
550 | sess.run(tf.initialize_all_variables())
551 | for p, assign_op, (name, value) in zip(placeholders, assign_ops,
552 | six.iteritems(var_values)):
553 | sess.run(assign_op, {p: value})
554 |
555 | # Use the built saver to save the averaged checkpoint. Only keep 1
556 | # checkpoint and the best checkpoint will be moved to avg_best_metric_dir.
557 | saver.save(
558 | sess,
559 | os.path.join(avg_model_dir, "translate.ckpt"))
560 |
561 | return avg_model_dir
562 |
563 |
564 | def create_or_load_model(model, model_dir, session, name):
565 | """Create translation model and initialize or load parameters in session."""
566 | latest_ckpt = tf.train.latest_checkpoint(model_dir)
567 | if latest_ckpt:
568 | model = load_model(model, latest_ckpt, session, name)
569 | else:
570 | start_time = time.time()
571 | session.run(tf.global_variables_initializer())
572 | session.run(tf.tables_initializer())
573 | utils.print_out(" created %s model with fresh parameters, time %.2fs" %
574 | (name, time.time() - start_time))
575 |
576 | global_step = model.global_step.eval(session=session)
577 | return model, global_step
578 |
579 |
580 | def compute_perplexity(model, sess, name):
581 | """Compute perplexity of the output of the model.
582 |
583 | Args:
584 | model: model for compute perplexity.
585 | sess: tensorflow session to use.
586 | name: name of the batch.
587 |
588 | Returns:
589 | The perplexity of the eval outputs.
590 | """
591 | total_loss = 0
592 | total_predict_count = 0
593 | start_time = time.time()
594 |
595 | while True:
596 | try:
597 | loss, predict_count, batch_size = model.eval(sess)
598 | total_loss += loss * batch_size
599 | total_predict_count += predict_count
600 | except tf.errors.OutOfRangeError:
601 | break
602 |
603 | perplexity = utils.safe_exp(total_loss / total_predict_count)
604 | utils.print_time(" eval %s: perplexity %.2f" % (name, perplexity),
605 | start_time)
606 | return perplexity
607 |
--------------------------------------------------------------------------------