├── .gitignore ├── xmunmt ├── __init__.py ├── utils │ ├── utils.py │ ├── __init__.py │ ├── bleu.py │ ├── parallel.py │ ├── hooks.py │ └── search.py ├── data │ ├── __init__.py │ ├── vocab.py │ ├── record.py │ └── dataset.py ├── layers │ ├── __init__.py │ ├── rnn_cell.py │ ├── attention.py │ └── nn.py ├── interface │ ├── __init__.py │ └── model.py ├── models │ ├── __init__.py │ └── rnnsearch.py ├── scripts │ ├── shuffle_corpus.py │ ├── char_utils.py │ ├── build_vocab.py │ └── checkpoint_averaging.py └── bin │ ├── translator.py │ └── trainer.py ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *~ -------------------------------------------------------------------------------- /xmunmt/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | -------------------------------------------------------------------------------- /xmunmt/utils/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | -------------------------------------------------------------------------------- /xmunmt/data/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | -------------------------------------------------------------------------------- /xmunmt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | -------------------------------------------------------------------------------- /xmunmt/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | import tensorflow as tf 7 | import xmunmt.layers.attention 8 | import xmunmt.layers.nn 9 | import xmunmt.layers.rnn_cell 10 | -------------------------------------------------------------------------------- /xmunmt/interface/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | from xmunmt.interface.model import NMTModel 11 | -------------------------------------------------------------------------------- /xmunmt/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import xmunmt.models.rnnsearch 11 | 12 | 13 | def get_model(name): 14 | name = name.lower() 15 | 16 | if name == "rnnsearch": 17 | return xmunmt.models.rnnsearch.RNNsearch 18 | else: 19 | raise LookupError("Unknown model %s" % name) 20 | -------------------------------------------------------------------------------- /xmunmt/data/vocab.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import tensorflow as tf 11 | 12 | 13 | def load_vocabulary(filename): 14 | vocab = [] 15 | with tf.gfile.GFile(filename) as fd: 16 | for line in fd: 17 | word = line.strip() 18 | vocab.append(word) 19 | 20 | return vocab 21 | 22 | 23 | def process_vocabulary(vocab, params): 24 | if params.append_eos: 25 | vocab.append(params.eos) 26 | 27 | return vocab 28 | 29 | 30 | def get_control_mapping(vocab, symbols): 31 | mapping = {} 32 | 33 | for i, token in enumerate(vocab): 34 | for symbol in symbols: 35 | if symbol.decode("utf-8") == token.decode("utf-8"): 36 | mapping[symbol] = i 37 | 38 | return mapping 39 | -------------------------------------------------------------------------------- /xmunmt/interface/model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | 11 | class NMTModel(object): 12 | 13 | def __init__(self, params, scope): 14 | self._scope = scope 15 | self._params = params 16 | 17 | def get_training_func(self, initializer): 18 | raise NotImplementedError("Not implemented") 19 | 20 | def get_evaluation_func(self): 21 | raise NotImplementedError("Not implemented") 22 | 23 | def get_inference_func(self): 24 | raise NotImplementedError("Not implemented") 25 | 26 | @staticmethod 27 | def get_name(): 28 | raise NotImplementedError("Not implemented") 29 | 30 | @staticmethod 31 | def get_parameters(): 32 | raise NotImplementedError("Not implemented") 33 | 34 | @property 35 | def parameters(self): 36 | return self._params 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Natural Language Processing Lab at Xiamen University 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, this 11 | list of conditions and the following disclaimer in the documentation and/or 12 | other materials provided with the distribution. 13 | 14 | * Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from this 16 | software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 22 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 25 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /xmunmt/scripts/shuffle_corpus.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import argparse 11 | 12 | import numpy 13 | 14 | 15 | def parseargs(): 16 | parser = argparse.ArgumentParser(description="Shuffle corpus") 17 | 18 | parser.add_argument("--corpus", nargs="+", required=True, 19 | help="input corpora") 20 | parser.add_argument("--suffix", type=str, default="shuf", 21 | help="Suffix of output files") 22 | parser.add_argument("--seed", type=int, help="Random seed") 23 | 24 | return parser.parse_args() 25 | 26 | 27 | def main(args): 28 | name = args.corpus 29 | suffix = "." + args.suffix 30 | stream = [open(item, "r") for item in name] 31 | data = [fd.readlines() for fd in stream] 32 | minlen = min([len(lines) for lines in data]) 33 | 34 | if args.seed: 35 | numpy.random.seed(args.seed) 36 | 37 | indices = numpy.arange(minlen) 38 | numpy.random.shuffle(indices) 39 | 40 | newstream = [open(item + suffix, "w") for item in name] 41 | 42 | for idx in indices.tolist(): 43 | lines = [item[idx] for item in data] 44 | 45 | for line, fd in zip(lines, newstream): 46 | fd.write(line) 47 | 48 | for fdr, fdw in zip(stream, newstream): 49 | fdr.close() 50 | fdw.close() 51 | 52 | 53 | if __name__ == "__main__": 54 | main(parseargs()) 55 | -------------------------------------------------------------------------------- /xmunmt/layers/rnn_cell.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import tensorflow as tf 11 | from xmunmt.layers.nn import linear 12 | 13 | 14 | class LegacyGRUCell(tf.nn.rnn_cell.RNNCell): 15 | """ Groundhog's implementation of GRUCell 16 | 17 | Args: 18 | num_units: int, The number of units in the RNN cell. 19 | reuse: (optional) Python boolean describing whether to reuse 20 | variables in an existing scope. If not `True`, and the existing 21 | scope already has the given variables, an error is raised. 22 | """ 23 | 24 | def __init__(self, num_units, reuse=None): 25 | super(LegacyGRUCell, self).__init__(_reuse=reuse) 26 | self._num_units = num_units 27 | 28 | def __call__(self, inputs, state, scope=None): 29 | with tf.variable_scope(scope, default_name="gru_cell", 30 | values=[inputs, state]): 31 | if not isinstance(inputs, (list, tuple)): 32 | inputs = [inputs] 33 | 34 | all_inputs = list(inputs) + [state] 35 | r = tf.nn.sigmoid(linear(all_inputs, self._num_units, False, False, 36 | scope="reset_gate")) 37 | u = tf.nn.sigmoid(linear(all_inputs, self._num_units, False, False, 38 | scope="update_gate")) 39 | all_inputs = list(inputs) + [r * state] 40 | c = linear(all_inputs, self._num_units, True, False, 41 | scope="candidate") 42 | 43 | new_state = (1.0 - u) * state + u * tf.tanh(c) 44 | 45 | return new_state, new_state 46 | 47 | @property 48 | def state_size(self): 49 | return self._num_units 50 | 51 | @property 52 | def output_size(self): 53 | return self._num_units 54 | -------------------------------------------------------------------------------- /xmunmt/scripts/char_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import sys 11 | 12 | 13 | def encode(): 14 | for line in sys.stdin: 15 | line = line.decode("utf-8") 16 | token = [] 17 | token_list = [] 18 | 19 | for char in line: 20 | if char == " ": 21 | if token: 22 | token_list.append(token) 23 | token_list.append(["_"]) 24 | token = [] 25 | elif char == "_": 26 | if token: 27 | token_list.append(token) 28 | token_list.append(["@_@"]) 29 | token = [] 30 | elif char == "-": 31 | if token: 32 | token_list.append(token) 33 | token_list.append(["@-@"]) 34 | token = [] 35 | elif ord(char) < 256: 36 | token.append(char.encode("utf-8")) 37 | else: 38 | if token: 39 | token_list.append(token) 40 | token_list.append([char.encode("utf-8")]) 41 | token = [] 42 | 43 | if token is not None: 44 | token_list.append(token) 45 | 46 | tokens = ["".join(item) for item in token_list] 47 | encoded = " ".join(tokens) 48 | sys.stdout.write(encoded) 49 | 50 | 51 | def decode(): 52 | for line in sys.stdin: 53 | tokens = line.strip().split() 54 | token_list = [] 55 | 56 | for token in tokens: 57 | if token == "@_@": 58 | token_list.append("_") 59 | elif token == "_": 60 | token_list.append(" ") 61 | elif token == "@-@": 62 | token_list.append("-") 63 | else: 64 | token_list.append(token) 65 | 66 | decoded = "".join(token_list) 67 | sys.stdout.write(decoded + "\n") 68 | 69 | 70 | def usage(): 71 | print("usage: char_utils.py encode < input > output\n" 72 | " char_utils.py decode < input > output") 73 | 74 | 75 | if __name__ == "__main__": 76 | if len(sys.argv) != 2: 77 | usage() 78 | 79 | if sys.argv[1] == "encode": 80 | encode() 81 | elif sys.argv[1] == "decode": 82 | decode() 83 | else: 84 | usage() 85 | -------------------------------------------------------------------------------- /xmunmt/layers/attention.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import math 11 | 12 | import tensorflow as tf 13 | from xmunmt.layers.nn import linear 14 | 15 | 16 | def attention_bias(inputs, inf=-1e9, name=None): 17 | """ A bias tensor used in attention mechanism 18 | """ 19 | 20 | with tf.name_scope(name, default_name="attention_bias", values=[inputs]): 21 | mask = inputs 22 | bias = (1.0 - mask) * inf 23 | return bias 24 | 25 | 26 | def attention(query, memories, bias, hidden_size, cache=None, reuse=None, 27 | dtype=None, scope=None): 28 | """ Standard attention layer 29 | 30 | Args: 31 | query: A tensor with shape [batch, key_size] 32 | memories: A tensor with shape [batch, memory_size, key_size] 33 | bias: A tensor with shape [batch, memory_size] 34 | hidden_size: An integer 35 | cache: A dictionary of precomputed value 36 | reuse: A boolean value, whether to reuse the scope 37 | dtype: An optional instance of tf.DType 38 | scope: An optional string, the scope of this layer 39 | 40 | Return: 41 | A tensor with shape [batch, value_size] and a Tensor with 42 | shape [batch, memory_size] 43 | """ 44 | 45 | with tf.variable_scope(scope or "attention", reuse=reuse, 46 | values=[query, memories, bias], dtype=dtype): 47 | mem_shape = tf.shape(memories) 48 | key_size = memories.get_shape().as_list()[-1] 49 | 50 | if cache is None: 51 | k = tf.reshape(memories, [-1, key_size]) 52 | k = linear(k, hidden_size, False, False, scope="k_transform") 53 | 54 | if query is None: 55 | return {"key": k} 56 | else: 57 | k = cache["key"] 58 | 59 | q = linear(query, hidden_size, False, False, scope="q_transform") 60 | k = tf.reshape(k, [mem_shape[0], mem_shape[1], hidden_size]) 61 | 62 | hidden = tf.tanh(q[:, None, :] + k) 63 | hidden = tf.reshape(hidden, [-1, hidden_size]) 64 | 65 | logits = linear(hidden, 1, False, False, scope="logits") 66 | logits = tf.reshape(logits, [-1, mem_shape[1]]) 67 | 68 | if bias is not None: 69 | logits = logits + bias 70 | 71 | alpha = tf.nn.softmax(logits) 72 | 73 | outputs = { 74 | "value": tf.reduce_sum(alpha[:, :, None] * memories, axis=1), 75 | "weight": alpha 76 | } 77 | 78 | return outputs 79 | -------------------------------------------------------------------------------- /xmunmt/scripts/build_vocab.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import argparse 11 | import collections 12 | 13 | 14 | def count_words(filename): 15 | counter = collections.Counter() 16 | 17 | with open(filename, "r") as fd: 18 | for line in fd: 19 | words = line.strip().split() 20 | counter.update(words) 21 | 22 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) 23 | words, counts = list(zip(*count_pairs)) 24 | 25 | return words, counts 26 | 27 | 28 | def control_symbols(string): 29 | if not string: 30 | return [] 31 | else: 32 | return string.strip().split(",") 33 | 34 | 35 | def save_vocab(name, vocab): 36 | if name.split(".")[-1] != "txt": 37 | name = name + ".txt" 38 | 39 | pairs = sorted(vocab.items(), key=lambda x: (x[1], x[0])) 40 | words, ids = list(zip(*pairs)) 41 | 42 | with open(name, "w") as f: 43 | for word in words: 44 | f.write(word + "\n") 45 | 46 | 47 | def parse_args(): 48 | parser = argparse.ArgumentParser(description="Create vocabulary") 49 | 50 | parser.add_argument("corpus", help="input corpus") 51 | parser.add_argument("output", default="vocab.txt", 52 | help="Output vocabulary name") 53 | parser.add_argument("--limit", default=0, type=int, help="Vocabulary size") 54 | parser.add_argument("--control", type=str, default=",UNK", 55 | help="Add control symbols to vocabulary. " 56 | "Control symbols are separated by comma.") 57 | 58 | return parser.parse_args() 59 | 60 | 61 | def main(args): 62 | vocab = {} 63 | limit = args.limit 64 | count = 0 65 | 66 | words, counts = count_words(args.corpus) 67 | ctrl_symbols = control_symbols(args.control) 68 | 69 | for sym in ctrl_symbols: 70 | vocab[sym] = len(vocab) 71 | 72 | for word, freq in zip(words, counts): 73 | if limit and len(vocab) >= limit: 74 | break 75 | 76 | if word in vocab: 77 | print("Warning: found duplicate token %s, ignored" % word) 78 | continue 79 | 80 | vocab[word] = len(vocab) 81 | count += freq 82 | 83 | save_vocab(args.output, vocab) 84 | 85 | print("Total words: %d" % sum(counts)) 86 | print("Unique words: %d" % len(words)) 87 | print("Vocabulary coverage: %4.2f%%" % (100.0 * count / sum(counts))) 88 | 89 | 90 | if __name__ == "__main__": 91 | main(parse_args()) 92 | -------------------------------------------------------------------------------- /xmunmt/utils/bleu.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import math 11 | from collections import Counter 12 | 13 | 14 | def closest_length(candidate, references): 15 | clen = len(candidate) 16 | closest_diff = 9999 17 | closest_len = 9999 18 | 19 | for reference in references: 20 | rlen = len(reference) 21 | diff = abs(rlen - clen) 22 | 23 | if diff < closest_diff: 24 | closest_diff = diff 25 | closest_len = rlen 26 | elif diff == closest_diff: 27 | closest_len = rlen if rlen < closest_len else closest_len 28 | 29 | return closest_len 30 | 31 | 32 | def shortest_length(references): 33 | return min([len(ref) for ref in references]) 34 | 35 | 36 | def modified_precision(candidate, references, n): 37 | tngrams = len(candidate) + 1 - n 38 | counts = Counter([tuple(candidate[i:i + n]) for i in range(tngrams)]) 39 | 40 | if len(counts) == 0: 41 | return 0, 0 42 | 43 | max_counts = {} 44 | for reference in references: 45 | rngrams = len(reference) + 1 - n 46 | ngrams = [tuple(reference[i:i + n]) for i in range(rngrams)] 47 | ref_counts = Counter(ngrams) 48 | for ngram in counts: 49 | mcount = 0 if ngram not in max_counts else max_counts[ngram] 50 | rcount = 0 if ngram not in ref_counts else ref_counts[ngram] 51 | max_counts[ngram] = max(mcount, rcount) 52 | 53 | clipped_counts = {} 54 | 55 | for ngram, count in counts.items(): 56 | clipped_counts[ngram] = min(count, max_counts[ngram]) 57 | 58 | return float(sum(clipped_counts.values())), float(sum(counts.values())) 59 | 60 | 61 | def brevity_penalty(trans, refs, mode="closest"): 62 | bp_c = 0.0 63 | bp_r = 0.0 64 | 65 | for candidate, references in zip(trans, refs): 66 | bp_c += len(candidate) 67 | 68 | if mode == "shortest": 69 | bp_r += shortest_length(references) 70 | else: 71 | bp_r += closest_length(candidate, references) 72 | 73 | # Prevent zero divide 74 | bp_c = bp_c or 1.0 75 | 76 | return math.exp(min(0, 1.0 - bp_r / bp_c)) 77 | 78 | 79 | def bleu(trans, refs, bp="closest", smooth=False, n=4, weights=None): 80 | p_norm = [0 for _ in range(n)] 81 | p_denorm = [0 for _ in range(n)] 82 | 83 | for candidate, references in zip(trans, refs): 84 | for i in range(n): 85 | ccount, tcount = modified_precision(candidate, references, i + 1) 86 | p_norm[i] += ccount 87 | p_denorm[i] += tcount 88 | 89 | bleu_n = [0 for _ in range(n)] 90 | 91 | for i in range(n): 92 | # Add-one smoothing 93 | if smooth and i > 0: 94 | p_norm[i] += 1 95 | p_denorm[i] += 1 96 | 97 | if p_norm[i] == 0 or p_denorm[i] == 0: 98 | bleu_n[i] = -9999 99 | else: 100 | bleu_n[i] = math.log(float(p_norm[i]) / float(p_denorm[i])) 101 | 102 | if weights: 103 | if len(weights) != n: 104 | raise ValueError("len(weights) != n: invalid weight number") 105 | log_precision = sum([bleu_n[i] * weights[i] for i in range(n)]) 106 | else: 107 | log_precision = sum(bleu_n) / float(n) 108 | 109 | bp = brevity_penalty(trans, refs, bp) 110 | 111 | score = bp * math.exp(log_precision) 112 | 113 | return score 114 | -------------------------------------------------------------------------------- /xmunmt/utils/parallel.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import operator 11 | 12 | import tensorflow as tf 13 | 14 | 15 | class GPUParamServerDeviceSetter(object): 16 | 17 | def __init__(self, worker_device, ps_devices): 18 | self.ps_devices = ps_devices 19 | self.worker_device = worker_device 20 | self.ps_sizes = [0] * len(self.ps_devices) 21 | 22 | def __call__(self, op): 23 | if op.device: 24 | return op.device 25 | if op.type not in ["Variable", "VariableV2", "VarHandleOp"]: 26 | return self.worker_device 27 | 28 | # Gets the least loaded ps_device 29 | device_index, _ = min(enumerate(self.ps_sizes), 30 | key=operator.itemgetter(1)) 31 | device_name = self.ps_devices[device_index] 32 | var_size = op.outputs[0].get_shape().num_elements() 33 | self.ps_sizes[device_index] += var_size 34 | 35 | return device_name 36 | 37 | 38 | def _maybe_repeat(x, n): 39 | if isinstance(x, list): 40 | assert len(x) == n 41 | return x 42 | else: 43 | return [x] * n 44 | 45 | 46 | def _create_device_setter(is_cpu_ps, worker, num_gpus): 47 | if is_cpu_ps: 48 | # tf.train.replica_device_setter supports placing variables on the CPU, 49 | # all on one GPU, or on ps_servers defined in a cluster_spec. 50 | return tf.train.replica_device_setter( 51 | worker_device=worker, ps_device="/cpu:0", ps_tasks=1) 52 | else: 53 | gpus = ["/gpu:%d" % i for i in range(num_gpus)] 54 | return GPUParamServerDeviceSetter(worker, gpus) 55 | 56 | 57 | # Data-level parallelism 58 | def data_parallelism(devices, fn, *args, **kwargs): 59 | num_worker = len(devices) 60 | 61 | # Replicate args and kwargs 62 | if args: 63 | new_args = [_maybe_repeat(arg, num_worker) for arg in args] 64 | # Transpose 65 | new_args = [list(x) for x in zip(*new_args)] 66 | else: 67 | new_args = [[] for _ in range(num_worker)] 68 | 69 | new_kwargs = [{} for _ in range(num_worker)] 70 | 71 | for k, v in kwargs.iteritems(): 72 | vals = _maybe_repeat(v, num_worker) 73 | 74 | for i in range(num_worker): 75 | new_kwargs[i][k] = vals[i] 76 | 77 | fns = _maybe_repeat(fn, num_worker) 78 | 79 | # Now make the parallel call. 80 | outputs = [] 81 | 82 | for i in range(num_worker): 83 | worker = "/gpu:%d" % i 84 | device_setter = _create_device_setter(False, worker, len(devices)) 85 | with tf.variable_scope(tf.get_variable_scope(), reuse=(i != 0)): 86 | with tf.name_scope("parallel_%d" % i): 87 | with tf.device(device_setter): 88 | outputs.append(fns[i](*new_args[i], **new_kwargs[i])) 89 | 90 | if isinstance(outputs[0], tuple): 91 | outputs = list(zip(*outputs)) 92 | outputs = tuple([list(o) for o in outputs]) 93 | 94 | return outputs 95 | 96 | 97 | def shard_features(features, device_list): 98 | num_datashards = len(device_list) 99 | 100 | sharded_features = {} 101 | 102 | for k, v in features.iteritems(): 103 | v = tf.convert_to_tensor(v) 104 | if not v.shape.as_list(): 105 | v = tf.expand_dims(v, axis=-1) 106 | v = tf.tile(v, [num_datashards]) 107 | with tf.device(v.device): 108 | sharded_features[k] = tf.split(v, num_datashards, 0) 109 | 110 | datashard_to_features = [] 111 | 112 | for d in range(num_datashards): 113 | feat = { 114 | k: v[d] for k, v in sharded_features.iteritems() 115 | } 116 | datashard_to_features.append(feat) 117 | 118 | return datashard_to_features 119 | 120 | 121 | def parallel_model(model_fn, features, devices, use_cpu=False): 122 | devices = ["gpu:%d" % d for d in devices] 123 | 124 | if use_cpu: 125 | devices += ["cpu:0"] 126 | 127 | if len(devices) == 1: 128 | outputs = [model_fn(features)] 129 | if isinstance(outputs[0], (list, tuple)): 130 | outputs = list(zip(*outputs)) 131 | else: 132 | features = shard_features(features, devices) 133 | outputs = data_parallelism(devices, model_fn, features) 134 | 135 | return outputs 136 | -------------------------------------------------------------------------------- /xmunmt/scripts/checkpoint_averaging.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import os 11 | import argparse 12 | import operator 13 | import numpy as np 14 | import tensorflow as tf 15 | 16 | 17 | def parseargs(): 18 | msg = "Average checkpoints" 19 | usage = "average.py [] [-h | --help]" 20 | parser = argparse.ArgumentParser(description=msg, usage=usage) 21 | 22 | parser.add_argument("--path", type=str, required=True, 23 | help="checkpoint dir") 24 | parser.add_argument("--checkpoints", type=int, required=True, 25 | help="number of checkpoints to use") 26 | parser.add_argument("--output", type=str, help="output path") 27 | 28 | return parser.parse_args() 29 | 30 | 31 | def get_checkpoints(path): 32 | if not tf.gfile.Exists(os.path.join(path, "checkpoint")): 33 | raise ValueError("Cannot find checkpoints in %s" % path) 34 | 35 | checkpoint_names = [] 36 | 37 | with tf.gfile.GFile(os.path.join(path, "checkpoint")) as fd: 38 | # Skip the first line 39 | fd.readline() 40 | for line in fd: 41 | name = line.strip().split(":")[-1].strip()[1:-1] 42 | key = int(name.split("-")[-1]) 43 | checkpoint_names.append((key, os.path.join(path, name))) 44 | 45 | sorted_names = sorted(checkpoint_names, key=operator.itemgetter(0), 46 | reverse=True) 47 | 48 | return [item[-1] for item in sorted_names] 49 | 50 | 51 | def checkpoint_exists(path): 52 | return (tf.gfile.Exists(path) or tf.gfile.Exists(path + ".meta") or 53 | tf.gfile.Exists(path + ".index")) 54 | 55 | 56 | def main(_): 57 | tf.logging.set_verbosity(tf.logging.INFO) 58 | checkpoints = get_checkpoints(FLAGS.path) 59 | checkpoints = checkpoints[:FLAGS.checkpoints] 60 | 61 | if not checkpoints: 62 | raise ValueError("No checkpoints provided for averaging.") 63 | 64 | checkpoints = [c for c in checkpoints if checkpoint_exists(c)] 65 | 66 | if not checkpoints: 67 | raise ValueError( 68 | "None of the provided checkpoints exist. %s" % FLAGS.checkpoints 69 | ) 70 | 71 | var_list = tf.contrib.framework.list_variables(checkpoints[0]) 72 | var_values, var_dtypes = {}, {} 73 | 74 | for (name, shape) in var_list: 75 | if not name.startswith("global_step"): 76 | var_values[name] = np.zeros(shape) 77 | 78 | for checkpoint in checkpoints: 79 | reader = tf.contrib.framework.load_checkpoint(checkpoint) 80 | for name in var_values: 81 | tensor = reader.get_tensor(name) 82 | var_dtypes[name] = tensor.dtype 83 | var_values[name] += tensor 84 | tf.logging.info("Read from checkpoint %s", checkpoint) 85 | 86 | # Average checkpoints 87 | for name in var_values: 88 | var_values[name] /= len(checkpoints) 89 | 90 | tf_vars = [ 91 | tf.get_variable(name, shape=var_values[name].shape, 92 | dtype=var_dtypes[name]) for name in var_values 93 | ] 94 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 95 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 96 | global_step = tf.Variable(0, name="global_step", trainable=False, 97 | dtype=tf.int64) 98 | saver = tf.train.Saver(tf.global_variables()) 99 | 100 | with tf.Session() as sess: 101 | sess.run(tf.global_variables_initializer()) 102 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 103 | var_values.iteritems()): 104 | sess.run(assign_op, {p: value}) 105 | saved_name = os.path.join(FLAGS.output, "average") 106 | saver.save(sess, saved_name, global_step=global_step) 107 | 108 | tf.logging.info("Averaged checkpoints saved in %s", saved_name) 109 | 110 | params_pattern = os.path.join(FLAGS.path, "*.json") 111 | params_files = tf.gfile.Glob(params_pattern) 112 | 113 | for name in params_files: 114 | new_name = name.replace(FLAGS.path.rstrip("/"), 115 | FLAGS.output.rstrip("/")) 116 | tf.gfile.Copy(name, new_name, overwrite=True) 117 | 118 | 119 | if __name__ == "__main__": 120 | FLAGS = parseargs() 121 | tf.app.run() 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # XMUNMT 2 | An open source Neural Machine Translation toolkit developed by the NLPLAB of Xiamen University. 3 | 4 | ## Features 5 | * Multi-GPU support 6 | * Builtin validation functionality 7 | 8 | 9 | ## Tutorial 10 | This tutorial describes how to train an NMT model on WMT17's EN-DE data using this repository. 11 | 12 | ### Prerequisite 13 | You must install TensorFlow (>=1.4.0) first to use this library. 14 | 15 | ### Download Data 16 | The preprocessed data can be found at 17 | [here](http://data.statmt.org/wmt17/translation-task/preprocessed/de-en/). 18 | 19 | ### Data Preprocessing 20 | 1. Byte Pair Encoding 21 | * The most common approach to achieve open vocabulary is to use Byte Pair Encoding (BPE). The codes of BPE can be found at [here](https://github.com/rsennrich/subword-nmt). 22 | * To encode the training corpora using BPE, you need to generate BPE operations first. The following command will create a file named "bpe32k", which contains 32k BPE operations along with two dictionaries named "vocab.en" and "vocab.de". 23 | ``` 24 | python subword-nmt/learn_joint_bpe_and_vocab.py --input corpus.tc.en corpus.tc.de -s 32000 -o bpe32k --write-vocabulary vocab.en vocab.de 25 | ``` 26 | * You still need to encode the training corpora, validation set and test set using the generated BPE operations and dictionaries. 27 | ``` 28 | python subword-nmt/apply_bpe.py -c bpe32k --vocabulary vocab.en --vocabulary-threshold 50 < corpus.tc.en > corpus.bpe32k.en 29 | python subword-nmt/apply_bpe.py -c bpe32k --vocabulary vocab.de --vocabulary-threshold 50 < corpus.tc.de > corpus.bpe32k.de 30 | python subword-nmt/apply_bpe.py -c bpe32k --vocabulary vocab.en --vocabulary-threshold 50 < newstest2016.tc.en > newstest2016.bpe32k.en 31 | python subword-nmt/apply_bpe.py -c bpe32k --vocabulary vocab.de --vocabulary-threshold 50 < newstest2016.tc.de > newstest2016.bpe32k.de 32 | python subword-nmt/apply_bpe.py -c bpe32k --vocabulary vocab.en --vocabulary-threshold 50 < newstest2017.tc.en > newstest2017.bpe32k.en 33 | ``` 34 | 35 | 2. Environment Variables 36 | * Before using XMUNMT, you need to add the path of XMUNMT to PYTHONPATH environment variable. Typically, this can be done by adding the following line to the .bashrc file in your home directory. 37 | ``` 38 | PYTHONPATH=/PATH/TO/XMUNMT:$PYTHONPATH 39 | ``` 40 | 41 | 2. Build vocabulary 42 | * To train an NMT, you need to build vocabularies first. To build a shared source and target vocabulary, you can use the following script: 43 | ``` 44 | cat corpus.bpe32k.en corpus.bpe32k.de > corpus.bpe32k.all 45 | python XMUNMT/xmunmt/scripts/build_vocab.py corpus.bpe32k.all vocab.shared32k.txt 46 | ``` 47 | 3. Shuffle corpus 48 | * It is beneficial to shuffle the training corpora before training. 49 | ``` 50 | python XMUNMT/xmunmt/scripts/shuffle_corpus.py --corpus corpus.bpe32k.en corpus.bpe32k.de --seed 1234 51 | ``` 52 | * The above command will create two new files named "corpus.bpe32k.en.shuf" and "corpus.bpe32k.de.shuf". 53 | 54 | ### Training 55 | * Finally, we can start the training stage. The recommended hyper-parameters are described below. 56 | ``` 57 | python XMUNMT/xmunmt/bin/trainer.py 58 | --model rnnsearch 59 | --output train 60 | --input corpus.bpe32k.en.shuf corpus.bpe32k.de.shuf 61 | --vocabulary vocab.shared32k.txt vocab.shared32k.txt 62 | --validation newstest2016.bpe32k.en 63 | --references newstest2016.bpe32k.de 64 | --parameters=device_list=[0],eval_steps=5000,train_steps=75000, 65 | learning_rate_decay=piecewise_constant, 66 | learning_rate_values=[5e-4,25e-5,125e-6], 67 | learning_rate_boundaries=[25000,50000] 68 | ``` 69 | * Change the argument of "device_list" to select GPU or use multiple GPUs. The above command will create a directory named "train". 70 | The best model can be found at "train/eval" 71 | 72 | ### Decoding 73 | * The decoding command is quite simple. 74 | ``` 75 | python XMUNMT/xmunmt/bin/translator.py 76 | --models rnnsearch 77 | --checkpoints train/eval 78 | --input newstest2017.bpe32k.en 79 | --output test.txt 80 | --vocabulary vocab.shared32k.txt vocab.shared32k.txt 81 | ``` 82 | 83 | ## Benchmark 84 | The benchmark is performed on 1 GTX 1080Ti GPU with default parameters. 85 | 86 | | Dataset | BLEU | BLEU (cased) | 87 | | :-----------: | :--------: | :-----------: | 88 | | WMT17 En-De | 22.81 | 22.30 | 89 | | WMT17 De-En | 29.01 | 27.69 | 90 | 91 | * More benchmarks will be added soon. 92 | 93 | 94 | ## Contact 95 | This code is written by Zhixing Tan. If you have any problems, feel free to send an email. 96 | 97 | ## LICENSE 98 | BSD 99 | -------------------------------------------------------------------------------- /xmunmt/data/record.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | # Disclaimer: Part of this code is modified from the Tensor2Tensor library 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import math 12 | 13 | import numpy as np 14 | import six 15 | import tensorflow as tf 16 | from tensorflow.contrib.slim import parallel_reader, tfexample_decoder 17 | 18 | 19 | def input_pipeline(file_pattern, mode, capacity=64): 20 | keys_to_features = { 21 | "inputs": tf.VarLenFeature(tf.int64), 22 | "targets": tf.VarLenFeature(tf.int64) 23 | } 24 | 25 | items_to_handlers = { 26 | "inputs": tfexample_decoder.Tensor("inputs"), 27 | "targets": tfexample_decoder.Tensor("targets") 28 | } 29 | 30 | # Now the non-trivial case construction. 31 | with tf.name_scope("examples_queue"): 32 | training = (mode == "train") 33 | # Read serialized examples using slim parallel_reader. 34 | num_epochs = None if training else 1 35 | data_files = parallel_reader.get_data_files(file_pattern) 36 | num_readers = min(4 if training else 1, len(data_files)) 37 | _, examples = parallel_reader.parallel_read([file_pattern], 38 | tf.TFRecordReader, 39 | num_epochs=num_epochs, 40 | shuffle=training, 41 | capacity=2 * capacity, 42 | min_after_dequeue=capacity, 43 | num_readers=num_readers) 44 | 45 | decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, 46 | items_to_handlers) 47 | 48 | decoded = decoder.decode(examples, items=list(items_to_handlers)) 49 | examples = {} 50 | 51 | for (field, tensor) in zip(keys_to_features, decoded): 52 | examples[field] = tensor 53 | 54 | # We do not want int64s as they do are not supported on GPUs. 55 | return {k: tf.to_int32(v) for (k, v) in six.iteritems(examples)} 56 | 57 | 58 | def batch_examples(examples, batch_size, max_length, mantissa_bits, 59 | shard_multiplier=1, length_multiplier=1, scheme="token", 60 | drop_long_sequences=True): 61 | with tf.name_scope("batch_examples"): 62 | max_length = max_length or batch_size 63 | min_length = 8 64 | mantissa_bits = mantissa_bits 65 | 66 | # compute boundaries 67 | x = min_length 68 | boundaries = [] 69 | 70 | while x < max_length: 71 | boundaries.append(x) 72 | x += 2 ** max(0, int(math.log(x, 2)) - mantissa_bits) 73 | 74 | if scheme is "token": 75 | batch_sizes = [max(1, batch_size // length) 76 | for length in boundaries + [max_length]] 77 | batch_sizes = [b * shard_multiplier for b in batch_sizes] 78 | bucket_capacities = [2 * b for b in batch_sizes] 79 | else: 80 | batch_sizes = batch_size * shard_multiplier 81 | bucket_capacities = [2 * n for n in boundaries + [max_length]] 82 | 83 | max_length *= length_multiplier 84 | boundaries = [boundary * length_multiplier for boundary in boundaries] 85 | max_length = max_length if drop_long_sequences else 10 ** 9 86 | 87 | # The queue to bucket on will be chosen based on maximum length. 88 | max_example_length = 0 89 | for v in examples.values(): 90 | seq_length = tf.shape(v)[0] 91 | max_example_length = tf.maximum(max_example_length, seq_length) 92 | 93 | (_, outputs) = tf.contrib.training.bucket_by_sequence_length( 94 | max_example_length, 95 | examples, 96 | batch_sizes, 97 | [b + 1 for b in boundaries], 98 | capacity=2, 99 | bucket_capacities=bucket_capacities, 100 | dynamic_pad=True, 101 | keep_input=(max_example_length <= max_length) 102 | ) 103 | 104 | return outputs 105 | 106 | 107 | def get_input_features(file_patterns, mode, params): 108 | with tf.name_scope("input_queues"): 109 | with tf.device("/cpu:0"): 110 | if mode != "train": 111 | num_datashards = 1 112 | batch_size = params.eval_batch_size 113 | else: 114 | num_datashards = len(params.device_list) 115 | batch_size = params.batch_size 116 | 117 | batch_size_multiplier = 1 118 | capacity = 64 * num_datashards 119 | examples = input_pipeline(file_patterns, mode, capacity) 120 | drop_long_sequences = (mode == "train") 121 | 122 | feature_map = batch_examples( 123 | examples, 124 | batch_size, 125 | params.max_length, 126 | params.mantissa_bits, 127 | num_datashards, 128 | batch_size_multiplier, 129 | "token" if not params.constant_batch_size else "constant", 130 | drop_long_sequences 131 | ) 132 | 133 | # Final feature map. 134 | features = { 135 | "source": feature_map["inputs"], 136 | "target": feature_map["targets"], 137 | "source_length": tf.to_int32( 138 | tf.reduce_sum( 139 | tf.to_float(tf.not_equal(feature_map["inputs"], 0)), 140 | axis=1 141 | ) 142 | ), 143 | "target_length": tf.to_int32( 144 | tf.reduce_sum( 145 | tf.to_float(tf.not_equal(feature_map["targets"], 0)), 146 | axis=1 147 | ) 148 | ), 149 | } 150 | 151 | return features 152 | -------------------------------------------------------------------------------- /xmunmt/layers/nn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import tensorflow as tf 11 | 12 | 13 | def linear(inputs, output_size, bias, concat=False, dtype=None, scope=None): 14 | """ 15 | Linear layer 16 | 17 | Args: 18 | inputs: A Tensor or a list of Tensors with shape [batch, input_size] 19 | output_size: An integer specify the output size 20 | bias: a boolean value indicate whether to use bias term 21 | concat: a boolean value indicate whether to concatenate all inputs 22 | dtype: an instance of tf.DType, the default value is ``tf.float32'' 23 | scope: the scope of this layer, the default value is ``linear'' 24 | 25 | Returns: 26 | a Tensor with shape [batch, output_size] 27 | 28 | Raises: 29 | RuntimeError: raises ``RuntimeError'' when input sizes do not 30 | compatible with each other 31 | """ 32 | 33 | with tf.variable_scope(scope, default_name="linear", values=[inputs]): 34 | if not isinstance(inputs, (list, tuple)): 35 | inputs = [inputs] 36 | 37 | input_size = [item.get_shape()[-1].value for item in inputs] 38 | 39 | if len(inputs) != len(input_size): 40 | raise RuntimeError("inputs and input_size unmatched!") 41 | 42 | output_shape = tf.concat([tf.shape(inputs[0])[:-1], [output_size]], 43 | axis=0) 44 | # Flatten to 2D 45 | inputs = [tf.reshape(inp, [-1, inp.shape[-1].value]) for inp in inputs] 46 | 47 | results = [] 48 | 49 | if concat: 50 | input_size = sum(input_size) 51 | inputs = tf.concat(inputs, 1) 52 | 53 | shape = [input_size, output_size] 54 | matrix = tf.get_variable("matrix", shape, dtype=dtype) 55 | results.append(tf.matmul(inputs, matrix)) 56 | else: 57 | for i in range(len(input_size)): 58 | shape = [input_size[i], output_size] 59 | name = "matrix_%d" % i 60 | matrix = tf.get_variable(name, shape, dtype=dtype) 61 | results.append(tf.matmul(inputs[i], matrix)) 62 | 63 | output = tf.add_n(results) 64 | 65 | if bias: 66 | shape = [output_size] 67 | bias = tf.get_variable("bias", shape, dtype=dtype) 68 | output = tf.nn.bias_add(output, bias) 69 | 70 | output = tf.reshape(output, output_shape) 71 | 72 | return output 73 | 74 | 75 | def maxout(inputs, output_size, maxpart=2, use_bias=True, concat=True, 76 | dtype=None, scope=None): 77 | """ Maxout layer 78 | Args: 79 | inputs: see the corresponding description of ``linear'' 80 | output_size: see the corresponding description of ``linear'' 81 | maxpart: an integer, the default value is 2 82 | use_bias: a boolean value indicate whether to use bias term 83 | scope: the scope of this layer, the default value is ``maxout'' 84 | 85 | Returns: 86 | a Tensor with shape [batch, output_size] 87 | 88 | Raises: 89 | RuntimeError: see the corresponding description of ``linear'' 90 | """ 91 | 92 | candidate = linear(inputs, output_size * maxpart, use_bias, concat, 93 | dtype=dtype, scope=scope or "maxout") 94 | shape = tf.concat([tf.shape(candidate)[:-1], [output_size, maxpart]], 95 | axis=0) 96 | value = tf.reshape(candidate, shape) 97 | output = tf.reduce_max(value, -1) 98 | 99 | return output 100 | 101 | 102 | def layer_norm(inputs, epsilon=1e-6, dtype=None, scope=None): 103 | """ Layer Normalization 104 | 105 | Args: 106 | inputs: A Tensor of shape [..., channel_size] 107 | epsilon: A floating number 108 | dtype: An optional instance of tf.DType 109 | scope: An optional string 110 | 111 | Returns: 112 | A Tensor with the same shape as inputs 113 | """ 114 | with tf.variable_scope(scope, default_name="layer_norm", values=[inputs], 115 | dtype=dtype): 116 | channel_size = inputs.get_shape().as_list()[-1] 117 | 118 | scale = tf.get_variable("scale", shape=[channel_size], 119 | initializer=tf.ones_initializer()) 120 | 121 | offset = tf.get_variable("offset", shape=[channel_size], 122 | initializer=tf.zeros_initializer()) 123 | 124 | mean = tf.reduce_mean(inputs, axis=-1, keep_dims=True) 125 | variance = tf.reduce_mean(tf.square(inputs - mean), axis=-1, 126 | keep_dims=True) 127 | 128 | norm_inputs = (inputs - mean) * tf.rsqrt(variance + epsilon) 129 | 130 | return norm_inputs * scale + offset 131 | 132 | 133 | def smoothed_softmax_cross_entropy_with_logits(**kwargs): 134 | logits = kwargs.get("logits") 135 | labels = kwargs.get("labels") 136 | smoothing = kwargs.get("smoothing") or 0.0 137 | normalize = kwargs.get("normalize") 138 | scope = kwargs.get("scope") 139 | 140 | if logits is None or labels is None: 141 | raise ValueError("Both logits and labels must be provided") 142 | 143 | with tf.name_scope(scope or "smoothed_softmax_cross_entropy_with_logits", 144 | values=[logits, labels]): 145 | 146 | labels = tf.reshape(labels, [-1]) 147 | 148 | if not smoothing: 149 | ce = tf.nn.sparse_softmax_cross_entropy_with_logits( 150 | logits=logits, 151 | labels=labels 152 | ) 153 | return ce 154 | 155 | # label smoothing 156 | vocab_size = tf.shape(logits)[1] 157 | 158 | n = tf.to_float(vocab_size - 1) 159 | p = 1.0 - smoothing 160 | q = smoothing / n 161 | 162 | soft_targets = tf.one_hot(tf.cast(labels, tf.int32), depth=vocab_size, 163 | on_value=p, off_value=q) 164 | xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, 165 | labels=soft_targets) 166 | 167 | if normalize is False: 168 | return xentropy 169 | 170 | # Normalizing constant is the best cross-entropy value with soft 171 | # targets. We subtract it just for readability, makes no difference on 172 | # learning 173 | normalizing = -(p * tf.log(p) + n * q * tf.log(q + 1e-20)) 174 | 175 | return xentropy - normalizing 176 | -------------------------------------------------------------------------------- /xmunmt/bin/translator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 4 | # Author: Zhixing Tan 5 | # Contact: playinf@stu.xmu.edu.cn 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import argparse 12 | import itertools 13 | import os 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | import xmunmt.data.dataset as dataset 18 | import xmunmt.data.vocab as vocabulary 19 | import xmunmt.models as models 20 | import xmunmt.utils.search as search 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser( 25 | description="Translate using existing NMT models", 26 | usage="translator.py [] [-h | --help]" 27 | ) 28 | 29 | # input files 30 | parser.add_argument("--input", type=str, required=True, 31 | help="Path of input file") 32 | parser.add_argument("--output", type=str, required=True, 33 | help="Path of output file") 34 | parser.add_argument("--checkpoints", type=str, nargs="+", required=True, 35 | help="Path of trained models") 36 | parser.add_argument("--vocabulary", type=str, nargs=2, required=True, 37 | help="Path of source and target vocabulary") 38 | 39 | # model and configuration 40 | parser.add_argument("--models", type=str, required=True, nargs="+", 41 | help="Name of the model") 42 | parser.add_argument("--parameters", type=str, 43 | help="Additional hyper parameters") 44 | 45 | return parser.parse_args() 46 | 47 | 48 | def default_parameters(): 49 | params = tf.contrib.training.HParams( 50 | input=None, 51 | output=None, 52 | vocabulary=None, 53 | model=None, 54 | # vocabulary specific 55 | pad="", 56 | bos="", 57 | eos="", 58 | unk="", 59 | mapping=None, 60 | append_eos=False, 61 | # decoding 62 | top_beams=1, 63 | beam_size=4, 64 | decode_alpha=0.6, 65 | decode_length=50, 66 | decode_batch_size=32, 67 | decode_constant=5.0, 68 | decode_normalize=False, 69 | device_list=[0], 70 | num_threads=6 71 | ) 72 | 73 | return params 74 | 75 | 76 | def merge_parameters(params1, params2): 77 | params = tf.contrib.training.HParams() 78 | 79 | for (k, v) in params1.values().iteritems(): 80 | params.add_hparam(k, v) 81 | 82 | params_dict = params.values() 83 | 84 | for (k, v) in params2.values().iteritems(): 85 | if k in params_dict: 86 | # Override 87 | setattr(params, k, v) 88 | else: 89 | params.add_hparam(k, v) 90 | 91 | return params 92 | 93 | 94 | def import_params(model_dir, model_name, params): 95 | model_dir = os.path.abspath(model_dir) 96 | m_name = os.path.join(model_dir, model_name + ".json") 97 | 98 | if not tf.gfile.Exists(m_name): 99 | return params 100 | 101 | with tf.gfile.Open(m_name) as fd: 102 | tf.logging.info("Restoring model parameters from %s" % m_name) 103 | json_str = fd.readline() 104 | params.parse_json(json_str) 105 | 106 | return params 107 | 108 | 109 | def override_parameters(params, args): 110 | if args.parameters: 111 | params.parse(args.parameters) 112 | 113 | params.vocabulary = { 114 | "source": vocabulary.load_vocabulary(args.vocabulary[0]), 115 | "target": vocabulary.load_vocabulary(args.vocabulary[1]) 116 | } 117 | params.vocabulary["source"] = vocabulary.process_vocabulary( 118 | params.vocabulary["source"], params 119 | ) 120 | params.vocabulary["target"] = vocabulary.process_vocabulary( 121 | params.vocabulary["target"], params 122 | ) 123 | 124 | control_symbols = [params.pad, params.bos, params.eos, params.unk] 125 | 126 | params.mapping = { 127 | "source": vocabulary.get_control_mapping( 128 | params.vocabulary["source"], 129 | control_symbols 130 | ), 131 | "target": vocabulary.get_control_mapping( 132 | params.vocabulary["target"], 133 | control_symbols 134 | ) 135 | } 136 | 137 | return params 138 | 139 | 140 | def session_config(params): 141 | optimizer_options = tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L1, 142 | do_function_inlining=False) 143 | graph_options = tf.GraphOptions(optimizer_options=optimizer_options) 144 | config = tf.ConfigProto(allow_soft_placement=True, 145 | graph_options=graph_options) 146 | if params.device_list: 147 | device_str = ",".join([str(i) for i in params.device_list]) 148 | config.gpu_options.visible_device_list = device_str 149 | 150 | return config 151 | 152 | 153 | def set_variables(var_list, value_dict, prefix): 154 | ops = [] 155 | for var in var_list: 156 | for name in value_dict: 157 | var_name = "/".join([prefix] + list(name.split("/")[1:])) 158 | 159 | if var.name[:-2] == var_name: 160 | tf.logging.info("restoring %s -> %s" % (name, var.name)) 161 | with tf.device("/cpu:0"): 162 | op = tf.assign(var, value_dict[name]) 163 | ops.append(op) 164 | break 165 | 166 | return ops 167 | 168 | 169 | def main(args): 170 | tf.logging.set_verbosity(tf.logging.INFO) 171 | # Load configs 172 | model_cls_list = [models.get_model(model) for model in args.models] 173 | params_list = [default_parameters() for _ in range(len(model_cls_list))] 174 | params_list = [ 175 | merge_parameters(params, model_cls.get_parameters()) 176 | for params, model_cls in zip(params_list, model_cls_list) 177 | ] 178 | params_list = [ 179 | import_params(args.checkpoints[i], args.models[i], params_list[i]) 180 | for i in range(len(args.checkpoints)) 181 | ] 182 | params_list = [ 183 | override_parameters(params_list[i], args) 184 | for i in range(len(model_cls_list)) 185 | ] 186 | 187 | # Build Graph 188 | with tf.Graph().as_default(): 189 | model_var_lists = [] 190 | 191 | # Load checkpoints 192 | for i, checkpoint in enumerate(args.checkpoints): 193 | print("Loading %s" % checkpoint) 194 | var_list = tf.train.list_variables(checkpoint) 195 | values = {} 196 | reader = tf.train.load_checkpoint(checkpoint) 197 | 198 | for (name, shape) in var_list: 199 | if not name.startswith(model_cls_list[i].get_name()): 200 | continue 201 | 202 | if name.find("losses_avg") >= 0: 203 | continue 204 | 205 | tensor = reader.get_tensor(name) 206 | values[name] = tensor 207 | 208 | model_var_lists.append(values) 209 | 210 | # Build models 211 | model_fns = [] 212 | 213 | for i in range(len(args.checkpoints)): 214 | name = model_cls_list[i].get_name() 215 | model = model_cls_list[i](params_list[i], name + "_%d" % i) 216 | model_fn = model.get_inference_func() 217 | model_fns.append(model_fn) 218 | 219 | params = params_list[0] 220 | # Read input file 221 | sorted_keys, sorted_inputs = dataset.sort_input_file(args.input) 222 | # Build input queue 223 | features = dataset.get_inference_input(sorted_inputs, params) 224 | predictions = search.create_inference_graph(model_fns, features, 225 | params) 226 | 227 | assign_ops = [] 228 | 229 | all_var_list = tf.trainable_variables() 230 | 231 | for i in range(len(args.checkpoints)): 232 | un_init_var_list = [] 233 | name = model_cls_list[i].get_name() 234 | 235 | for v in all_var_list: 236 | if v.name.startswith(name + "_%d" % i): 237 | un_init_var_list.append(v) 238 | 239 | ops = set_variables(un_init_var_list, model_var_lists[i], 240 | name + "_%d" % i) 241 | assign_ops.extend(ops) 242 | 243 | assign_op = tf.group(*assign_ops) 244 | 245 | sess_creator = tf.train.ChiefSessionCreator( 246 | config=session_config(params) 247 | ) 248 | 249 | results = [] 250 | 251 | # Create session 252 | with tf.train.MonitoredSession(session_creator=sess_creator) as sess: 253 | # Restore variables 254 | sess.run(assign_op) 255 | 256 | while not sess.should_stop(): 257 | results.append(sess.run(predictions)) 258 | message = "Finished batch %d" % len(results) 259 | tf.logging.log(tf.logging.INFO, message) 260 | 261 | # Convert to plain text 262 | vocab = params.vocabulary["target"] 263 | outputs = [] 264 | 265 | for result in results: 266 | outputs.append(result.tolist()) 267 | 268 | outputs = list(itertools.chain(*outputs)) 269 | 270 | restored_outputs = [] 271 | 272 | for index in range(len(sorted_inputs)): 273 | restored_outputs.append(outputs[sorted_keys[index]]) 274 | 275 | # Write to file 276 | with open(args.output, "w") as outfile: 277 | for output in restored_outputs: 278 | decoded = [] 279 | for idx in output: 280 | if idx == params.mapping["target"][params.eos]: 281 | break 282 | decoded.append(vocab[idx]) 283 | 284 | decoded = " ".join(decoded) 285 | outfile.write("%s\n" % decoded) 286 | 287 | 288 | if __name__ == "__main__": 289 | main(parse_args()) 290 | -------------------------------------------------------------------------------- /xmunmt/data/dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import math 11 | import operator 12 | 13 | import numpy as np 14 | import tensorflow as tf 15 | 16 | 17 | def batch_examples(example, batch_size, max_length, mantissa_bits, 18 | shard_multiplier=1, length_multiplier=1, constant=False, 19 | num_threads=4, drop_long_sequences=True): 20 | """ Batch examples 21 | Args: 22 | example: A dictionary of . 23 | batch_size: The number of tokens or sentences in a batch 24 | max_length: The maximum length of a example to keep 25 | mantissa_bits: An integer 26 | shard_multiplier: an integer increasing the batch_size to suit 27 | splitting across data shards. 28 | length_multiplier: an integer multiplier that is used to 29 | increase the batch sizes and sequence length tolerance. 30 | constant: Whether to use constant batch size 31 | num_threads: Number of threads 32 | drop_long_sequences: Whether to drop long sequences 33 | 34 | Returns: 35 | A dictionary of batched examples 36 | """ 37 | 38 | with tf.name_scope("batch_examples"): 39 | max_length = max_length or batch_size 40 | min_length = 8 41 | mantissa_bits = mantissa_bits 42 | 43 | # Compute boundaries 44 | x = min_length 45 | boundaries = [] 46 | 47 | while x < max_length: 48 | boundaries.append(x) 49 | x += 2 ** max(0, int(math.log(x, 2)) - mantissa_bits) 50 | 51 | # Whether the batch size is constant 52 | if not constant: 53 | batch_sizes = [max(1, batch_size // length) 54 | for length in boundaries + [max_length]] 55 | batch_sizes = [b * shard_multiplier for b in batch_sizes] 56 | bucket_capacities = [2 * b for b in batch_sizes] 57 | else: 58 | batch_sizes = batch_size * shard_multiplier 59 | bucket_capacities = [2 * n for n in boundaries + [max_length]] 60 | 61 | max_length *= length_multiplier 62 | boundaries = [boundary * length_multiplier for boundary in boundaries] 63 | max_length = max_length if drop_long_sequences else 10 ** 9 64 | 65 | # The queue to bucket on will be chosen based on maximum length 66 | max_example_length = 0 67 | for v in example.values(): 68 | if v.shape.ndims > 0: 69 | seq_length = tf.shape(v)[0] 70 | max_example_length = tf.maximum(max_example_length, seq_length) 71 | 72 | (_, outputs) = tf.contrib.training.bucket_by_sequence_length( 73 | max_example_length, 74 | example, 75 | batch_sizes, 76 | [b + 1 for b in boundaries], 77 | num_threads=num_threads, 78 | capacity=2, # Number of full batches to store, we don't need many. 79 | bucket_capacities=bucket_capacities, 80 | dynamic_pad=True, 81 | keep_input=(max_example_length <= max_length) 82 | ) 83 | 84 | return outputs 85 | 86 | 87 | def get_training_input(filenames, params): 88 | """ Get input for training stage 89 | Args: 90 | filenames: A list contains [source_filename, target_filename] 91 | params: Hyper-parameters 92 | 93 | Returns 94 | A dictionary of pair 95 | """ 96 | 97 | with tf.device("/cpu:0"): 98 | src_dataset = tf.data.TextLineDataset(filenames[0]) 99 | tgt_dataset = tf.data.TextLineDataset(filenames[1]) 100 | 101 | dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset)) 102 | dataset = dataset.shuffle(params.buffer_size) 103 | dataset = dataset.repeat() 104 | 105 | # Split string 106 | dataset = dataset.map( 107 | lambda src, tgt: ( 108 | tf.string_split([src]).values, 109 | tf.string_split([tgt]).values 110 | ), 111 | num_parallel_calls=params.num_threads 112 | ) 113 | 114 | # Append symbol 115 | dataset = dataset.map( 116 | lambda src, tgt: ( 117 | tf.concat([src, [tf.constant(params.eos)]], axis=0), 118 | tf.concat([tgt, [tf.constant(params.eos)]], axis=0) 119 | ), 120 | num_parallel_calls=params.num_threads 121 | ) 122 | 123 | # Convert to dictionary 124 | dataset = dataset.map( 125 | lambda src, tgt: { 126 | "source": src, 127 | "target": tgt, 128 | "source_length": tf.shape(src), 129 | "target_length": tf.shape(tgt) 130 | }, 131 | num_parallel_calls=params.num_threads 132 | ) 133 | 134 | # Create iterator 135 | iterator = dataset.make_one_shot_iterator() 136 | features = iterator.get_next() 137 | 138 | # Create lookup table 139 | src_table = tf.contrib.lookup.index_table_from_tensor( 140 | tf.constant(params.vocabulary["source"]), 141 | default_value=params.mapping["source"][params.unk] 142 | ) 143 | tgt_table = tf.contrib.lookup.index_table_from_tensor( 144 | tf.constant(params.vocabulary["target"]), 145 | default_value=params.mapping["target"][params.unk] 146 | ) 147 | 148 | # String to index lookup 149 | features["source"] = src_table.lookup(features["source"]) 150 | features["target"] = tgt_table.lookup(features["target"]) 151 | 152 | # Batching 153 | features = batch_examples(features, params.batch_size, 154 | params.max_length, params.mantissa_bits, 155 | shard_multiplier=len(params.device_list), 156 | length_multiplier=params.length_multiplier, 157 | constant=params.constant_batch_size, 158 | num_threads=params.num_threads) 159 | 160 | # Convert to int32 161 | features["source"] = tf.to_int32(features["source"]) 162 | features["target"] = tf.to_int32(features["target"]) 163 | features["source_length"] = tf.to_int32(features["source_length"]) 164 | features["target_length"] = tf.to_int32(features["target_length"]) 165 | features["source_length"] = tf.squeeze(features["source_length"], 1) 166 | features["target_length"] = tf.squeeze(features["target_length"], 1) 167 | 168 | return features 169 | 170 | 171 | def sort_input_file(filename, reverse=True): 172 | # Read file 173 | with tf.gfile.Open(filename) as fd: 174 | inputs = [line.strip() for line in fd] 175 | 176 | input_lens = [ 177 | (i, len(line.strip().split())) for i, line in enumerate(inputs) 178 | ] 179 | 180 | sorted_input_lens = sorted(input_lens, key=operator.itemgetter(1), 181 | reverse=reverse) 182 | sorted_keys = {} 183 | sorted_inputs = [] 184 | 185 | for i, (index, _) in enumerate(sorted_input_lens): 186 | sorted_inputs.append(inputs[index]) 187 | sorted_keys[index] = i 188 | 189 | return sorted_keys, sorted_inputs 190 | 191 | 192 | def sort_and_zip_files(names): 193 | inputs = [] 194 | input_lens = [] 195 | files = [tf.gfile.GFile(name) for name in names] 196 | 197 | count = 0 198 | 199 | for lines in zip(*files): 200 | lines = [line.strip() for line in lines] 201 | input_lens.append((count, len(lines[0].split()))) 202 | inputs.append(lines) 203 | count += 1 204 | 205 | # Close files 206 | for fd in files: 207 | fd.close() 208 | 209 | sorted_input_lens = sorted(input_lens, key=operator.itemgetter(1), 210 | reverse=True) 211 | sorted_inputs = [] 212 | 213 | for i, (index, _) in enumerate(sorted_input_lens): 214 | sorted_inputs.append(inputs[index]) 215 | 216 | return [list(x) for x in zip(*sorted_inputs)] 217 | 218 | 219 | def get_evaluation_input(inputs, params): 220 | with tf.device("/cpu:0"): 221 | # Create datasets 222 | datasets = [] 223 | 224 | for data in inputs: 225 | dataset = tf.data.Dataset.from_tensor_slices(data) 226 | # Split string 227 | dataset = dataset.map(lambda x: tf.string_split([x]).values, 228 | num_parallel_calls=params.num_threads) 229 | # Append 230 | dataset = dataset.map( 231 | lambda x: tf.concat([x, [tf.constant(params.eos)]], axis=0), 232 | num_parallel_calls=params.num_threads 233 | ) 234 | datasets.append(dataset) 235 | 236 | dataset = tf.data.Dataset.zip(tuple(datasets)) 237 | 238 | # Convert tuple to dictionary 239 | dataset = dataset.map( 240 | lambda *x: { 241 | "source": x[0], 242 | "source_length": tf.shape(x[0])[0], 243 | "references": x[1:] 244 | }, 245 | num_parallel_calls=params.num_threads 246 | ) 247 | 248 | dataset = dataset.padded_batch( 249 | params.eval_batch_size, 250 | { 251 | "source": [tf.Dimension(None)], 252 | "source_length": [], 253 | "references": (tf.Dimension(None),) * (len(inputs) - 1) 254 | }, 255 | { 256 | "source": params.pad, 257 | "source_length": 0, 258 | "references": (params.pad,) * (len(inputs) - 1) 259 | } 260 | ) 261 | 262 | iterator = dataset.make_one_shot_iterator() 263 | features = iterator.get_next() 264 | 265 | src_table = tf.contrib.lookup.index_table_from_tensor( 266 | tf.constant(params.vocabulary["source"]), 267 | default_value=params.mapping["source"][params.unk] 268 | ) 269 | tgt_table = tf.contrib.lookup.index_table_from_tensor( 270 | tf.constant(params.vocabulary["target"]), 271 | default_value=params.mapping["target"][params.unk] 272 | ) 273 | features["source"] = src_table.lookup(features["source"]) 274 | features["references"] = tuple( 275 | tgt_table.lookup(item) for item in features["references"] 276 | ) 277 | 278 | return features 279 | 280 | 281 | def get_inference_input(inputs, params): 282 | dataset = tf.data.Dataset.from_tensor_slices( 283 | tf.constant(inputs) 284 | ) 285 | 286 | # Split string 287 | dataset = dataset.map(lambda x: tf.string_split([x]).values, 288 | num_parallel_calls=params.num_threads) 289 | 290 | # Append 291 | dataset = dataset.map( 292 | lambda x: tf.concat([x, [tf.constant(params.eos)]], axis=0), 293 | num_parallel_calls=params.num_threads 294 | ) 295 | 296 | # Convert tuple to dictionary 297 | dataset = dataset.map( 298 | lambda x: {"source": x, "source_length": tf.shape(x)[0]}, 299 | num_parallel_calls=params.num_threads 300 | ) 301 | 302 | dataset = dataset.padded_batch( 303 | params.decode_batch_size, 304 | {"source": [tf.Dimension(None)], "source_length": []}, 305 | {"source": params.pad, "source_length": 0} 306 | ) 307 | 308 | iterator = dataset.make_one_shot_iterator() 309 | features = iterator.get_next() 310 | 311 | src_table = tf.contrib.lookup.index_table_from_tensor( 312 | tf.constant(params.vocabulary["source"]), 313 | default_value=params.mapping["source"][params.unk] 314 | ) 315 | features["source"] = src_table.lookup(features["source"]) 316 | 317 | return features 318 | -------------------------------------------------------------------------------- /xmunmt/utils/hooks.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import operator 11 | import os 12 | 13 | import tensorflow as tf 14 | import xmunmt.utils.bleu as bleu 15 | 16 | 17 | def _get_saver(): 18 | # Get saver from the SAVERS collection if present. 19 | collection_key = tf.GraphKeys.SAVERS 20 | savers = tf.get_collection(collection_key) 21 | 22 | if not savers: 23 | raise RuntimeError("No items in collection {}. " 24 | "Please add a saver to the collection ") 25 | elif len(savers) > 1: 26 | raise RuntimeError("More than one item in collection") 27 | 28 | return savers[0] 29 | 30 | 31 | def _read_checkpoint_def(filename): 32 | records = [] 33 | 34 | with tf.gfile.GFile(filename) as fd: 35 | fd.readline() 36 | 37 | for line in fd: 38 | records.append(line.strip().split(":")[-1].strip()[1:-1]) 39 | 40 | return records 41 | 42 | 43 | def _save_checkpoint_def(filename, checkpoint_names): 44 | keys = [] 45 | 46 | for checkpoint_name in checkpoint_names: 47 | step = int(checkpoint_name.strip().split("-")[-1]) 48 | keys.append((step, checkpoint_name)) 49 | 50 | sorted_names = sorted(keys, key=operator.itemgetter(0), 51 | reverse=True) 52 | 53 | with tf.gfile.GFile(filename, "w") as fd: 54 | fd.write("model_checkpoint_path: \"%s\"\n" % checkpoint_names[0]) 55 | 56 | for checkpoint_name in sorted_names: 57 | checkpoint_name = checkpoint_name[1] 58 | fd.write("all_model_checkpoint_paths: \"%s\"\n" % checkpoint_name) 59 | 60 | 61 | def _read_score_record(filename): 62 | # "checkpoint_name": score 63 | records = [] 64 | 65 | if not tf.gfile.Exists(filename): 66 | return records 67 | 68 | with tf.gfile.GFile(filename) as fd: 69 | for line in fd: 70 | name, score = line.strip().split(":") 71 | name = name.strip()[1:-1] 72 | score = float(score) 73 | records.append([name, score]) 74 | 75 | return records 76 | 77 | 78 | def _save_score_record(filename, records): 79 | keys = [] 80 | 81 | for record in records: 82 | checkpoint_name = record[0] 83 | step = int(checkpoint_name.strip().split("-")[-1]) 84 | keys.append((step, record)) 85 | 86 | sorted_keys = sorted(keys, key=operator.itemgetter(0), 87 | reverse=True) 88 | sorted_records = [item[1] for item in sorted_keys] 89 | 90 | with tf.gfile.GFile(filename, "w") as fd: 91 | for record in sorted_records: 92 | checkpoint_name, score = record 93 | fd.write("\"%s\": %f\n" % (checkpoint_name, score)) 94 | 95 | 96 | def _add_to_record(records, record, max_to_keep): 97 | added = None 98 | removed = None 99 | models = {} 100 | 101 | for (name, score) in records: 102 | models[name] = score 103 | 104 | if len(records) < max_to_keep: 105 | if record[0] not in models: 106 | added = record[0] 107 | records.append(record) 108 | else: 109 | sorted_records = sorted(records, key=lambda x: -x[1]) 110 | worst_score = sorted_records[-1][1] 111 | current_score = record[1] 112 | 113 | if current_score >= worst_score: 114 | if record[0] not in models: 115 | added = record[0] 116 | removed = sorted_records[-1][0] 117 | records = sorted_records[:-1] + [record] 118 | 119 | # Sort 120 | records = sorted(records, key=lambda x: -x[1]) 121 | 122 | return added, removed, records 123 | 124 | 125 | def _evaluate(eval_fn, input_fn, decode_fn, path, config): 126 | graph = tf.Graph() 127 | with graph.as_default(): 128 | features = input_fn() 129 | refs = features["references"] 130 | predictions = eval_fn(features) 131 | results = { 132 | "predictions": predictions, 133 | "references": refs 134 | } 135 | 136 | all_refs = [[] for _ in range(len(refs))] 137 | all_outputs = [] 138 | 139 | sess_creator = tf.train.ChiefSessionCreator( 140 | checkpoint_dir=path, 141 | config=config 142 | ) 143 | 144 | with tf.train.MonitoredSession(session_creator=sess_creator) as sess: 145 | while not sess.should_stop(): 146 | outputs = sess.run(results) 147 | # shape: [batch, len] 148 | predictions = outputs["predictions"].tolist() 149 | # shape: ([batch, len], ..., [batch, len]) 150 | references = [item.tolist() for item in outputs["references"]] 151 | 152 | all_outputs.extend(predictions) 153 | 154 | for i in range(len(refs)): 155 | all_refs[i].extend(references[i]) 156 | 157 | decoded_symbols = decode_fn(all_outputs) 158 | decoded_refs = [decode_fn(refs) for refs in all_refs] 159 | decoded_refs = [list(x) for x in zip(*decoded_refs)] 160 | 161 | return bleu.bleu(decoded_symbols, decoded_refs) 162 | 163 | 164 | class EvaluationHook(tf.train.SessionRunHook): 165 | """ Validate and save checkpoints every N steps or seconds. 166 | This hook only saves checkpoint according to a specific metric. 167 | """ 168 | 169 | def __init__(self, eval_fn, eval_input_fn, eval_decode_fn, base_dir, 170 | session_config, max_to_keep=5, eval_secs=None, 171 | eval_steps=None, metric="BLEU"): 172 | """ Initializes a `EvaluationHook`. 173 | Args: 174 | eval_fn: A function with signature (feature) 175 | eval_input_fn: A function with signature () 176 | eval_decode_fn: A function with signature (inputs) 177 | base_dir: A string. Base directory for the checkpoint files. 178 | session_config: An instance of tf.ConfigProto 179 | max_to_keep: An integer. The maximum of checkpoints to save 180 | eval_secs: An integer, eval every N secs. 181 | eval_steps: An integer, eval every N steps. 182 | checkpoint_basename: `str`, base name for the checkpoint files. 183 | ValueError: One of `save_steps` or `save_secs` should be set. 184 | ValueError: At most one of saver or scaffold should be set. 185 | """ 186 | tf.logging.info("Create EvaluationHook.") 187 | 188 | if metric != "BLEU": 189 | raise ValueError("Currently, EvaluationHook only support BLEU") 190 | 191 | self._base_dir = base_dir.rstrip("/") 192 | self._session_config = session_config 193 | self._save_path = os.path.join(base_dir, "eval") 194 | self._record_name = os.path.join(self._save_path, "record") 195 | self._eval_fn = eval_fn 196 | self._eval_input_fn = eval_input_fn 197 | self._eval_decode_fn = eval_decode_fn 198 | self._max_to_keep = max_to_keep 199 | self._metric = metric 200 | self._global_step = None 201 | self._timer = tf.train.SecondOrStepTimer( 202 | every_secs=eval_secs or None, every_steps=eval_steps or None 203 | ) 204 | 205 | def begin(self): 206 | if self._timer.last_triggered_step() is None: 207 | self._timer.update_last_triggered_step(0) 208 | 209 | global_step = tf.train.get_global_step() 210 | 211 | if not tf.gfile.Exists(self._save_path): 212 | tf.logging.info("Making dir: %s" % self._save_path) 213 | tf.gfile.MakeDirs(self._save_path) 214 | 215 | params_pattern = os.path.join(self._base_dir, "*.json") 216 | params_files = tf.gfile.Glob(params_pattern) 217 | 218 | for name in params_files: 219 | new_name = name.replace(self._base_dir, self._save_path) 220 | tf.gfile.Copy(name, new_name, overwrite=True) 221 | 222 | if global_step is None: 223 | raise RuntimeError("Global step should be created first") 224 | 225 | self._global_step = global_step 226 | 227 | def before_run(self, run_context): 228 | args = tf.train.SessionRunArgs(self._global_step) 229 | return args 230 | 231 | def after_run(self, run_context, run_values): 232 | stale_global_step = run_values.results 233 | 234 | if self._timer.should_trigger_for_step(stale_global_step + 1): 235 | global_step = run_context.session.run(self._global_step) 236 | 237 | # Get the real value 238 | if self._timer.should_trigger_for_step(global_step): 239 | self._timer.update_last_triggered_step(global_step) 240 | # Save model 241 | save_path = os.path.join(self._base_dir, "model.ckpt") 242 | saver = _get_saver() 243 | tf.logging.info("Saving checkpoints for %d into %s." % 244 | (global_step, save_path)) 245 | saver.save(run_context.session, 246 | save_path, 247 | global_step=global_step) 248 | # Do validation here 249 | tf.logging.info("Validating model at step %d" % global_step) 250 | score = _evaluate(self._eval_fn, self._eval_input_fn, 251 | self._eval_decode_fn, 252 | self._base_dir, 253 | self._session_config) 254 | tf.logging.info("%s at step %d: %f" % 255 | (self._metric, global_step, score)) 256 | 257 | checkpoint_filename = os.path.join(self._base_dir, 258 | "checkpoint") 259 | all_checkpoints = _read_checkpoint_def(checkpoint_filename) 260 | records = _read_score_record(self._record_name) 261 | latest_checkpoint = all_checkpoints[-1] 262 | record = [latest_checkpoint, score] 263 | added, removed, records = _add_to_record(records, record, 264 | self._max_to_keep) 265 | 266 | if added is not None: 267 | old_path = os.path.join(self._base_dir, added) 268 | new_path = os.path.join(self._save_path, added) 269 | old_files = tf.gfile.Glob(old_path + "*") 270 | tf.logging.info("Copying %s to %s" % (old_path, new_path)) 271 | 272 | for o_file in old_files: 273 | n_file = o_file.replace(old_path, new_path) 274 | tf.gfile.Copy(o_file, n_file, overwrite=True) 275 | 276 | if removed is not None: 277 | filename = os.path.join(self._save_path, removed) 278 | tf.logging.info("Removing %s" % filename) 279 | files = tf.gfile.Glob(filename + "*") 280 | 281 | for name in files: 282 | tf.gfile.Remove(name) 283 | 284 | _save_score_record(self._record_name, records) 285 | checkpoint_filename = checkpoint_filename.replace( 286 | self._base_dir, self._save_path 287 | ) 288 | _save_checkpoint_def(checkpoint_filename, 289 | [item[0] for item in records]) 290 | 291 | best_score = records[0][1] 292 | tf.logging.info("Best score at step %d: %f" % 293 | (global_step, best_score)) 294 | 295 | def end(self, session): 296 | last_step = session.run(self._global_step) 297 | 298 | if last_step != self._timer.last_triggered_step(): 299 | global_step = last_step 300 | tf.logging.info("Validating model at step %d" % global_step) 301 | score = _evaluate(self._eval_fn, self._eval_input_fn, 302 | self._eval_decode_fn, 303 | self._base_dir, 304 | self._session_config) 305 | tf.logging.info("%s at step %d: %f" % 306 | (self._metric, global_step, score)) 307 | 308 | checkpoint_filename = os.path.join(self._base_dir, 309 | "checkpoint") 310 | all_checkpoints = _read_checkpoint_def(checkpoint_filename) 311 | records = _read_score_record(self._record_name) 312 | latest_checkpoint = all_checkpoints[-1] 313 | record = [latest_checkpoint, score] 314 | added, removed, records = _add_to_record(records, record, 315 | self._max_to_keep) 316 | 317 | if added is not None: 318 | old_path = os.path.join(self._base_dir, added) 319 | new_path = os.path.join(self._save_path, added) 320 | old_files = tf.gfile.Glob(old_path + "*") 321 | tf.logging.info("Copying %s to %s" % (old_path, new_path)) 322 | 323 | for o_file in old_files: 324 | n_file = o_file.replace(old_path, new_path) 325 | tf.gfile.Copy(o_file, n_file, overwrite=True) 326 | 327 | if removed is not None: 328 | filename = os.path.join(self._save_path, removed) 329 | tf.logging.info("Removing %s" % filename) 330 | files = tf.gfile.Glob(filename + "*") 331 | 332 | for name in files: 333 | tf.gfile.Remove(name) 334 | 335 | _save_score_record(self._record_name, records) 336 | checkpoint_filename = checkpoint_filename.replace( 337 | self._base_dir, self._save_path 338 | ) 339 | _save_checkpoint_def(checkpoint_filename, 340 | [item[0] for item in records]) 341 | 342 | best_score = records[0][1] 343 | tf.logging.info("Best score: %f" % best_score) 344 | -------------------------------------------------------------------------------- /xmunmt/bin/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 4 | # Author: Zhixing Tan 5 | # Contact: playinf@stu.xmu.edu.cn 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import argparse 12 | import os 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | import xmunmt.data.dataset as dataset 17 | import xmunmt.data.record as record 18 | import xmunmt.data.vocab as vocabulary 19 | import xmunmt.models as models 20 | import xmunmt.utils.hooks as hooks 21 | import xmunmt.utils.parallel as parallel 22 | import xmunmt.utils.search as search 23 | 24 | 25 | def parse_args(args=None): 26 | parser = argparse.ArgumentParser( 27 | description="Training neural machine translation models", 28 | usage="trainer.py [] [-h | --help]" 29 | ) 30 | 31 | # input files 32 | parser.add_argument("--input", type=str, nargs=2, 33 | help="Path of source and target corpus") 34 | parser.add_argument("--record", type=str, 35 | help="Path to tf.Record data") 36 | parser.add_argument("--output", type=str, default="train", 37 | help="Path to saved models") 38 | parser.add_argument("--vocabulary", type=str, nargs=2, 39 | help="Path of source and target vocabulary") 40 | parser.add_argument("--validation", type=str, 41 | help="Path of validation file") 42 | parser.add_argument("--references", type=str, nargs="+", 43 | help="Path of reference files") 44 | 45 | # model and configuration 46 | parser.add_argument("--model", type=str, required=True, 47 | help="Name of the model") 48 | parser.add_argument("--parameters", type=str, default="", 49 | help="Additional hyper parameters") 50 | 51 | return parser.parse_args(args) 52 | 53 | 54 | def default_parameters(): 55 | params = tf.contrib.training.HParams( 56 | input=["", ""], 57 | output="", 58 | record="", 59 | model="rnnsearch", 60 | vocab=["", ""], 61 | # Default training hyper parameters 62 | num_threads=6, 63 | batch_size=128, 64 | max_length=60, 65 | length_multiplier=1, 66 | mantissa_bits=2, 67 | warmup_steps=4000, 68 | train_steps=100000, 69 | buffer_size=10000, 70 | constant_batch_size=True, 71 | device_list=[0], 72 | initializer="uniform", 73 | initializer_gain=0.08, 74 | adam_beta1=0.9, 75 | adam_beta2=0.999, 76 | adam_epsilon=1e-8, 77 | clip_grad_norm=5.0, 78 | learning_rate=1.0, 79 | learning_rate_decay="noam", 80 | learning_rate_boundaries=[0], 81 | learning_rate_values=[0.0], 82 | keep_checkpoint_max=20, 83 | keep_top_checkpoint_max=5, 84 | # Validation 85 | eval_steps=2000, 86 | eval_secs=0, 87 | eval_batch_size=32, 88 | top_beams=1, 89 | beam_size=4, 90 | decode_alpha=0.6, 91 | decode_length=50, 92 | decode_constant=5.0, 93 | decode_normalize=False, 94 | validation="", 95 | references=[""], 96 | save_checkpoint_secs=0, 97 | save_checkpoint_steps=1000 98 | ) 99 | 100 | return params 101 | 102 | 103 | def import_params(model_dir, model_name, params): 104 | model_dir = os.path.abspath(model_dir) 105 | p_name = os.path.join(model_dir, "params.json") 106 | m_name = os.path.join(model_dir, model_name + ".json") 107 | 108 | if not tf.gfile.Exists(p_name) or not tf.gfile.Exists(m_name): 109 | return params 110 | 111 | with tf.gfile.Open(p_name) as fd: 112 | tf.logging.info("Restoring hyper parameters from %s" % p_name) 113 | json_str = fd.readline() 114 | params.parse_json(json_str) 115 | 116 | with tf.gfile.Open(m_name) as fd: 117 | tf.logging.info("Restoring model parameters from %s" % m_name) 118 | json_str = fd.readline() 119 | params.parse_json(json_str) 120 | 121 | return params 122 | 123 | 124 | def export_params(output_dir, name, params): 125 | if not tf.gfile.Exists(output_dir): 126 | tf.gfile.MkDir(output_dir) 127 | 128 | # Save params as params.json 129 | filename = os.path.join(output_dir, name) 130 | with tf.gfile.Open(filename, "w") as fd: 131 | fd.write(params.to_json()) 132 | 133 | 134 | def collect_params(all_params, params): 135 | collected = tf.contrib.training.HParams() 136 | 137 | for k in params.values().iterkeys(): 138 | collected.add_hparam(k, getattr(all_params, k)) 139 | 140 | return collected 141 | 142 | 143 | def merge_parameters(params1, params2): 144 | params = tf.contrib.training.HParams() 145 | 146 | for (k, v) in params1.values().iteritems(): 147 | params.add_hparam(k, v) 148 | 149 | params_dict = params.values() 150 | 151 | for (k, v) in params2.values().iteritems(): 152 | if k in params_dict: 153 | # Override 154 | setattr(params, k, v) 155 | else: 156 | params.add_hparam(k, v) 157 | 158 | return params 159 | 160 | 161 | def override_parameters(params, args): 162 | params.model = args.model 163 | params.input = args.input or params.input 164 | params.output = args.output or params.output 165 | params.record = args.record or params.record 166 | params.vocab = args.vocabulary or params.vocab 167 | params.validation = args.validation or params.validation 168 | params.references = args.references or params.references 169 | params.parse(args.parameters) 170 | 171 | params.vocabulary = { 172 | "source": vocabulary.load_vocabulary(params.vocab[0]), 173 | "target": vocabulary.load_vocabulary(params.vocab[1]) 174 | } 175 | params.vocabulary["source"] = vocabulary.process_vocabulary( 176 | params.vocabulary["source"], params 177 | ) 178 | params.vocabulary["target"] = vocabulary.process_vocabulary( 179 | params.vocabulary["target"], params 180 | ) 181 | 182 | control_symbols = [params.pad, params.bos, params.eos, params.unk] 183 | 184 | params.mapping = { 185 | "source": vocabulary.get_control_mapping( 186 | params.vocabulary["source"], 187 | control_symbols 188 | ), 189 | "target": vocabulary.get_control_mapping( 190 | params.vocabulary["target"], 191 | control_symbols 192 | ) 193 | } 194 | 195 | return params 196 | 197 | 198 | def get_initializer(params): 199 | if params.initializer == "uniform": 200 | max_val = params.initializer_gain 201 | return tf.random_uniform_initializer(-max_val, max_val) 202 | elif params.initializer == "normal": 203 | return tf.random_normal_initializer(0.0, params.initializer_gain) 204 | elif params.initializer == "normal_unit_scaling": 205 | return tf.variance_scaling_initializer(params.initializer_gain, 206 | mode="fan_avg", 207 | distribution="normal") 208 | elif params.initializer == "uniform_unit_scaling": 209 | return tf.variance_scaling_initializer(params.initializer_gain, 210 | mode="fan_avg", 211 | distribution="uniform") 212 | else: 213 | raise ValueError("Unrecognized initializer: %s" % params.initializer) 214 | 215 | 216 | def get_learning_rate_decay(learning_rate, global_step, params): 217 | if params.learning_rate_decay == "noam": 218 | step = tf.to_float(global_step) 219 | warmup_steps = tf.to_float(params.warmup_steps) 220 | multiplier = params.hidden_size ** -0.5 221 | decay = multiplier * tf.minimum((step + 1) * (warmup_steps ** -1.5), 222 | (step + 1) ** -0.5) 223 | 224 | return learning_rate * decay 225 | elif params.learning_rate_decay == "piecewise_constant": 226 | return tf.train.piecewise_constant(tf.to_int32(global_step), 227 | params.learning_rate_boundaries, 228 | params.learning_rate_values) 229 | elif params.learning_rate_decay == "none": 230 | return learning_rate 231 | else: 232 | raise ValueError("Unknown learning_rate_decay") 233 | 234 | 235 | def session_config(params): 236 | optimizer_options = tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L1, 237 | do_function_inlining=True) 238 | graph_options = tf.GraphOptions(optimizer_options=optimizer_options) 239 | config = tf.ConfigProto(allow_soft_placement=True, 240 | graph_options=graph_options) 241 | if params.device_list: 242 | device_str = ",".join([str(i) for i in params.device_list]) 243 | config.gpu_options.visible_device_list = device_str 244 | 245 | return config 246 | 247 | 248 | def decode_target_ids(inputs, params): 249 | decoded = [] 250 | vocab = params.vocabulary["target"] 251 | 252 | for item in inputs: 253 | syms = [] 254 | for idx in item: 255 | sym = vocab[idx] 256 | 257 | if sym == params.eos: 258 | break 259 | 260 | if sym == params.pad: 261 | break 262 | 263 | syms.append(sym) 264 | decoded.append(syms) 265 | 266 | return decoded 267 | 268 | 269 | def main(args): 270 | tf.logging.set_verbosity(tf.logging.INFO) 271 | model_cls = models.get_model(args.model) 272 | params = default_parameters() 273 | 274 | # Import and override parameters 275 | # Priorities (low -> high): 276 | # default -> saved -> command 277 | params = merge_parameters(params, model_cls.get_parameters()) 278 | params = import_params(args.output, args.model, params) 279 | override_parameters(params, args) 280 | 281 | # Export all parameters and model specific parameters 282 | export_params(params.output, "params.json", params) 283 | export_params( 284 | params.output, 285 | "%s.json" % args.model, 286 | collect_params(params, model_cls.get_parameters()) 287 | ) 288 | 289 | # Build Graph 290 | with tf.Graph().as_default(): 291 | if not params.record: 292 | # Build input queue 293 | features = dataset.get_training_input(params.input, params) 294 | else: 295 | features = record.get_input_features( 296 | os.path.join(params.record, "*train*"), "train", params 297 | ) 298 | 299 | # Build model 300 | initializer = get_initializer(params) 301 | model = model_cls(params) 302 | 303 | # Multi-GPU setting 304 | sharded_losses = parallel.parallel_model( 305 | model.get_training_func(initializer), 306 | features, 307 | params.device_list 308 | ) 309 | loss = tf.add_n(sharded_losses) / len(sharded_losses) 310 | 311 | # Create global step 312 | global_step = tf.train.get_or_create_global_step() 313 | 314 | # Print parameters 315 | all_weights = {v.name: v for v in tf.trainable_variables()} 316 | total_size = 0 317 | 318 | for v_name in sorted(list(all_weights)): 319 | v = all_weights[v_name] 320 | tf.logging.info("%s\tshape %s", v.name[:-2].ljust(80), 321 | str(v.shape).ljust(20)) 322 | v_size = np.prod(np.array(v.shape.as_list())).tolist() 323 | total_size += v_size 324 | tf.logging.info("Total trainable variables size: %d", total_size) 325 | 326 | learning_rate = get_learning_rate_decay(params.learning_rate, 327 | global_step, params) 328 | tf.summary.scalar("learning_rate", learning_rate) 329 | 330 | # Create optimizer 331 | opt = tf.train.AdamOptimizer(learning_rate, 332 | beta1=params.adam_beta1, 333 | beta2=params.adam_beta2, 334 | epsilon=params.adam_epsilon) 335 | 336 | train_op = tf.contrib.layers.optimize_loss( 337 | name="training", 338 | loss=loss, 339 | global_step=global_step, 340 | learning_rate=learning_rate, 341 | clip_gradients=params.clip_grad_norm or None, 342 | optimizer=opt, 343 | colocate_gradients_with_ops=True 344 | ) 345 | 346 | # Validation 347 | if params.validation and params.references[0]: 348 | files = [params.validation] + list(params.references) 349 | eval_inputs = dataset.sort_and_zip_files(files) 350 | eval_input_fn = dataset.get_evaluation_input 351 | else: 352 | eval_input_fn = None 353 | 354 | # Add hooks 355 | train_hooks = [ 356 | tf.train.StopAtStepHook(last_step=params.train_steps), 357 | tf.train.NanTensorHook(loss), 358 | tf.train.LoggingTensorHook( 359 | { 360 | "step": global_step, 361 | "loss": loss, 362 | "source": tf.shape(features["source"]), 363 | "target": tf.shape(features["target"]) 364 | }, 365 | every_n_iter=1 366 | ), 367 | tf.train.CheckpointSaverHook( 368 | checkpoint_dir=params.output, 369 | save_secs=params.save_checkpoint_secs or None, 370 | save_steps=params.save_checkpoint_steps or None, 371 | saver=tf.train.Saver( 372 | max_to_keep=params.keep_checkpoint_max, 373 | sharded=False 374 | ) 375 | ) 376 | ] 377 | 378 | config = session_config(params) 379 | 380 | if eval_input_fn is not None: 381 | train_hooks.append( 382 | hooks.EvaluationHook( 383 | lambda f: search.create_inference_graph( 384 | model.get_evaluation_func(), f, params 385 | ), 386 | lambda: eval_input_fn(eval_inputs, params), 387 | lambda x: decode_target_ids(x, params), 388 | params.output, 389 | config, 390 | params.keep_top_checkpoint_max, 391 | eval_secs=params.eval_secs, 392 | eval_steps=params.eval_steps 393 | ) 394 | ) 395 | 396 | # Create session, do not use default CheckpointSaverHook 397 | with tf.train.MonitoredTrainingSession( 398 | checkpoint_dir=params.output, hooks=train_hooks, 399 | save_checkpoint_secs=None, config=config) as sess: 400 | while not sess.should_stop(): 401 | sess.run(train_op) 402 | 403 | 404 | if __name__ == "__main__": 405 | main(parse_args()) 406 | -------------------------------------------------------------------------------- /xmunmt/models/rnnsearch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import copy 11 | 12 | import tensorflow as tf 13 | import xmunmt.interface as interface 14 | import xmunmt.layers as layers 15 | 16 | 17 | def _copy_through(time, length, output, new_output): 18 | copy_cond = (time >= length) 19 | return tf.where(copy_cond, output, new_output) 20 | 21 | 22 | def _gru_encoder(cell, inputs, sequence_length, initial_state, dtype=None): 23 | # Assume that the underlying cell is GRUCell-like 24 | output_size = cell.output_size 25 | dtype = dtype or inputs.dtype 26 | 27 | batch = tf.shape(inputs)[0] 28 | time_steps = tf.shape(inputs)[1] 29 | 30 | zero_output = tf.zeros([batch, output_size], dtype) 31 | 32 | if initial_state is None: 33 | initial_state = cell.zero_state(batch, dtype) 34 | 35 | input_ta = tf.TensorArray(dtype, time_steps, 36 | tensor_array_name="input_array") 37 | output_ta = tf.TensorArray(dtype, time_steps, 38 | tensor_array_name="output_array") 39 | input_ta = input_ta.unstack(tf.transpose(inputs, [1, 0, 2])) 40 | 41 | def loop_func(t, out_ta, state): 42 | inp_t = input_ta.read(t) 43 | cell_output, new_state = cell(inp_t, state) 44 | cell_output = _copy_through(t, sequence_length, zero_output, 45 | cell_output) 46 | new_state = _copy_through(t, sequence_length, state, new_state) 47 | out_ta = out_ta.write(t, cell_output) 48 | return t + 1, out_ta, new_state 49 | 50 | time = tf.constant(0, dtype=tf.int32, name="time") 51 | loop_vars = (time, output_ta, initial_state) 52 | 53 | outputs = tf.while_loop(lambda t, *_: t < time_steps, loop_func, 54 | loop_vars, parallel_iterations=32, 55 | swap_memory=True) 56 | 57 | output_final_ta = outputs[1] 58 | final_state = outputs[2] 59 | 60 | all_output = output_final_ta.stack() 61 | all_output.set_shape([None, None, output_size]) 62 | all_output = tf.transpose(all_output, [1, 0, 2]) 63 | 64 | return all_output, final_state 65 | 66 | 67 | def _encoder(cell_fw, cell_bw, inputs, sequence_length, dtype=None, 68 | scope=None): 69 | with tf.variable_scope(scope or "encoder", 70 | values=[inputs, sequence_length]): 71 | inputs_fw = inputs 72 | inputs_bw = tf.reverse_sequence(inputs, sequence_length, 73 | batch_axis=0, seq_axis=1) 74 | 75 | with tf.variable_scope("forward"): 76 | output_fw, state_fw = _gru_encoder(cell_fw, inputs_fw, 77 | sequence_length, None, 78 | dtype=dtype) 79 | 80 | with tf.variable_scope("backward"): 81 | output_bw, state_bw = _gru_encoder(cell_bw, inputs_bw, 82 | sequence_length, None, 83 | dtype=dtype) 84 | output_bw = tf.reverse_sequence(output_bw, sequence_length, 85 | batch_axis=0, seq_axis=1) 86 | 87 | results = { 88 | "annotation": tf.concat([output_fw, output_bw], axis=2), 89 | "outputs": { 90 | "forward": output_fw, 91 | "backward": output_bw 92 | }, 93 | "final_states": { 94 | "forward": state_fw, 95 | "backward": state_bw 96 | } 97 | } 98 | 99 | return results 100 | 101 | 102 | def _decoder(cell, inputs, memory, sequence_length, initial_state, dtype=None, 103 | scope=None): 104 | # Assume that the underlying cell is GRUCell-like 105 | batch = tf.shape(inputs)[0] 106 | time_steps = tf.shape(inputs)[1] 107 | dtype = dtype or inputs.dtype 108 | output_size = cell.output_size 109 | zero_output = tf.zeros([batch, output_size], dtype) 110 | zero_value = tf.zeros([batch, memory.shape[-1].value], dtype) 111 | 112 | with tf.variable_scope(scope or "decoder", dtype=dtype): 113 | inputs = tf.transpose(inputs, [1, 0, 2]) 114 | mem_mask = tf.sequence_mask(sequence_length["source"], 115 | maxlen=tf.shape(memory)[1], 116 | dtype=tf.float32) 117 | bias = layers.attention.attention_bias(mem_mask) 118 | cache = layers.attention.attention(None, memory, None, output_size) 119 | 120 | input_ta = tf.TensorArray(tf.float32, time_steps, 121 | tensor_array_name="input_array") 122 | output_ta = tf.TensorArray(tf.float32, time_steps, 123 | tensor_array_name="output_array") 124 | value_ta = tf.TensorArray(tf.float32, time_steps, 125 | tensor_array_name="value_array") 126 | alpha_ta = tf.TensorArray(tf.float32, time_steps, 127 | tensor_array_name="alpha_array") 128 | input_ta = input_ta.unstack(inputs) 129 | initial_state = layers.nn.linear(initial_state, output_size, True, 130 | scope="s_transform") 131 | initial_state = tf.tanh(initial_state) 132 | 133 | def loop_func(t, out_ta, att_ta, val_ta, state, cache_key): 134 | inp_t = input_ta.read(t) 135 | results = layers.attention.attention(state, memory, bias, 136 | output_size, 137 | cache={"key": cache_key}) 138 | alpha = results["weight"] 139 | context = results["value"] 140 | cell_input = [inp_t, context] 141 | cell_output, new_state = cell(cell_input, state) 142 | cell_output = _copy_through(t, sequence_length["target"], 143 | zero_output, cell_output) 144 | new_state = _copy_through(t, sequence_length["target"], state, 145 | new_state) 146 | new_value = _copy_through(t, sequence_length["target"], zero_value, 147 | context) 148 | 149 | out_ta = out_ta.write(t, cell_output) 150 | att_ta = att_ta.write(t, alpha) 151 | val_ta = val_ta.write(t, new_value) 152 | cache_key = tf.identity(cache_key) 153 | return t + 1, out_ta, att_ta, val_ta, new_state, cache_key 154 | 155 | time = tf.constant(0, dtype=tf.int32, name="time") 156 | loop_vars = (time, output_ta, alpha_ta, value_ta, initial_state, 157 | cache["key"]) 158 | 159 | outputs = tf.while_loop(lambda t, *_: t < time_steps, 160 | loop_func, loop_vars, 161 | parallel_iterations=32, 162 | swap_memory=True) 163 | 164 | output_final_ta = outputs[1] 165 | value_final_ta = outputs[3] 166 | 167 | final_output = output_final_ta.stack() 168 | final_output.set_shape([None, None, output_size]) 169 | final_output = tf.transpose(final_output, [1, 0, 2]) 170 | 171 | final_value = value_final_ta.stack() 172 | final_value.set_shape([None, None, memory.shape[-1].value]) 173 | final_value = tf.transpose(final_value, [1, 0, 2]) 174 | 175 | result = { 176 | "outputs": final_output, 177 | "values": final_value, 178 | "initial_state": initial_state 179 | } 180 | 181 | return result 182 | 183 | 184 | def model_graph(features, labels, params): 185 | src_vocab_size = len(params.vocabulary["source"]) 186 | tgt_vocab_size = len(params.vocabulary["target"]) 187 | 188 | with tf.variable_scope("source_embedding"): 189 | src_emb = tf.get_variable("embedding", 190 | [src_vocab_size, params.embedding_size]) 191 | src_bias = tf.get_variable("bias", [params.embedding_size]) 192 | src_inputs = tf.nn.embedding_lookup(src_emb, features["source"]) 193 | 194 | with tf.variable_scope("target_embedding"): 195 | tgt_emb = tf.get_variable("embedding", 196 | [tgt_vocab_size, params.embedding_size]) 197 | tgt_bias = tf.get_variable("bias", [params.embedding_size]) 198 | tgt_inputs = tf.nn.embedding_lookup(tgt_emb, features["target"]) 199 | 200 | src_inputs = tf.nn.bias_add(src_inputs, src_bias) 201 | tgt_inputs = tf.nn.bias_add(tgt_inputs, tgt_bias) 202 | 203 | if params.dropout and not params.use_variational_dropout: 204 | src_inputs = tf.nn.dropout(src_inputs, 1.0 - params.dropout) 205 | tgt_inputs = tf.nn.dropout(tgt_inputs, 1.0 - params.dropout) 206 | 207 | # encoder 208 | cell_fw = layers.rnn_cell.LegacyGRUCell(params.hidden_size) 209 | cell_bw = layers.rnn_cell.LegacyGRUCell(params.hidden_size) 210 | 211 | if params.use_variational_dropout: 212 | cell_fw = tf.nn.rnn_cell.DropoutWrapper( 213 | cell_fw, 214 | input_keep_prob=1.0 - params.dropout, 215 | output_keep_prob=1.0 - params.dropout, 216 | state_keep_prob=1.0 - params.dropout, 217 | variational_recurrent=True, 218 | input_size=params.embedding_size, 219 | dtype=tf.float32 220 | ) 221 | cell_bw = tf.nn.rnn_cell.DropoutWrapper( 222 | cell_bw, 223 | input_keep_prob=1.0 - params.dropout, 224 | output_keep_prob=1.0 - params.dropout, 225 | state_keep_prob=1.0 - params.dropout, 226 | variational_recurrent=True, 227 | input_size=params.embedding_size, 228 | dtype=tf.float32 229 | ) 230 | 231 | encoder_output = _encoder(cell_fw, cell_bw, src_inputs, 232 | features["source_length"]) 233 | 234 | # decoder 235 | cell = layers.rnn_cell.LegacyGRUCell(params.hidden_size) 236 | 237 | if params.use_variational_dropout: 238 | cell = tf.nn.rnn_cell.DropoutWrapper( 239 | cell, 240 | input_keep_prob=1.0 - params.dropout, 241 | output_keep_prob=1.0 - params.dropout, 242 | state_keep_prob=1.0 - params.dropout, 243 | variational_recurrent=True, 244 | # input + context 245 | input_size=params.embedding_size + 2 * params.hidden_size, 246 | dtype=tf.float32 247 | ) 248 | 249 | length = { 250 | "source": features["source_length"], 251 | "target": features["target_length"] 252 | } 253 | initial_state = encoder_output["final_states"]["backward"] 254 | decoder_output = _decoder(cell, tgt_inputs, encoder_output["annotation"], 255 | length, initial_state) 256 | 257 | # Shift left 258 | shifted_tgt_inputs = tf.pad(tgt_inputs, [[0, 0], [1, 0], [0, 0]]) 259 | shifted_tgt_inputs = shifted_tgt_inputs[:, :-1, :] 260 | 261 | all_outputs = tf.concat( 262 | [ 263 | tf.expand_dims(decoder_output["initial_state"], axis=1), 264 | decoder_output["outputs"], 265 | ], 266 | axis=1 267 | ) 268 | shifted_outputs = all_outputs[:, :-1, :] 269 | 270 | maxout_features = [ 271 | shifted_tgt_inputs, 272 | shifted_outputs, 273 | decoder_output["values"] 274 | ] 275 | maxout_size = params.hidden_size // params.maxnum 276 | 277 | if labels is None: 278 | # Special case for non-incremental decoding 279 | maxout_features = [ 280 | shifted_tgt_inputs[:, -1, :], 281 | shifted_outputs[:, -1, :], 282 | decoder_output["values"][:, -1, :] 283 | ] 284 | maxhid = layers.nn.maxout(maxout_features, maxout_size, params.maxnum, 285 | concat=False) 286 | readout = layers.nn.linear(maxhid, params.embedding_size, False, 287 | scope="deepout") 288 | 289 | # Prediction 290 | logits = layers.nn.linear(readout, tgt_vocab_size, True, 291 | scope="softmax") 292 | 293 | return logits 294 | 295 | maxhid = layers.nn.maxout(maxout_features, maxout_size, params.maxnum, 296 | concat=False) 297 | readout = layers.nn.linear(maxhid, params.embedding_size, False, 298 | scope="deepout") 299 | 300 | if params.dropout and not params.use_variational_dropout: 301 | readout = tf.nn.dropout(readout, 1.0 - params.dropout) 302 | 303 | # Prediction 304 | logits = layers.nn.linear(readout, tgt_vocab_size, True, scope="softmax") 305 | logits = tf.reshape(logits, [-1, tgt_vocab_size]) 306 | 307 | ce = layers.nn.smoothed_softmax_cross_entropy_with_logits( 308 | logits=logits, 309 | labels=labels, 310 | smoothing=params.label_smoothing, 311 | normalize=True 312 | ) 313 | 314 | ce = tf.reshape(ce, tf.shape(labels)) 315 | tgt_mask = tf.to_float( 316 | tf.sequence_mask( 317 | features["target_length"], 318 | maxlen=tf.shape(features["target"])[1] 319 | ) 320 | ) 321 | loss = tf.reduce_sum(ce * tgt_mask) / tf.reduce_sum(tgt_mask) 322 | 323 | return loss 324 | 325 | 326 | class RNNsearch(interface.NMTModel): 327 | """ 328 | Reference: 329 | Neural Machine Translation by Jointly Learning to Align and Translate 330 | """ 331 | 332 | def __init__(self, params, scope="rnnsearch"): 333 | super(RNNsearch, self).__init__(params=params, scope=scope) 334 | 335 | def get_training_func(self, initializer): 336 | def training_fn(features, params=None): 337 | if params is None: 338 | params = self.parameters 339 | with tf.variable_scope(self._scope, initializer=initializer): 340 | loss = model_graph(features, features["target"], params) 341 | return loss 342 | 343 | return training_fn 344 | 345 | def get_evaluation_func(self): 346 | def evaluation_fn(features, params=None): 347 | if params is None: 348 | params = copy.copy(self.parameters) 349 | else: 350 | params = copy.copy(params) 351 | params.dropout = 0.0 352 | params.use_variational_dropout = False 353 | params.label_smoothing = 0.0 354 | 355 | with tf.variable_scope(self._scope): 356 | logits = model_graph(features, None, params) 357 | 358 | return logits 359 | 360 | return evaluation_fn 361 | 362 | def get_inference_func(self): 363 | def inference_fn(features, params=None): 364 | if params is None: 365 | params = copy.copy(self.parameters) 366 | else: 367 | params = copy.copy(params) 368 | params.dropout = 0.0 369 | params.use_variational_dropout = False 370 | params.label_smoothing = 0.0 371 | 372 | with tf.variable_scope(self._scope): 373 | logits = model_graph(features, None, params) 374 | 375 | return logits 376 | 377 | return inference_fn 378 | 379 | @staticmethod 380 | def get_name(): 381 | return "rnnsearch" 382 | 383 | @staticmethod 384 | def get_parameters(): 385 | params = tf.contrib.training.HParams( 386 | # vocabulary 387 | pad="", 388 | unk="UNK", 389 | eos="", 390 | bos="", 391 | append_eos=False, 392 | # model 393 | rnn_cell="LegacyGRUCell", 394 | embedding_size=620, 395 | hidden_size=1000, 396 | maxnum=2, 397 | # regularization 398 | dropout=0.2, 399 | use_variational_dropout=False, 400 | label_smoothing=0.1, 401 | constant_batch_size=True, 402 | batch_size=128, 403 | max_length=60, 404 | clip_grad_norm=5.0 405 | ) 406 | 407 | return params 408 | -------------------------------------------------------------------------------- /xmunmt/utils/search.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Natural Language Processing Lab of Xiamen University 3 | # Author: Zhixing Tan 4 | # Contact: playinf@stu.xmu.edu.cn 5 | # Disclaimer: Part of this code is modified from the Tensor2Tensor library 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import tensorflow as tf 12 | 13 | # Default value for INF 14 | INF = 1. * 1e7 15 | 16 | 17 | def log_prob_from_logits(logits): 18 | return logits - tf.reduce_logsumexp(logits, axis=2, keep_dims=True) 19 | 20 | 21 | def compute_batch_indices(batch_size, beam_size): 22 | """Computes the i'th coordinate that contains the batch index for gathers. 23 | 24 | Batch pos is a tensor like [[0,0,0,0,],[1,1,1,1],..]. It says which 25 | batch the beam item is in. This will create the i of the i,j coordinate 26 | needed for the gather. 27 | 28 | Args: 29 | batch_size: Batch size 30 | beam_size: Size of the beam. 31 | Returns: 32 | batch_pos: [batch_size, beam_size] tensor of ids 33 | """ 34 | batch_pos = tf.range(batch_size * beam_size) // beam_size 35 | batch_pos = tf.reshape(batch_pos, [batch_size, beam_size]) 36 | return batch_pos 37 | 38 | 39 | def compute_topk_scores_and_seq(sequences, scores, scores_to_gather, flags, 40 | beam_size, batch_size): 41 | """Given sequences and scores, will gather the top k=beam size sequences. 42 | 43 | This function is used to grow alive, and finished. It takes sequences, 44 | scores, and flags, and returns the top k from sequences, scores_to_gather, 45 | and flags based on the values in scores. 46 | 47 | Args: 48 | sequences: Tensor of sequences that we need to gather from. 49 | [batch_size, beam_size, seq_length] 50 | scores: Tensor of scores for each sequence in sequences. 51 | [batch_size, beam_size]. We will use these to compute the topk. 52 | scores_to_gather: Tensor of scores for each sequence in sequences. 53 | [batch_size, beam_size]. We will return the gathered scores from 54 | here. Scores to gather is different from scores because for 55 | grow_alive, we will need to return log_probs, while for 56 | grow_finished, we will need to return the length penalized scors. 57 | flags: Tensor of bools for sequences that say whether a sequence has 58 | reached EOS or not 59 | beam_size: int 60 | batch_size: int 61 | Returns: 62 | Tuple of 63 | (topk_seq [batch_size, beam_size, decode_length], 64 | topk_gathered_scores [batch_size, beam_size], 65 | topk_finished_flags[batch_size, beam_size]) 66 | """ 67 | _, topk_indexes = tf.nn.top_k(scores, k=beam_size) 68 | # The next three steps are to create coordinates for tf.gather_nd to pull 69 | # out the top-k sequences from sequences based on scores. 70 | # batch pos is a tensor like [[0,0,0,0,],[1,1,1,1],..]. It says which 71 | # batch the beam item is in. This will create the i of the i,j coordinate 72 | # needed for the gather 73 | batch_pos = compute_batch_indices(batch_size, beam_size) 74 | 75 | # top coordinates will give us the actual coordinates to do the gather. 76 | # stacking will create a tensor of dimension batch * beam * 2, where the 77 | # last dimension contains the i,j gathering coordinates. 78 | top_coordinates = tf.stack([batch_pos, topk_indexes], axis=2) 79 | 80 | # Gather up the highest scoring sequences 81 | topk_seq = tf.gather_nd(sequences, top_coordinates) 82 | topk_flags = tf.gather_nd(flags, top_coordinates) 83 | topk_gathered_scores = tf.gather_nd(scores_to_gather, top_coordinates) 84 | return topk_seq, topk_gathered_scores, topk_flags 85 | 86 | 87 | def beam_search(symbols_to_logits_fn, initial_ids, beam_size, decode_length, 88 | vocab_size, alpha, eos_id, lp_constant=5.0): 89 | """Beam search with length penalties. 90 | 91 | Uses an interface specific to the sequence cnn models; 92 | Requires a function that can take the currently decoded symbols and return 93 | the logits for the next symbol. The implementation is inspired by 94 | https://arxiv.org/abs/1609.08144. 95 | 96 | Args: 97 | symbols_to_logits_fn: Interface to the model, to provide logits. 98 | Should take [batch_size, decoded_ids] and return [ 99 | batch_size, vocab_size] 100 | initial_ids: Ids to start off the decoding, this will be the first 101 | thing handed to symbols_to_logits_fn 102 | (after expanding to beam size) [batch_size] 103 | beam_size: Size of the beam. 104 | decode_length: Number of steps to decode for. 105 | vocab_size: Size of the vocab, must equal the size of the logits 106 | returned by symbols_to_logits_fn 107 | alpha: alpha for length penalty. 108 | eos_id: ID for end of sentence. 109 | lp_constant: A float number for length penalty 110 | Returns: 111 | Tuple of 112 | (decoded beams [batch_size, beam_size, decode_length] 113 | decoding probabilities [batch_size, beam_size]) 114 | """ 115 | batch_size = tf.shape(initial_ids)[0] 116 | 117 | # Assume initial_ids are prob 1.0 118 | initial_log_probs = tf.constant([[0.] + [-float("inf")] * (beam_size - 1)]) 119 | # Expand to beam_size (batch_size, beam_size) 120 | alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1]) 121 | 122 | # Expand each batch to beam_size 123 | alive_seq = tf.tile(tf.expand_dims(initial_ids, 1), [1, beam_size]) 124 | alive_seq = tf.expand_dims(alive_seq, 2) # (batch_size, beam_size, 1) 125 | 126 | # Finished will keep track of all the sequences that have finished so far 127 | # Finished log probs will be negative infinity in the beginning 128 | # finished_flags will keep track of booleans 129 | finished_seq = tf.zeros(tf.shape(alive_seq), tf.int32) 130 | # Setting the scores of the initial to negative infinity. 131 | finished_scores = tf.ones([batch_size, beam_size]) * -INF 132 | finished_flags = tf.zeros([batch_size, beam_size], tf.bool) 133 | 134 | def grow_finished(finished_seq, finished_scores, finished_flags, curr_seq, 135 | curr_scores, curr_finished): 136 | """Given sequences and scores, will gather the top k=beam size 137 | sequences. 138 | 139 | Args: 140 | finished_seq: Current finished sequences. 141 | [batch_size, beam_size, current_decoded_length] 142 | finished_scores: scores for each of these sequences. 143 | [batch_size, beam_size] 144 | finished_flags: finished bools for each of these sequences. 145 | [batch_size, beam_size] 146 | curr_seq: current topk sequence that has been grown by one 147 | position. [batch_size, beam_size, current_decoded_length] 148 | curr_scores: scores for each of these sequences. 149 | [batch_size, beam_size] 150 | curr_finished: Finished flags for each of these sequences. 151 | [batch_size, beam_size] 152 | Returns: 153 | Tuple of 154 | (Top-k sequences based on scores, 155 | log probs of these sequences, 156 | Finished flags of these sequences) 157 | """ 158 | # First append a column of 0'ids to finished to make the same length 159 | # with finished scores 160 | finished_seq = tf.concat( 161 | [finished_seq, tf.zeros([batch_size, beam_size, 1], tf.int32)], 162 | axis=2 163 | ) 164 | 165 | # Set the scores of the unfinished seq in curr_seq to large negative 166 | # values 167 | curr_scores += (1. - tf.to_float(curr_finished)) * -INF 168 | # concatenating the sequences and scores along beam axis 169 | curr_finished_seq = tf.concat([finished_seq, curr_seq], axis=1) 170 | curr_finished_scores = tf.concat([finished_scores, curr_scores], 171 | axis=1) 172 | curr_finished_flags = tf.concat([finished_flags, curr_finished], 173 | axis=1) 174 | return compute_topk_scores_and_seq( 175 | curr_finished_seq, curr_finished_scores, curr_finished_scores, 176 | curr_finished_flags, beam_size, batch_size) 177 | 178 | def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished): 179 | """Given sequences and scores, will gather the top k=beam size 180 | sequences. 181 | 182 | Args: 183 | curr_seq: current topk sequence that has been grown by one 184 | position. [batch_size, beam_size, i+1] 185 | curr_scores: scores for each of these sequences. 186 | [batch_size, beam_size] 187 | curr_log_probs: log probs for each of these sequences. 188 | [batch_size, beam_size] 189 | curr_finished: Finished flags for each of these sequences. 190 | [batch_size, beam_size] 191 | Returns: 192 | Tuple of 193 | (Top-k sequences based on scores, 194 | log probs of these sequences, 195 | Finished flags of these sequences) 196 | """ 197 | # Set the scores of the finished seq in curr_seq to large negative 198 | # values 199 | curr_scores += tf.to_float(curr_finished) * -INF 200 | return compute_topk_scores_and_seq(curr_seq, curr_scores, 201 | curr_log_probs, curr_finished, 202 | beam_size, batch_size) 203 | 204 | def grow_topk(i, alive_seq, alive_log_probs): 205 | r"""Inner beam search loop. 206 | 207 | This function takes the current alive sequences, and grows them to 208 | topk sequences where k = 2*beam. We use 2*beam because, we could have 209 | beam_size number of sequences that might hit and there will be 210 | no alive sequences to continue. With 2*beam_size, this will not happen. 211 | This relies on the assumption the vocab size is > beam size. 212 | If this is true, we'll have at least beam_size non extensions if 213 | we extract the next top 2*beam words. 214 | Length penalty is given by = (5+len(decode)/6) ^ -\alpha. Pls refer to 215 | https://arxiv.org/abs/1609.08144. 216 | 217 | Args: 218 | i: loop index 219 | alive_seq: Topk sequences decoded so far 220 | [batch_size, beam_size, i+1] 221 | alive_log_probs: probabilities of these sequences. 222 | [batch_size, beam_size] 223 | Returns: 224 | Tuple of 225 | (Top-k sequences extended by the next word, 226 | The log probs of these sequences, 227 | The scores with length penalty of these sequences, 228 | Flags indicating which of these sequences have finished 229 | decoding) 230 | """ 231 | # Get the logits for all the possible next symbols 232 | flat_ids = tf.reshape(alive_seq, [batch_size * beam_size, -1]) 233 | 234 | # (batch_size * beam_size, decoded_length) 235 | flat_logits_list = symbols_to_logits_fn(flat_ids) 236 | logits_list = [ 237 | tf.reshape(flat_logits, (batch_size, beam_size, -1)) 238 | for flat_logits in flat_logits_list 239 | ] 240 | 241 | # Convert logits to normalized log probs 242 | candidate_log_probs = [ 243 | log_prob_from_logits(logits) 244 | for logits in logits_list 245 | ] 246 | 247 | n_models = len(candidate_log_probs) 248 | candidate_log_probs = tf.add_n(candidate_log_probs) / float(n_models) 249 | 250 | # Multiply the probabilities by the current probabilities of the beam. 251 | # (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1) 252 | log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, 253 | axis=2) 254 | 255 | length_penalty = tf.pow( 256 | ((lp_constant + tf.to_float(i + 1)) / (1.0 + lp_constant)), alpha 257 | ) 258 | 259 | curr_scores = log_probs / length_penalty 260 | # Flatten out (beam_size, vocab_size) probs in to a list of 261 | # possibilities 262 | flat_curr_scores = tf.reshape(curr_scores, 263 | [-1, beam_size * vocab_size]) 264 | 265 | topk_scores, topk_ids = tf.nn.top_k(flat_curr_scores, k=beam_size * 2) 266 | 267 | # Recovering the log probs because we will need to send them back 268 | topk_log_probs = topk_scores * length_penalty 269 | 270 | # Work out what beam the top probs are in. 271 | topk_beam_index = topk_ids // vocab_size 272 | topk_ids %= vocab_size # Unflatten the ids 273 | 274 | # The next three steps are to create coordinates for tf.gather_nd to 275 | # pull 276 | # out the correct sequences from id's that we need to grow. 277 | # We will also use the coordinates to gather the booleans of the beam 278 | # items that survived. 279 | batch_pos = compute_batch_indices(batch_size, beam_size * 2) 280 | 281 | # top beams will give us the actual coordinates to do the gather. 282 | # stacking will create a tensor of dimension batch * beam * 2, where 283 | # the last dimension contains the i,j gathering coordinates. 284 | topk_coordinates = tf.stack([batch_pos, topk_beam_index], axis=2) 285 | 286 | # Gather up the most probable 2*beams both for the ids and 287 | # finished_in_alive bools 288 | topk_seq = tf.gather_nd(alive_seq, topk_coordinates) 289 | 290 | # Append the most probable alive 291 | topk_seq = tf.concat([topk_seq, tf.expand_dims(topk_ids, axis=2)], 292 | axis=2) 293 | 294 | topk_finished = tf.equal(topk_ids, eos_id) 295 | 296 | return topk_seq, topk_log_probs, topk_scores, topk_finished 297 | 298 | def inner_loop(i, alive_seq, alive_log_probs, finished_seq, 299 | finished_scores, finished_flags): 300 | """Inner beam search loop. 301 | 302 | There are three groups of tensors, alive, finished, and topk. 303 | The alive group contains information about the current alive sequences 304 | The top-k group contains information about alive + topk current decoded 305 | words the finished group contains information about finished sentences, 306 | that is, the ones that have decoded to . These are what we return. 307 | The general beam search algorithm is as follows: 308 | While we haven't terminated (pls look at termination condition) 309 | 1. Grow the current alive to get beam*2 top-k sequences 310 | 2. Among the top-k, keep the top beam_size ones that haven't 311 | reached into alive 312 | 3. Among the top-k, keep the top beam_size ones have reached 313 | into finished 314 | Repeat 315 | To make things simple with using fixed size tensors, we will end 316 | up inserting unfinished sequences into finished in the beginning. To 317 | stop that we add -ve INF to the score of the unfinished sequence so 318 | that when a true finished sequence does appear, it will have a higher 319 | score than all the unfinished ones. 320 | 321 | Args: 322 | i: loop index 323 | alive_seq: Topk sequences decoded so far 324 | [batch_size, beam_size, i+1] 325 | alive_log_probs: probabilities of the beams. 326 | [batch_size, beam_size] 327 | finished_seq: Current finished sequences. 328 | [batch_size, beam_size, i+1] 329 | finished_scores: scores for each of these sequences. 330 | [batch_size, beam_size] 331 | finished_flags: finished bools for each of these sequences. 332 | [batch_size, beam_size] 333 | 334 | Returns: 335 | Tuple of 336 | (Incremented loop index 337 | New alive sequences, 338 | Log probs of the alive sequences, 339 | New finished sequences, 340 | Scores of the new finished sequences, 341 | Flags indicating which sequence in finished as reached EOS) 342 | """ 343 | 344 | # Each inner loop, we carry out three steps: 345 | # 1. Get the current topk items. 346 | # 2. Extract the ones that have finished and haven't finished 347 | # 3. Recompute the contents of finished based on scores. 348 | topk_seq, topk_log_probs, topk_scores, topk_finished = grow_topk( 349 | i, alive_seq, alive_log_probs 350 | ) 351 | alive_seq, alive_log_probs, _ = grow_alive(topk_seq, topk_scores, 352 | topk_log_probs, 353 | topk_finished) 354 | finished_seq, finished_scores, finished_flags = grow_finished( 355 | finished_seq, finished_scores, finished_flags, topk_seq, 356 | topk_scores, topk_finished 357 | ) 358 | 359 | return (i + 1, alive_seq, alive_log_probs, finished_seq, 360 | finished_scores, finished_flags) 361 | 362 | def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq, 363 | finished_scores, finished_in_finished): 364 | """Checking termination condition. 365 | 366 | We terminate when we decoded up to decode_length or the lowest scoring 367 | item in finished has a greater score that the highest prob item in 368 | alive divided by the max length penalty 369 | 370 | Args: 371 | i: loop index 372 | alive_log_probs: probabilities of the beams. 373 | [batch_size, beam_size] 374 | finished_scores: scores for each of these sequences. 375 | [batch_size, beam_size] 376 | finished_in_finished: finished bools for each of these sequences. 377 | [batch_size, beam_size] 378 | 379 | Returns: 380 | Bool. 381 | """ 382 | max_length_penalty = tf.pow(((5. + tf.to_float(decode_length)) / 6.), 383 | alpha) 384 | # The best possible score of the most likley alive sequence 385 | lower_bound_alive_scores = alive_log_probs[:, 0] / max_length_penalty 386 | 387 | # Now to compute the lowest score of a finished sequence in finished 388 | # If the sequence isn't finished, we multiply it's score by 0. since 389 | # scores are all -ve, taking the min will give us the score of the 390 | # lowest finished item. 391 | lowest_score_of_finished_in_finished = tf.reduce_min( 392 | finished_scores * tf.to_float(finished_in_finished), axis=1 393 | ) 394 | # If none of the sequences have finished, then the min will be 0 and 395 | # we have to replace it by -ve INF if it is. The score of any seq in 396 | # alive will be much higher than -ve INF and the termination condition 397 | # will not be met. 398 | lowest_score_of_finished_in_finished += ( 399 | (1. - tf.to_float(tf.reduce_any(finished_in_finished, 1))) * -INF 400 | ) 401 | 402 | bound_is_met = tf.reduce_all( 403 | tf.greater(lowest_score_of_finished_in_finished, 404 | lower_bound_alive_scores) 405 | ) 406 | 407 | return tf.logical_and(tf.less(i, decode_length), 408 | tf.logical_not(bound_is_met)) 409 | 410 | (_, alive_seq, alive_log_probs, finished_seq, finished_scores, 411 | finished_flags) = tf.while_loop( 412 | _is_finished, 413 | inner_loop, [ 414 | tf.constant(0), alive_seq, alive_log_probs, finished_seq, 415 | finished_scores, finished_flags 416 | ], 417 | shape_invariants=[ 418 | tf.TensorShape([]), 419 | tf.TensorShape([None, None, None]), 420 | alive_log_probs.get_shape(), 421 | tf.TensorShape([None, None, None]), 422 | finished_scores.get_shape(), 423 | finished_flags.get_shape() 424 | ], 425 | parallel_iterations=1, 426 | back_prop=False 427 | ) 428 | 429 | alive_seq.set_shape((None, beam_size, None)) 430 | finished_seq.set_shape((None, beam_size, None)) 431 | 432 | # Accounting for corner case: It's possible that no sequence in alive for 433 | # a particular batch item ever reached . In that case, we should just 434 | # copy the contents of alive for that batch item. 435 | # tf.reduce_any(finished_flags, 1) if 0, means that no sequence for that 436 | # batch index had reached EOS. We need to do the same for the scores as 437 | # well. 438 | finished_seq = tf.where( 439 | tf.reduce_any(finished_flags, 1), finished_seq, alive_seq) 440 | finished_scores = tf.where( 441 | tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs) 442 | return finished_seq, finished_scores 443 | 444 | 445 | def create_inference_graph(model_fns, features, params): 446 | if not isinstance(model_fns, (list, tuple)): 447 | model_fns = [model_fns] 448 | 449 | decode_length = params.decode_length 450 | beam_size = params.beam_size 451 | top_beams = params.top_beams 452 | alpha = params.decode_alpha 453 | 454 | # [batch, decoded_ids] => [batch, vocab_size] 455 | def symbols_to_logits_fn(decoded_ids): 456 | features["target"] = tf.pad(decoded_ids[:, 1:], [[0, 0], [0, 1]]) 457 | features["target_length"] = tf.fill([tf.shape(features["target"])[0]], 458 | tf.shape(features["target"])[1]) 459 | 460 | results = [] 461 | 462 | for i, model_fn in enumerate(model_fns): 463 | results.append(model_fn(features)) 464 | 465 | return results 466 | 467 | batch_size = tf.shape(features["source"])[0] 468 | # Prepend symbol 469 | bos_id = params.mapping["target"][params.bos] 470 | initial_ids = tf.fill([batch_size], tf.constant(bos_id, dtype=tf.int32)) 471 | 472 | inputs_old = features["source"] 473 | inputs_length_old = features["source_length"] 474 | 475 | # Expand the inputs in to the beam size 476 | # [batch, length] => [batch, beam_size, length] 477 | features["source"] = tf.expand_dims(features["source"], 1) 478 | features["source"] = tf.tile(features["source"], [1, beam_size, 1]) 479 | shape = tf.shape(features["source"]) 480 | 481 | # [batch, beam_size, length] => [batch * beam_size, length] 482 | features["source"] = tf.reshape(features["source"], 483 | [shape[0] * shape[1], shape[2]]) 484 | 485 | # For source sequence length 486 | features["source_length"] = tf.expand_dims(features["source_length"], 1) 487 | features["source_length"] = tf.tile(features["source_length"], 488 | [1, beam_size]) 489 | shape = tf.shape(features["source_length"]) 490 | 491 | # [batch, beam_size, length] => [batch * beam_size, length] 492 | features["source_length"] = tf.reshape(features["source_length"], 493 | [shape[0] * shape[1]]) 494 | 495 | vocab_size = len(params.vocabulary["target"]) 496 | # Setting decode length to input length + decode_length 497 | decode_length = tf.shape(features["source"])[1] + decode_length 498 | 499 | ids, scores = beam_search(symbols_to_logits_fn, initial_ids, 500 | beam_size, decode_length, vocab_size, 501 | alpha, 502 | eos_id=params.mapping["target"][params.eos], 503 | lp_constant=params.decode_constant) 504 | 505 | # Set inputs back to the unexpanded inputs to not to confuse the Estimator 506 | features["source"] = inputs_old 507 | features["source_length"] = inputs_length_old 508 | 509 | # Return `top_beams` decoding 510 | # (also remove initial id from the beam search) 511 | if not params.decode_normalize: 512 | if top_beams == 1: 513 | return ids[:, 0, 1:] 514 | else: 515 | return ids[:, :top_beams, 1:] 516 | else: 517 | if top_beams == 1: 518 | return ids[:, 0, 1:], scores[:, 0] 519 | else: 520 | return ids[:, :top_beams, 1:], scores[:, :top_beams] 521 | --------------------------------------------------------------------------------