├── LICENSE ├── README.md ├── UserManual.pdf ├── docs ├── UserManual.tex ├── everb.sty └── thumt.bib └── thumt ├── __init__.py ├── bin ├── scorer.py ├── trainer.py ├── trainer_ctx.py ├── translator.py └── translator_ctx.py ├── data ├── __init__.py ├── cache.py ├── dataset.py ├── record.py └── vocab.py ├── interface ├── __init__.py └── model.py ├── layers ├── __init__.py ├── attention.py ├── nn.py └── rnn_cell.py ├── models ├── __init__.py ├── contextual_transformer.py ├── rnnsearch.py ├── seq2seq.py └── transformer.py ├── scripts ├── build_vocab.py ├── change.py ├── check_param.py ├── checkpoint_averaging.py ├── combine.py ├── combine_add.py ├── compare.py ├── convert_old_model.py ├── convert_vocab.py ├── input_converter.py └── shuffle_corpus.py └── utils ├── __init__.py ├── bleu.py ├── hooks.py ├── inference.py ├── inference_ctx.py ├── optimize.py ├── parallel.py ├── sample.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Natural Language Processing Lab at Tsinghua University 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, this 11 | list of conditions and the following disclaimer in the documentation and/or 12 | other materials provided with the distribution. 13 | 14 | * Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from this 16 | software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 22 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 25 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improving the Transformer Translation Model with Document-Level Context 2 | ## Contents 3 | * [Introduction](#introduction) 4 | * [Usage](#usage) 5 | * [Citation](#citation) 6 | * [FAQ](#faq) 7 | 8 | ## Introduction 9 | 10 | This is the implementation of our work, which extends Transformer to integrate document-level context \[[paper](https://arxiv.org/abs/1810.03581)\]. The implementation is on top of [THUMT](https://github.com/thumt/THUMT) 11 | 12 | ## Usage 13 | 14 | Note: The usage is not user-friendly. May improve later. 15 | 16 | 1. Train a standard Transformer model, please refer to the user manual of [THUMT](https://github.com/thumt/THUMT). Suppose that model_baseline/model.ckpt-30000 performs best on validation set. 17 | 18 | 2. Generate a dummy improved Transformer model with the following command: 19 | 20 |
python THUMT/thumt/bin/trainer_ctx.py --inputs [source corpus] [target corpus] \
 21 |                                       --context [context corpus] \
 22 |                                       --vocabulary [source vocabulary] [target vocabulary] \
 23 |                                       --output model_dummy --model contextual_transformer \
 24 |                                       --parameters train_steps=1
 25 | 
26 | 27 | 3. Generate the initial model by merging the standard Transformer model into the dummy model, then create a checkpoint file: 28 | 29 |
python THUMT/thumt/scripts/combine_add.py --model model_dummy/model.ckpt-0 \
 30 |                                          --part model_baseline/model.ckpt-30000 --output train
 31 | printf 'model_checkpoint_path: "new-0"\nall_model_checkpoint_paths: "new-0"' > train/checkpoint
 32 | 
33 | 34 | 35 | 4. Train the improved Transformer model with the following command: 36 | 37 |
python THUMT/thumt/bin/trainer_ctx.py --inputs [source corpus] [target corpus] \
 38 |                                       --context [context corpus] \
 39 |                                       --vocabulary [source vocabulary] [target vocabulary] \
 40 |                                       --output train --model contextual_transformer \
 41 |                                       --parameters start_steps=30000,num_context_layers=1
 42 | 
43 | 44 | 5. Translate with the improved Transformer model: 45 | 46 |
python THUMT/thumt/bin/translator_ctx.py --inputs [source corpus] --context [context corpus] \
 47 |                                          --output [translation result] \
 48 |                                          --vocabulary [source vocabulary] [target vocabulary] \
 49 |                                          --model contextual_transformer --checkpoints [model path] \
 50 |                                          --parameters num_context_layers=1
 51 | 
52 | 53 | ## Citation 54 | 55 | Please cite the following paper if you use the code: 56 | 57 |
@InProceedings{Zhang:18,
 58 |   author    = {Zhang, Jiacheng and Luan, Huanbo and Sun, Maosong and Zhai, Feifei and Xu, Jingfang and Zhang, Min and Liu, Yang},
 59 |   title     = {Improving the Transformer Translation Model with Document-Level Context},
 60 |   booktitle = {Proceedings of EMNLP},
 61 |   year      = {2018},
 62 | }
 63 | 
64 | 65 | 66 | ## FAQ 67 | 68 | 1. What is the context corpus? 69 | 70 | The context corpus file contains one context sentence each line. Normally, context sentence is the several preceding source sentences within a document. For example, if the origin document-level corpus is: 71 | 72 |
==== source ====
 73 | <document id=XXX>
 74 | <seg id=1>source sentence #1</seg>
 75 | <seg id=2>source sentence #2</seg>
 76 | <seg id=3>source sentence #3</seg>
 77 | <seg id=4>source sentence #4</seg>
 78 | </document>
 79 | 
 80 | ==== target ====
 81 | <document id=XXX>
 82 | <seg id=1>target sentence #1</seg>
 83 | <seg id=2>target sentence #2</seg>
 84 | <seg id=3>target sentence #3</seg>
 85 | <seg id=4>target sentence #4</seg>
 86 | </document>
87 | 88 | The inputs to our system should be processed as (suppose that 2 preceding source sentences are used as context): 89 | 90 |
==== train.src ==== (source corpus)
 91 | source sentence #1
 92 | source sentence #2
 93 | source sentence #3
 94 | source sentence #4
 95 | 
 96 | ==== train.ctx ==== (context corpus)
 97 | (the first line is empty)
 98 | source sentence #1
 99 | source sentence #1 source sentence #2 (there is only a space between the two sentence)
100 | source sentence #2 source sentence #3
101 | 
102 | ==== train.trg ==== (target corpus)
103 | target sentence #1
104 | target sentence #2
105 | target sentence #3
106 | target sentence #4
107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /UserManual.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-MT/Document-Transformer/5bcc7f43cc948240fa0e3a400bffdc178f841fcd/UserManual.pdf -------------------------------------------------------------------------------- /docs/thumt.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{Bahdanau:15, 2 | author = {Bahdanau, Dzmitry and Cho, KyungHyun and Bengio, Yoshua}, 3 | title = {Neural Machine Translation by Jointly Learning to Align and Translate}, 4 | booktitle = {Proceedings of ICLR}, 5 | year = {2015}, 6 | } 7 | 8 | @article{Brown:93, 9 | author = "Brown, Peter F. and Della Pietra, Stephen A. and Della Pietra, Vincent J. and Mercer, Robert L.", 10 | title = "The mathematics of statistical machine translation: Parameter estimation", 11 | journal = "Computational Linguistics", 12 | year = "1993", 13 | } 14 | 15 | @InProceedings{Cheng:16, 16 | author = {Cheng, Yong and Xu, Wei and He, Zhongjun and He, Wei and Wu, Hua and Sun, Maosong and Liu, Yang}, 17 | title = {Semi-Supervised Learning for Neural Machine Translation}, 18 | booktitle = {Proceedings of ACL}, 19 | year = {2016}, 20 | } 21 | 22 | @InProceedings{Chiang:05, 23 | author = {Chiang, David}, 24 | title = {A Hierarchical Phrase-based Model for Statistical Machine Translation}, 25 | booktitle = {Proceedings of ACL}, 26 | year = {2005}, 27 | } 28 | 29 | 30 | @InProceedings{Ding:17, 31 | author = {Ding, Yanzhuo and Liu, Yang and Luan, Huanbo and Sun, Maosong}, 32 | title = {Visualizing and Understanding Neural Machine Translation}, 33 | booktitle = {Proceedings of ACL}, 34 | year = {2017}, 35 | } 36 | 37 | @misc{Kingma:14, 38 | author = {Kingma, Diederik P. and Ba, Jimmy}, 39 | title = {Adam: A Method for Stochastic Optimization}, 40 | howpublished = {arXiv:1412.6980}, 41 | year = {2014}, 42 | } 43 | 44 | @InProceedings{Koehn:03, 45 | author = {Koehn, Philipp and Och, Franz J. and Marcu, Daniel}, 46 | title = {Statistical Phrase-based Translation}, 47 | booktitle = {Proceedings of NAACL}, 48 | year = {2003}, 49 | } 50 | 51 | @InProceedings{Luong:15, 52 | author = {Luong, Thang and Sutskever, Ilya and Le, Quoc and Vinyals, Oriol and Zaremba, Wojciech}, 53 | title = {Addressing the Rare Word Problem in Neural Machine Translation}, 54 | booktitle = {Proceedings of ACL}, 55 | year = {2015}, 56 | } 57 | 58 | @inproceedings{Papineni:02, 59 | author = {Papineni, Kishore and Roukos, Salim and Ward, Todd and Zhu, Wei-Jing}, 60 | title = {BLEU: A Method for Automatic Evaluation of Machine Translation}, 61 | booktitle = {Proceedings of ACL}, 62 | year = {2002}, 63 | } 64 | 65 | @InProceedings{Sennrich:16, 66 | author = {Sennrich, Rico and Haddow, Barry and Birch, Alexandra}, 67 | title = {Neural Machine Translation of Rare Words with Subword Units}, 68 | booktitle = {Proceedings of ACL}, 69 | year = {2016}, 70 | } 71 | 72 | @inproceedings{Shen:16, 73 | author = {Shen, Shiqi and Cheng, Yong and He, Zhongjun and He, Wei and Wu, Hua and Sun, Maosong and Liu, Yang}, 74 | title = {Minimum Risk Training for Neural Machine Translation}, 75 | booktitle = {Proceedings of ACL}, 76 | year = {2016}, 77 | } 78 | 79 | @inproceedings{Sutskever:14, 80 | author = {Sutskever, Ilya and Vinyals, Oriol and Le, Quoc V.}, 81 | title = {Sequence to Sequence Learning with Neural Networks}, 82 | booktitle = {Proceedings of NIPS}, 83 | year = {2014}, 84 | } 85 | 86 | @inproceedings{Vaswani:17, 87 | title={Attention Is All You Need}, 88 | author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and Kaiser, Lukasz and Polosukhin, Illia}, 89 | booktitle={Proceedings of NIPS}, 90 | year={2017} 91 | } 92 | 93 | @misc{Wu:16, 94 | author= {Yonghui Wu and 95 | Mike Schuster and 96 | Zhifeng Chen and 97 | Quoc V. Le and 98 | Mohammad Norouzi and 99 | Wolfgang Macherey and 100 | Maxim Krikun and 101 | Yuan Cao and 102 | Qin Gao and 103 | Klaus Macherey and 104 | Jeff Klingner and 105 | Apurva Shah and 106 | Melvin Johnson and 107 | Xiaobing Liu and 108 | Lukasz Kaiser and 109 | Stephan Gouws and 110 | Yoshikiyo Kato and 111 | Taku Kudo and 112 | Hideto Kazawa and 113 | Keith Stevens and 114 | George Kurian and 115 | Nishant Patil and 116 | Wei Wang and 117 | Cliff Young and 118 | Jason Smith and 119 | Jason Riesa and 120 | Alex Rudnick and 121 | Oriol Vinyals and 122 | Greg Corrado and 123 | Macduff Hughes and 124 | Jeffrey Dean}, 125 | title = {Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation}, 126 | howpublished = {arXiv:1609.08144v2}, 127 | year = {2016}, 128 | } 129 | 130 | @misc{Zeiler:12, 131 | author = {Zeiler, Matthew D.}, 132 | title = {AdaDelta: An Adaptive Learning Rate Method}, 133 | howpublished = {arXiv:1212.5701v1}, 134 | year = {2012}, 135 | } -------------------------------------------------------------------------------- /thumt/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | -------------------------------------------------------------------------------- /thumt/bin/scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import itertools 11 | import os 12 | 13 | import tensorflow as tf 14 | import thumt.data.vocab as vocabulary 15 | import thumt.models as models 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser( 20 | description="Translate using existing NMT models", 21 | usage="translator.py [] [-h | --help]" 22 | ) 23 | 24 | # input files 25 | parser.add_argument("--input", type=str, required=True, nargs=2, 26 | help="Path of input file") 27 | parser.add_argument("--output", type=str, required=True, 28 | help="Path of output file") 29 | parser.add_argument("--checkpoint", type=str, required=True, 30 | help="Path of trained models") 31 | parser.add_argument("--vocabulary", type=str, nargs=2, required=True, 32 | help="Path of source and target vocabulary") 33 | 34 | # model and configuration 35 | parser.add_argument("--model", type=str, required=True, 36 | help="Name of the model") 37 | parser.add_argument("--parameters", type=str, 38 | help="Additional hyper parameters") 39 | 40 | return parser.parse_args() 41 | 42 | 43 | def default_parameters(): 44 | params = tf.contrib.training.HParams( 45 | input=None, 46 | output=None, 47 | vocabulary=None, 48 | model=None, 49 | # vocabulary specific 50 | pad="", 51 | bos="", 52 | eos="", 53 | unk="", 54 | mapping=None, 55 | append_eos=False, 56 | device_list=[0], 57 | num_threads=6, 58 | eval_batch_size=32 59 | ) 60 | 61 | return params 62 | 63 | 64 | def merge_parameters(params1, params2): 65 | params = tf.contrib.training.HParams() 66 | 67 | for (k, v) in params1.values().iteritems(): 68 | params.add_hparam(k, v) 69 | 70 | params_dict = params.values() 71 | 72 | for (k, v) in params2.values().iteritems(): 73 | if k in params_dict: 74 | # Override 75 | setattr(params, k, v) 76 | else: 77 | params.add_hparam(k, v) 78 | 79 | return params 80 | 81 | 82 | def import_params(model_dir, model_name, params): 83 | model_dir = os.path.abspath(model_dir) 84 | m_name = os.path.join(model_dir, model_name + ".json") 85 | 86 | if not tf.gfile.Exists(m_name): 87 | return params 88 | 89 | with tf.gfile.Open(m_name) as fd: 90 | tf.logging.info("Restoring model parameters from %s" % m_name) 91 | json_str = fd.readline() 92 | params.parse_json(json_str) 93 | 94 | return params 95 | 96 | 97 | def override_parameters(params, args): 98 | if args.parameters: 99 | params.parse(args.parameters) 100 | 101 | params.vocabulary = { 102 | "source": vocabulary.load_vocabulary(args.vocabulary[0]), 103 | "target": vocabulary.load_vocabulary(args.vocabulary[1]) 104 | } 105 | params.vocabulary["source"] = vocabulary.process_vocabulary( 106 | params.vocabulary["source"], params 107 | ) 108 | params.vocabulary["target"] = vocabulary.process_vocabulary( 109 | params.vocabulary["target"], params 110 | ) 111 | 112 | control_symbols = [params.pad, params.bos, params.eos, params.unk] 113 | 114 | params.mapping = { 115 | "source": vocabulary.get_control_mapping( 116 | params.vocabulary["source"], 117 | control_symbols 118 | ), 119 | "target": vocabulary.get_control_mapping( 120 | params.vocabulary["target"], 121 | control_symbols 122 | ) 123 | } 124 | 125 | return params 126 | 127 | 128 | def session_config(params): 129 | optimizer_options = tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L1, 130 | do_function_inlining=False) 131 | graph_options = tf.GraphOptions(optimizer_options=optimizer_options) 132 | config = tf.ConfigProto(allow_soft_placement=True, 133 | graph_options=graph_options) 134 | if params.device_list: 135 | device_str = ",".join([str(i) for i in params.device_list]) 136 | config.gpu_options.visible_device_list = device_str 137 | 138 | return config 139 | 140 | 141 | def set_variables(var_list, value_dict, prefix): 142 | ops = [] 143 | for var in var_list: 144 | for name in value_dict: 145 | var_name = "/".join([prefix] + list(name.split("/")[1:])) 146 | 147 | if var.name[:-2] == var_name: 148 | tf.logging.debug("restoring %s -> %s" % (name, var.name)) 149 | with tf.device("/cpu:0"): 150 | op = tf.assign(var, value_dict[name]) 151 | ops.append(op) 152 | break 153 | 154 | return ops 155 | 156 | 157 | def read_files(names): 158 | inputs = [[] for _ in range(len(names))] 159 | files = [tf.gfile.GFile(name) for name in names] 160 | 161 | count = 0 162 | 163 | for lines in zip(*files): 164 | lines = [line.strip() for line in lines] 165 | 166 | for i, line in enumerate(lines): 167 | inputs[i].append(line) 168 | 169 | count += 1 170 | 171 | # Close files 172 | for fd in files: 173 | fd.close() 174 | 175 | return inputs 176 | 177 | 178 | def get_features(inputs, params): 179 | with tf.device("/cpu:0"): 180 | # Create datasets 181 | datasets = [] 182 | 183 | for data in inputs: 184 | dataset = tf.data.Dataset.from_tensor_slices(data) 185 | # Split string 186 | dataset = dataset.map(lambda x: tf.string_split([x]).values, 187 | num_parallel_calls=params.num_threads) 188 | # Append 189 | dataset = dataset.map( 190 | lambda x: tf.concat([x, [tf.constant(params.eos)]], axis=0), 191 | num_parallel_calls=params.num_threads 192 | ) 193 | datasets.append(dataset) 194 | 195 | dataset = tf.data.Dataset.zip(tuple(datasets)) 196 | 197 | # Convert tuple to dictionary 198 | dataset = dataset.map( 199 | lambda *x: { 200 | "source": x[0], 201 | "source_length": tf.shape(x[0])[0], 202 | "target": x[1], 203 | "target_length": tf.shape(x[1])[0] 204 | }, 205 | num_parallel_calls=params.num_threads 206 | ) 207 | 208 | dataset = dataset.padded_batch( 209 | params.eval_batch_size, 210 | { 211 | "source": [tf.Dimension(None)], 212 | "source_length": [], 213 | "target": [tf.Dimension(None)], 214 | "target_length": [] 215 | }, 216 | { 217 | "source": params.pad, 218 | "source_length": 0, 219 | "target": params.pad, 220 | "target_length": 0 221 | } 222 | ) 223 | 224 | iterator = dataset.make_one_shot_iterator() 225 | features = iterator.get_next() 226 | 227 | src_table = tf.contrib.lookup.index_table_from_tensor( 228 | tf.constant(params.vocabulary["source"]), 229 | default_value=params.mapping["source"][params.unk] 230 | ) 231 | tgt_table = tf.contrib.lookup.index_table_from_tensor( 232 | tf.constant(params.vocabulary["target"]), 233 | default_value=params.mapping["target"][params.unk] 234 | ) 235 | features["source"] = src_table.lookup(features["source"]) 236 | features["target"] = tgt_table.lookup(features["target"]) 237 | 238 | return features 239 | 240 | 241 | def main(args): 242 | tf.logging.set_verbosity(tf.logging.INFO) 243 | model_cls = models.get_model(args.model) 244 | params = default_parameters() 245 | 246 | # Import and override parameters 247 | # Priorities (low -> high): 248 | # default -> saved -> command 249 | params = merge_parameters(params, model_cls.get_parameters()) 250 | params = import_params(args.checkpoint, args.model, params) 251 | override_parameters(params, args) 252 | 253 | # Build Graph 254 | with tf.Graph().as_default(): 255 | model = model_cls(params) 256 | inputs = read_files(args.input) 257 | features = get_features(inputs, params) 258 | score_fn = model.get_evaluation_func() 259 | scores = score_fn(features, params) 260 | 261 | sess_creator = tf.train.ChiefSessionCreator( 262 | config=session_config(params) 263 | ) 264 | 265 | # Load checkpoint 266 | tf.logging.info("Loading %s" % args.checkpoint) 267 | var_list = tf.train.list_variables(args.checkpoint) 268 | values = {} 269 | reader = tf.train.load_checkpoint(args.checkpoint) 270 | 271 | for (name, shape) in var_list: 272 | if not name.startswith(model_cls.get_name()): 273 | continue 274 | 275 | tensor = reader.get_tensor(name) 276 | values[name] = tensor 277 | 278 | ops = set_variables(tf.trainable_variables(), values, 279 | model_cls.get_name()) 280 | assign_op = tf.group(*ops) 281 | 282 | # Create session 283 | with tf.train.MonitoredSession(session_creator=sess_creator) as sess: 284 | # Restore variables 285 | sess.run(assign_op) 286 | fd = tf.gfile.Open(args.output, "w") 287 | 288 | while not sess.should_stop(): 289 | results = sess.run(scores) 290 | for value in results: 291 | fd.write("%f\n" % value) 292 | 293 | fd.close() 294 | 295 | 296 | if __name__ == "__main__": 297 | main(parse_args()) 298 | -------------------------------------------------------------------------------- /thumt/bin/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import os 11 | import six 12 | 13 | import numpy as np 14 | import tensorflow as tf 15 | import thumt.data.cache as cache 16 | import thumt.data.dataset as dataset 17 | import thumt.data.record as record 18 | import thumt.data.vocab as vocabulary 19 | import thumt.models as models 20 | import thumt.utils.hooks as hooks 21 | import thumt.utils.inference as inference 22 | import thumt.utils.optimize as optimize 23 | import thumt.utils.parallel as parallel 24 | import thumt.utils.utils as utils 25 | 26 | 27 | def parse_args(args=None): 28 | parser = argparse.ArgumentParser( 29 | description="Training neural machine translation models", 30 | usage="trainer.py [] [-h | --help]" 31 | ) 32 | 33 | # input files 34 | parser.add_argument("--input", type=str, nargs=2, 35 | help="Path of source and target corpus") 36 | parser.add_argument("--record", type=str, 37 | help="Path to tf.Record data") 38 | parser.add_argument("--output", type=str, default="train", 39 | help="Path to saved models") 40 | parser.add_argument("--vocabulary", type=str, nargs=2, 41 | help="Path of source and target vocabulary") 42 | parser.add_argument("--validation", type=str, 43 | help="Path of validation file") 44 | parser.add_argument("--references", type=str, nargs="+", 45 | help="Path of reference files") 46 | 47 | # model and configuration 48 | parser.add_argument("--model", type=str, required=True, 49 | help="Name of the model") 50 | parser.add_argument("--parameters", type=str, default="", 51 | help="Additional hyper parameters") 52 | 53 | return parser.parse_args(args) 54 | 55 | 56 | def default_parameters(): 57 | params = tf.contrib.training.HParams( 58 | input=["", ""], 59 | output="", 60 | record="", 61 | model="transformer", 62 | vocab=["", ""], 63 | # Default training hyper parameters 64 | num_threads=6, 65 | batch_size=4096, 66 | max_length=256, 67 | length_multiplier=1, 68 | mantissa_bits=2, 69 | warmup_steps=4000, 70 | train_steps=100000, 71 | buffer_size=10000, 72 | constant_batch_size=False, 73 | device_list=[0], 74 | update_cycle=1, 75 | initializer="uniform_unit_scaling", 76 | initializer_gain=1.0, 77 | optimizer="Adam", 78 | adam_beta1=0.9, 79 | adam_beta2=0.999, 80 | adam_epsilon=1e-8, 81 | clip_grad_norm=5.0, 82 | learning_rate=1.0, 83 | learning_rate_decay="linear_warmup_rsqrt_decay", 84 | learning_rate_boundaries=[0], 85 | learning_rate_values=[0.0], 86 | keep_checkpoint_max=20, 87 | keep_top_checkpoint_max=5, 88 | # Validation 89 | eval_steps=2000, 90 | eval_secs=0, 91 | eval_batch_size=32, 92 | top_beams=1, 93 | beam_size=4, 94 | decode_alpha=0.6, 95 | decode_length=50, 96 | validation="", 97 | references=[""], 98 | save_checkpoint_secs=0, 99 | save_checkpoint_steps=1000, 100 | # Setting this to True can save disk spaces, but cannot restore 101 | # training using the saved checkpoint 102 | only_save_trainable=False 103 | ) 104 | 105 | return params 106 | 107 | 108 | def import_params(model_dir, model_name, params): 109 | model_dir = os.path.abspath(model_dir) 110 | p_name = os.path.join(model_dir, "params.json") 111 | m_name = os.path.join(model_dir, model_name + ".json") 112 | 113 | if not tf.gfile.Exists(p_name) or not tf.gfile.Exists(m_name): 114 | return params 115 | 116 | with tf.gfile.Open(p_name) as fd: 117 | tf.logging.info("Restoring hyper parameters from %s" % p_name) 118 | json_str = fd.readline() 119 | params.parse_json(json_str) 120 | 121 | with tf.gfile.Open(m_name) as fd: 122 | tf.logging.info("Restoring model parameters from %s" % m_name) 123 | json_str = fd.readline() 124 | params.parse_json(json_str) 125 | 126 | return params 127 | 128 | 129 | def export_params(output_dir, name, params): 130 | if not tf.gfile.Exists(output_dir): 131 | tf.gfile.MkDir(output_dir) 132 | 133 | # Save params as params.json 134 | filename = os.path.join(output_dir, name) 135 | with tf.gfile.Open(filename, "w") as fd: 136 | fd.write(params.to_json()) 137 | 138 | 139 | def collect_params(all_params, params): 140 | collected = tf.contrib.training.HParams() 141 | 142 | for k in params.values().iterkeys(): 143 | collected.add_hparam(k, getattr(all_params, k)) 144 | 145 | return collected 146 | 147 | 148 | def merge_parameters(params1, params2): 149 | params = tf.contrib.training.HParams() 150 | 151 | for (k, v) in params1.values().iteritems(): 152 | params.add_hparam(k, v) 153 | 154 | params_dict = params.values() 155 | 156 | for (k, v) in params2.values().iteritems(): 157 | if k in params_dict: 158 | # Override 159 | setattr(params, k, v) 160 | else: 161 | params.add_hparam(k, v) 162 | 163 | return params 164 | 165 | 166 | def override_parameters(params, args): 167 | params.model = args.model 168 | params.input = args.input or params.input 169 | params.output = args.output or params.output 170 | params.record = args.record or params.record 171 | params.vocab = args.vocabulary or params.vocab 172 | params.validation = args.validation or params.validation 173 | params.references = args.references or params.references 174 | params.parse(args.parameters) 175 | 176 | params.vocabulary = { 177 | "source": vocabulary.load_vocabulary(params.vocab[0]), 178 | "target": vocabulary.load_vocabulary(params.vocab[1]) 179 | } 180 | params.vocabulary["source"] = vocabulary.process_vocabulary( 181 | params.vocabulary["source"], params 182 | ) 183 | params.vocabulary["target"] = vocabulary.process_vocabulary( 184 | params.vocabulary["target"], params 185 | ) 186 | 187 | control_symbols = [params.pad, params.bos, params.eos, params.unk] 188 | 189 | params.mapping = { 190 | "source": vocabulary.get_control_mapping( 191 | params.vocabulary["source"], 192 | control_symbols 193 | ), 194 | "target": vocabulary.get_control_mapping( 195 | params.vocabulary["target"], 196 | control_symbols 197 | ) 198 | } 199 | 200 | return params 201 | 202 | 203 | def get_initializer(params): 204 | if params.initializer == "uniform": 205 | max_val = params.initializer_gain 206 | return tf.random_uniform_initializer(-max_val, max_val) 207 | elif params.initializer == "normal": 208 | return tf.random_normal_initializer(0.0, params.initializer_gain) 209 | elif params.initializer == "normal_unit_scaling": 210 | return tf.variance_scaling_initializer(params.initializer_gain, 211 | mode="fan_avg", 212 | distribution="normal") 213 | elif params.initializer == "uniform_unit_scaling": 214 | return tf.variance_scaling_initializer(params.initializer_gain, 215 | mode="fan_avg", 216 | distribution="uniform") 217 | else: 218 | raise ValueError("Unrecognized initializer: %s" % params.initializer) 219 | 220 | 221 | def get_learning_rate_decay(learning_rate, global_step, params): 222 | if params.learning_rate_decay in ["linear_warmup_rsqrt_decay", "noam"]: 223 | step = tf.to_float(global_step) 224 | warmup_steps = tf.to_float(params.warmup_steps) 225 | multiplier = params.hidden_size ** -0.5 226 | decay = multiplier * tf.minimum((step + 1) * (warmup_steps ** -1.5), 227 | (step + 1) ** -0.5) 228 | 229 | return learning_rate * decay 230 | elif params.learning_rate_decay == "piecewise_constant": 231 | return tf.train.piecewise_constant(tf.to_int32(global_step), 232 | params.learning_rate_boundaries, 233 | params.learning_rate_values) 234 | elif params.learning_rate_decay == "none": 235 | return learning_rate 236 | else: 237 | raise ValueError("Unknown learning_rate_decay") 238 | 239 | 240 | def session_config(params): 241 | optimizer_options = tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L1, 242 | do_function_inlining=True) 243 | graph_options = tf.GraphOptions(optimizer_options=optimizer_options) 244 | config = tf.ConfigProto(allow_soft_placement=True, 245 | graph_options=graph_options) 246 | if params.device_list: 247 | device_str = ",".join([str(i) for i in params.device_list]) 248 | config.gpu_options.visible_device_list = device_str 249 | 250 | return config 251 | 252 | 253 | def decode_target_ids(inputs, params): 254 | decoded = [] 255 | vocab = params.vocabulary["target"] 256 | 257 | for item in inputs: 258 | syms = [] 259 | for idx in item: 260 | if isinstance(idx, six.integer_types): 261 | sym = vocab[idx] 262 | else: 263 | sym = idx 264 | 265 | if sym == params.eos: 266 | break 267 | 268 | if sym == params.pad: 269 | break 270 | 271 | syms.append(sym) 272 | decoded.append(syms) 273 | 274 | return decoded 275 | 276 | 277 | def main(args): 278 | tf.logging.set_verbosity(tf.logging.INFO) 279 | model_cls = models.get_model(args.model) 280 | params = default_parameters() 281 | 282 | # Import and override parameters 283 | # Priorities (low -> high): 284 | # default -> saved -> command 285 | params = merge_parameters(params, model_cls.get_parameters()) 286 | params = import_params(args.output, args.model, params) 287 | override_parameters(params, args) 288 | 289 | # Export all parameters and model specific parameters 290 | export_params(params.output, "params.json", params) 291 | export_params( 292 | params.output, 293 | "%s.json" % args.model, 294 | collect_params(params, model_cls.get_parameters()) 295 | ) 296 | 297 | # Build Graph 298 | with tf.Graph().as_default(): 299 | if not params.record: 300 | # Build input queue 301 | features = dataset.get_training_input(params.input, params) 302 | else: 303 | features = record.get_input_features( 304 | os.path.join(params.record, "*train*"), "train", params 305 | ) 306 | 307 | features, init_op = cache.cache_features(features, 308 | params.update_cycle) 309 | 310 | # Build model 311 | initializer = get_initializer(params) 312 | model = model_cls(params) 313 | 314 | # Multi-GPU setting 315 | sharded_losses = parallel.parallel_model( 316 | model.get_training_func(initializer), 317 | features, 318 | params.device_list 319 | ) 320 | loss = tf.add_n(sharded_losses) / len(sharded_losses) 321 | 322 | # Create global step 323 | global_step = tf.train.get_or_create_global_step() 324 | 325 | # Print parameters 326 | all_weights = {v.name: v for v in tf.trainable_variables()} 327 | total_size = 0 328 | 329 | for v_name in sorted(list(all_weights)): 330 | v = all_weights[v_name] 331 | tf.logging.info("%s\tshape %s", v.name[:-2].ljust(80), 332 | str(v.shape).ljust(20)) 333 | v_size = np.prod(np.array(v.shape.as_list())).tolist() 334 | total_size += v_size 335 | tf.logging.info("Total trainable variables size: %d", total_size) 336 | 337 | learning_rate = get_learning_rate_decay(params.learning_rate, 338 | global_step, params) 339 | learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32) 340 | tf.summary.scalar("learning_rate", learning_rate) 341 | 342 | # Create optimizer 343 | if params.optimizer == "Adam": 344 | opt = tf.train.AdamOptimizer(learning_rate, 345 | beta1=params.adam_beta1, 346 | beta2=params.adam_beta2, 347 | epsilon=params.adam_epsilon) 348 | elif params.optimizer == "LazyAdam": 349 | opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate, 350 | beta1=params.adam_beta1, 351 | beta2=params.adam_beta2, 352 | epsilon=params.adam_epsilon) 353 | else: 354 | raise RuntimeError("Optimizer %s not supported" % params.optimizer) 355 | 356 | loss, ops = optimize.create_train_op(loss, opt, global_step, params) 357 | 358 | # Validation 359 | if params.validation and params.references[0]: 360 | files = [params.validation] + list(params.references) 361 | eval_inputs = dataset.sort_and_zip_files(files) 362 | eval_input_fn = dataset.get_evaluation_input 363 | else: 364 | eval_input_fn = None 365 | 366 | # Add hooks 367 | save_vars = tf.trainable_variables() + [global_step] 368 | saver = tf.train.Saver( 369 | var_list=save_vars if params.only_save_trainable else None, 370 | max_to_keep=params.keep_checkpoint_max, 371 | sharded=False 372 | ) 373 | tf.add_to_collection(tf.GraphKeys.SAVERS, saver) 374 | 375 | train_hooks = [ 376 | tf.train.StopAtStepHook(last_step=params.train_steps), 377 | tf.train.NanTensorHook(loss), 378 | tf.train.LoggingTensorHook( 379 | { 380 | "step": global_step, 381 | "loss": loss, 382 | }, 383 | every_n_iter=1 384 | ), 385 | tf.train.CheckpointSaverHook( 386 | checkpoint_dir=params.output, 387 | save_secs=params.save_checkpoint_secs or None, 388 | save_steps=params.save_checkpoint_steps or None, 389 | saver=saver 390 | ) 391 | ] 392 | 393 | config = session_config(params) 394 | 395 | if eval_input_fn is not None: 396 | train_hooks.append( 397 | hooks.EvaluationHook( 398 | lambda f: inference.create_inference_graph( 399 | [model.get_inference_func()], f, params 400 | ), 401 | lambda: eval_input_fn(eval_inputs, params), 402 | lambda x: decode_target_ids(x, params), 403 | params.output, 404 | config, 405 | params.keep_top_checkpoint_max, 406 | eval_secs=params.eval_secs, 407 | eval_steps=params.eval_steps 408 | ) 409 | ) 410 | 411 | # Create session, do not use default CheckpointSaverHook 412 | with tf.train.MonitoredTrainingSession( 413 | checkpoint_dir=params.output, hooks=train_hooks, 414 | save_checkpoint_secs=None, config=config) as sess: 415 | while not sess.should_stop(): 416 | # Bypass hook calls 417 | utils.session_run(sess, [init_op, ops["zero_op"]]) 418 | for i in range(params.update_cycle): 419 | utils.session_run(sess, ops["collect_op"]) 420 | utils.session_run(sess, ops["scale_op"]) 421 | sess.run(ops["train_op"]) 422 | 423 | 424 | if __name__ == "__main__": 425 | main(parse_args()) 426 | -------------------------------------------------------------------------------- /thumt/bin/translator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import itertools 11 | import os 12 | 13 | import numpy as np 14 | import tensorflow as tf 15 | import thumt.data.dataset as dataset 16 | import thumt.data.vocab as vocabulary 17 | import thumt.models as models 18 | import thumt.utils.inference as inference 19 | import thumt.utils.parallel as parallel 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser( 24 | description="Translate using existing NMT models", 25 | usage="translator.py [] [-h | --help]" 26 | ) 27 | 28 | # input files 29 | parser.add_argument("--input", type=str, required=True, 30 | help="Path of input file") 31 | parser.add_argument("--output", type=str, required=True, 32 | help="Path of output file") 33 | parser.add_argument("--checkpoints", type=str, nargs="+", required=True, 34 | help="Path of trained models") 35 | parser.add_argument("--vocabulary", type=str, nargs=2, required=True, 36 | help="Path of source and target vocabulary") 37 | 38 | # model and configuration 39 | parser.add_argument("--models", type=str, required=True, nargs="+", 40 | help="Name of the model") 41 | parser.add_argument("--parameters", type=str, 42 | help="Additional hyper parameters") 43 | parser.add_argument("--verbose", action="store_true", 44 | help="Enable verbose output") 45 | 46 | return parser.parse_args() 47 | 48 | 49 | def default_parameters(): 50 | params = tf.contrib.training.HParams( 51 | input=None, 52 | output=None, 53 | vocabulary=None, 54 | # vocabulary specific 55 | pad="", 56 | bos="", 57 | eos="", 58 | unk="", 59 | mapping=None, 60 | append_eos=False, 61 | # decoding 62 | top_beams=1, 63 | beam_size=4, 64 | decode_alpha=0.6, 65 | decode_length=50, 66 | decode_batch_size=32, 67 | device_list=[0], 68 | num_threads=1 69 | ) 70 | 71 | return params 72 | 73 | 74 | def merge_parameters(params1, params2): 75 | params = tf.contrib.training.HParams() 76 | 77 | for (k, v) in params1.values().iteritems(): 78 | params.add_hparam(k, v) 79 | 80 | params_dict = params.values() 81 | 82 | for (k, v) in params2.values().iteritems(): 83 | if k in params_dict: 84 | # Override 85 | setattr(params, k, v) 86 | else: 87 | params.add_hparam(k, v) 88 | 89 | return params 90 | 91 | 92 | def import_params(model_dir, model_name, params): 93 | if model_name.startswith("experimental_"): 94 | model_name = model_name[13:] 95 | 96 | model_dir = os.path.abspath(model_dir) 97 | m_name = os.path.join(model_dir, model_name + ".json") 98 | 99 | if not tf.gfile.Exists(m_name): 100 | return params 101 | 102 | with tf.gfile.Open(m_name) as fd: 103 | tf.logging.info("Restoring model parameters from %s" % m_name) 104 | json_str = fd.readline() 105 | params.parse_json(json_str) 106 | 107 | return params 108 | 109 | 110 | def override_parameters(params, args): 111 | if args.parameters: 112 | params.parse(args.parameters) 113 | 114 | params.vocabulary = { 115 | "source": vocabulary.load_vocabulary(args.vocabulary[0]), 116 | "target": vocabulary.load_vocabulary(args.vocabulary[1]) 117 | } 118 | params.vocabulary["source"] = vocabulary.process_vocabulary( 119 | params.vocabulary["source"], params 120 | ) 121 | params.vocabulary["target"] = vocabulary.process_vocabulary( 122 | params.vocabulary["target"], params 123 | ) 124 | 125 | control_symbols = [params.pad, params.bos, params.eos, params.unk] 126 | 127 | params.mapping = { 128 | "source": vocabulary.get_control_mapping( 129 | params.vocabulary["source"], 130 | control_symbols 131 | ), 132 | "target": vocabulary.get_control_mapping( 133 | params.vocabulary["target"], 134 | control_symbols 135 | ) 136 | } 137 | 138 | return params 139 | 140 | 141 | def session_config(params): 142 | optimizer_options = tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L1, 143 | do_function_inlining=False) 144 | graph_options = tf.GraphOptions(optimizer_options=optimizer_options) 145 | config = tf.ConfigProto(allow_soft_placement=True, 146 | graph_options=graph_options) 147 | if params.device_list: 148 | device_str = ",".join([str(i) for i in params.device_list]) 149 | config.gpu_options.visible_device_list = device_str 150 | 151 | return config 152 | 153 | 154 | def set_variables(var_list, value_dict, prefix): 155 | ops = [] 156 | for var in var_list: 157 | for name in value_dict: 158 | var_name = "/".join([prefix] + list(name.split("/")[1:])) 159 | 160 | if var.name[:-2] == var_name: 161 | tf.logging.debug("restoring %s -> %s" % (name, var.name)) 162 | with tf.device("/cpu:0"): 163 | op = tf.assign(var, value_dict[name]) 164 | ops.append(op) 165 | break 166 | 167 | return ops 168 | 169 | 170 | def shard_features(features, placeholders, predictions): 171 | num_shards = len(placeholders) 172 | feed_dict = {} 173 | n = 0 174 | 175 | for name in features: 176 | feat = features[name] 177 | batch = feat.shape[0] 178 | 179 | if batch < num_shards: 180 | feed_dict[placeholders[0][name]] = feat 181 | n = 1 182 | else: 183 | shard_size = (batch + num_shards - 1) // num_shards 184 | 185 | for i in range(num_shards): 186 | shard_feat = feat[i * shard_size:(i + 1) * shard_size] 187 | feed_dict[placeholders[i][name]] = shard_feat 188 | n = num_shards 189 | 190 | return predictions[:n], feed_dict 191 | 192 | 193 | def main(args): 194 | tf.logging.set_verbosity(tf.logging.INFO) 195 | # Load configs 196 | model_cls_list = [models.get_model(model) for model in args.models] 197 | params_list = [default_parameters() for _ in range(len(model_cls_list))] 198 | params_list = [ 199 | merge_parameters(params, model_cls.get_parameters()) 200 | for params, model_cls in zip(params_list, model_cls_list) 201 | ] 202 | params_list = [ 203 | import_params(args.checkpoints[i], args.models[i], params_list[i]) 204 | for i in range(len(args.checkpoints)) 205 | ] 206 | params_list = [ 207 | override_parameters(params_list[i], args) 208 | for i in range(len(model_cls_list)) 209 | ] 210 | 211 | # Build Graph 212 | with tf.Graph().as_default(): 213 | model_var_lists = [] 214 | 215 | # Load checkpoints 216 | for i, checkpoint in enumerate(args.checkpoints): 217 | tf.logging.info("Loading %s" % checkpoint) 218 | var_list = tf.train.list_variables(checkpoint) 219 | values = {} 220 | reader = tf.train.load_checkpoint(checkpoint) 221 | 222 | for (name, shape) in var_list: 223 | if not name.startswith(model_cls_list[i].get_name()): 224 | continue 225 | 226 | if name.find("losses_avg") >= 0: 227 | continue 228 | 229 | tensor = reader.get_tensor(name) 230 | values[name] = tensor 231 | 232 | model_var_lists.append(values) 233 | 234 | # Build models 235 | model_fns = [] 236 | 237 | for i in range(len(args.checkpoints)): 238 | name = model_cls_list[i].get_name() 239 | model = model_cls_list[i](params_list[i], name + "_%d" % i) 240 | model_fn = model.get_inference_func() 241 | model_fns.append(model_fn) 242 | 243 | params = params_list[0] 244 | # Read input file 245 | sorted_keys, sorted_inputs = dataset.sort_input_file(args.input) 246 | # Build input queue 247 | features = dataset.get_inference_input(sorted_inputs, params) 248 | # Create placeholders 249 | placeholders = [] 250 | 251 | for i in range(len(params.device_list)): 252 | placeholders.append({ 253 | "source": tf.placeholder(tf.int32, [None, None], 254 | "source_%d" % i), 255 | "source_length": tf.placeholder(tf.int32, [None], 256 | "source_length_%d" % i) 257 | }) 258 | 259 | # A list of outputs 260 | predictions = parallel.data_parallelism( 261 | params.device_list, 262 | lambda f: inference.create_inference_graph(model_fns, f, params), 263 | placeholders) 264 | 265 | # Create assign ops 266 | assign_ops = [] 267 | 268 | all_var_list = tf.trainable_variables() 269 | 270 | for i in range(len(args.checkpoints)): 271 | un_init_var_list = [] 272 | name = model_cls_list[i].get_name() 273 | 274 | for v in all_var_list: 275 | if v.name.startswith(name + "_%d" % i): 276 | un_init_var_list.append(v) 277 | 278 | ops = set_variables(un_init_var_list, model_var_lists[i], 279 | name + "_%d" % i) 280 | assign_ops.extend(ops) 281 | 282 | assign_op = tf.group(*assign_ops) 283 | results = [] 284 | 285 | # Create session 286 | with tf.Session(config=session_config(params)) as sess: 287 | # Restore variables 288 | sess.run(assign_op) 289 | sess.run(tf.tables_initializer()) 290 | 291 | while True: 292 | try: 293 | feats = sess.run(features) 294 | op, feed_dict = shard_features(feats, placeholders, 295 | predictions) 296 | results.append(sess.run(predictions, feed_dict=feed_dict)) 297 | message = "Finished batch %d" % len(results) 298 | tf.logging.log(tf.logging.INFO, message) 299 | except tf.errors.OutOfRangeError: 300 | break 301 | 302 | # Convert to plain text 303 | vocab = params.vocabulary["target"] 304 | outputs = [] 305 | scores = [] 306 | 307 | for result in results: 308 | for item in result[0]: 309 | outputs.append(item.tolist()) 310 | for item in result[1]: 311 | scores.append(item.tolist()) 312 | 313 | outputs = list(itertools.chain(*outputs)) 314 | scores = list(itertools.chain(*scores)) 315 | 316 | restored_inputs = [] 317 | restored_outputs = [] 318 | restored_scores = [] 319 | 320 | for index in range(len(sorted_inputs)): 321 | restored_inputs.append(sorted_inputs[sorted_keys[index]]) 322 | restored_outputs.append(outputs[sorted_keys[index]]) 323 | restored_scores.append(scores[sorted_keys[index]]) 324 | 325 | # Write to file 326 | with open(args.output, "w") as outfile: 327 | count = 0 328 | for outputs, scores in zip(restored_outputs, restored_scores): 329 | for output, score in zip(outputs, scores): 330 | decoded = [] 331 | for idx in output: 332 | if idx == params.mapping["target"][params.eos]: 333 | break 334 | decoded.append(vocab[idx]) 335 | 336 | decoded = " ".join(decoded) 337 | 338 | if not args.verbose: 339 | outfile.write("%s\n" % decoded) 340 | break 341 | else: 342 | pattern = "%d ||| %s ||| %s ||| %f\n" 343 | source = restored_inputs[count] 344 | values = (count, source, decoded, score) 345 | outfile.write(pattern % values) 346 | 347 | count += 1 348 | 349 | 350 | if __name__ == "__main__": 351 | main(parse_args()) 352 | -------------------------------------------------------------------------------- /thumt/bin/translator_ctx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import itertools 11 | import os 12 | 13 | import numpy as np 14 | import tensorflow as tf 15 | import thumt.data.dataset as dataset 16 | import thumt.data.vocab as vocabulary 17 | import thumt.models as models 18 | import thumt.utils.inference_ctx as inference 19 | import thumt.utils.parallel as parallel 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser( 24 | description="Translate using existing NMT models", 25 | usage="translator.py [] [-h | --help]" 26 | ) 27 | 28 | # input files 29 | parser.add_argument("--input", type=str, required=True, 30 | help="Path of input file") 31 | parser.add_argument("--context", type=str, required=True, 32 | help="Path of context file") 33 | parser.add_argument("--output", type=str, required=True, 34 | help="Path of output file") 35 | parser.add_argument("--checkpoints", type=str, nargs="+", required=True, 36 | help="Path of trained models") 37 | parser.add_argument("--vocabulary", type=str, nargs=2, required=True, 38 | help="Path of source and target vocabulary") 39 | 40 | # model and configuration 41 | parser.add_argument("--models", type=str, required=True, nargs="+", 42 | help="Name of the model") 43 | parser.add_argument("--parameters", type=str, 44 | help="Additional hyper parameters") 45 | parser.add_argument("--verbose", action="store_true", 46 | help="Enable verbose output") 47 | 48 | return parser.parse_args() 49 | 50 | 51 | def default_parameters(): 52 | params = tf.contrib.training.HParams( 53 | input=None, 54 | output=None, 55 | vocabulary=None, 56 | # vocabulary specific 57 | pad="", 58 | bos="", 59 | eos="", 60 | unk="", 61 | mapping=None, 62 | append_eos=False, 63 | # decoding 64 | top_beams=1, 65 | beam_size=4, 66 | decode_alpha=0.6, 67 | decode_length=50, 68 | decode_batch_size=32, 69 | device_list=[0], 70 | num_threads=1 71 | ) 72 | 73 | return params 74 | 75 | 76 | def merge_parameters(params1, params2): 77 | params = tf.contrib.training.HParams() 78 | 79 | for (k, v) in params1.values().iteritems(): 80 | params.add_hparam(k, v) 81 | 82 | params_dict = params.values() 83 | 84 | for (k, v) in params2.values().iteritems(): 85 | if k in params_dict: 86 | # Override 87 | setattr(params, k, v) 88 | else: 89 | params.add_hparam(k, v) 90 | 91 | return params 92 | 93 | 94 | def import_params(model_dir, model_name, params): 95 | if model_name.startswith("experimental_"): 96 | model_name = model_name[13:] 97 | 98 | model_dir = os.path.abspath(model_dir) 99 | m_name = os.path.join(model_dir, model_name + ".json") 100 | 101 | if not tf.gfile.Exists(m_name): 102 | return params 103 | 104 | with tf.gfile.Open(m_name) as fd: 105 | tf.logging.info("Restoring model parameters from %s" % m_name) 106 | json_str = fd.readline() 107 | params.parse_json(json_str) 108 | 109 | return params 110 | 111 | 112 | def override_parameters(params, args): 113 | if args.parameters: 114 | params.parse(args.parameters) 115 | 116 | params.vocabulary = { 117 | "source": vocabulary.load_vocabulary(args.vocabulary[0]), 118 | "target": vocabulary.load_vocabulary(args.vocabulary[1]) 119 | } 120 | params.vocabulary["source"] = vocabulary.process_vocabulary( 121 | params.vocabulary["source"], params 122 | ) 123 | params.vocabulary["target"] = vocabulary.process_vocabulary( 124 | params.vocabulary["target"], params 125 | ) 126 | 127 | control_symbols = [params.pad, params.bos, params.eos, params.unk] 128 | 129 | params.mapping = { 130 | "source": vocabulary.get_control_mapping( 131 | params.vocabulary["source"], 132 | control_symbols 133 | ), 134 | "target": vocabulary.get_control_mapping( 135 | params.vocabulary["target"], 136 | control_symbols 137 | ) 138 | } 139 | 140 | return params 141 | 142 | 143 | def session_config(params): 144 | optimizer_options = tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L1, 145 | do_function_inlining=False) 146 | graph_options = tf.GraphOptions(optimizer_options=optimizer_options) 147 | config = tf.ConfigProto(allow_soft_placement=True, 148 | graph_options=graph_options) 149 | if params.device_list: 150 | device_str = ",".join([str(i) for i in params.device_list]) 151 | config.gpu_options.visible_device_list = device_str 152 | 153 | return config 154 | 155 | 156 | def set_variables(var_list, value_dict, prefix): 157 | ops = [] 158 | for var in var_list: 159 | for name in value_dict: 160 | var_name = "/".join([prefix] + list(name.split("/")[1:])) 161 | 162 | if var.name[:-2] == var_name: 163 | tf.logging.debug("restoring %s -> %s" % (name, var.name)) 164 | with tf.device("/cpu:0"): 165 | op = tf.assign(var, value_dict[name]) 166 | ops.append(op) 167 | break 168 | 169 | return ops 170 | 171 | 172 | def shard_features(features, placeholders, predictions): 173 | num_shards = len(placeholders) 174 | feed_dict = {} 175 | n = 0 176 | 177 | for name in features: 178 | feat = features[name] 179 | batch = feat.shape[0] 180 | 181 | if batch < num_shards: 182 | feed_dict[placeholders[0][name]] = feat 183 | n = 1 184 | else: 185 | shard_size = (batch + num_shards - 1) // num_shards 186 | 187 | for i in range(num_shards): 188 | shard_feat = feat[i * shard_size:(i + 1) * shard_size] 189 | feed_dict[placeholders[i][name]] = shard_feat 190 | n = num_shards 191 | 192 | return predictions[:n], feed_dict 193 | 194 | 195 | def main(args): 196 | tf.logging.set_verbosity(tf.logging.INFO) 197 | # Load configs 198 | model_cls_list = [models.get_model(model) for model in args.models] 199 | params_list = [default_parameters() for _ in range(len(model_cls_list))] 200 | params_list = [ 201 | merge_parameters(params, model_cls.get_parameters()) 202 | for params, model_cls in zip(params_list, model_cls_list) 203 | ] 204 | params_list = [ 205 | import_params(args.checkpoints[i], args.models[i], params_list[i]) 206 | for i in range(len(args.checkpoints)) 207 | ] 208 | params_list = [ 209 | override_parameters(params_list[i], args) 210 | for i in range(len(model_cls_list)) 211 | ] 212 | 213 | # Build Graph 214 | with tf.Graph().as_default(): 215 | model_var_lists = [] 216 | 217 | # Load checkpoints 218 | for i, checkpoint in enumerate(args.checkpoints): 219 | tf.logging.info("Loading %s" % checkpoint) 220 | var_list = tf.train.list_variables(checkpoint) 221 | values = {} 222 | reader = tf.train.load_checkpoint(checkpoint) 223 | 224 | for (name, shape) in var_list: 225 | if not name.startswith(model_cls_list[i].get_name()): 226 | continue 227 | 228 | if name.find("losses_avg") >= 0: 229 | continue 230 | 231 | tensor = reader.get_tensor(name) 232 | values[name] = tensor 233 | 234 | model_var_lists.append(values) 235 | 236 | # Build models 237 | model_fns = [] 238 | 239 | for i in range(len(args.checkpoints)): 240 | name = model_cls_list[i].get_name() 241 | model = model_cls_list[i](params_list[i], name + "_%d" % i) 242 | model_fn = model.get_inference_func() 243 | model_fns.append(model_fn) 244 | 245 | params = params_list[0] 246 | # Read input file 247 | sorted_keys, sorted_inputs, sorted_ctxs = dataset.sort_input_file_ctx(args.input, args.context) 248 | # Build input queue 249 | features = dataset.get_inference_input(sorted_inputs, params) 250 | features_ctx = dataset.get_inference_input(sorted_ctxs, params) 251 | features["context"] = features_ctx["source"] 252 | features["context_length"] = features_ctx["source_length"] 253 | # Create placeholders 254 | placeholders = [] 255 | 256 | for i in range(len(params.device_list)): 257 | placeholders.append({ 258 | "source": tf.placeholder(tf.int32, [None, None], 259 | "source_%d" % i), 260 | "source_length": tf.placeholder(tf.int32, [None], 261 | "source_length_%d" % i), 262 | "context": tf.placeholder(tf.int32, [None, None], 263 | "context_%d" % i), 264 | "context_length": tf.placeholder(tf.int32, [None], 265 | "context_length_%d" % i) 266 | }) 267 | 268 | # A list of outputs 269 | predictions = parallel.data_parallelism( 270 | params.device_list, 271 | lambda f: inference.create_inference_graph(model_fns, f, params), 272 | placeholders) 273 | 274 | # Create assign ops 275 | assign_ops = [] 276 | 277 | all_var_list = tf.all_variables() 278 | 279 | for i in range(len(args.checkpoints)): 280 | un_init_var_list = [] 281 | name = model_cls_list[i].get_name() 282 | 283 | for v in all_var_list: 284 | if v.name.startswith(name + "_%d" % i): 285 | un_init_var_list.append(v) 286 | 287 | ops = set_variables(un_init_var_list, model_var_lists[i], 288 | name + "_%d" % i) 289 | assign_ops.extend(ops) 290 | 291 | assign_op = tf.group(*assign_ops) 292 | results = [] 293 | 294 | # Create session 295 | with tf.Session(config=session_config(params)) as sess: 296 | # Restore variables 297 | sess.run(assign_op) 298 | sess.run(tf.tables_initializer()) 299 | 300 | while True: 301 | try: 302 | feats = sess.run(features) 303 | op, feed_dict = shard_features(feats, placeholders, 304 | predictions) 305 | results.append(sess.run(predictions, feed_dict=feed_dict)) 306 | message = "Finished batch %d" % len(results) 307 | tf.logging.log(tf.logging.INFO, message) 308 | except tf.errors.OutOfRangeError: 309 | break 310 | 311 | # Convert to plain text 312 | vocab = params.vocabulary["target"] 313 | outputs = [] 314 | scores = [] 315 | 316 | for result in results: 317 | for item in result[0]: 318 | outputs.append(item.tolist()) 319 | for item in result[1]: 320 | scores.append(item.tolist()) 321 | 322 | outputs = list(itertools.chain(*outputs)) 323 | scores = list(itertools.chain(*scores)) 324 | 325 | restored_inputs = [] 326 | restored_outputs = [] 327 | restored_scores = [] 328 | 329 | for index in range(len(sorted_inputs)): 330 | restored_inputs.append(sorted_inputs[sorted_keys[index]]) 331 | restored_outputs.append(outputs[sorted_keys[index]]) 332 | restored_scores.append(scores[sorted_keys[index]]) 333 | 334 | # Write to file 335 | with open(args.output, "w") as outfile: 336 | count = 0 337 | for outputs, scores in zip(restored_outputs, restored_scores): 338 | for output, score in zip(outputs, scores): 339 | decoded = [] 340 | for idx in output: 341 | if idx == params.mapping["target"][params.eos]: 342 | break 343 | decoded.append(vocab[idx]) 344 | 345 | decoded = " ".join(decoded) 346 | 347 | if not args.verbose: 348 | outfile.write("%s\n" % decoded) 349 | break 350 | else: 351 | pattern = "%d ||| %s ||| %s ||| %f\n" 352 | source = restored_inputs[count] 353 | values = (count, source, decoded, score) 354 | outfile.write(pattern % values) 355 | 356 | count += 1 357 | 358 | 359 | if __name__ == "__main__": 360 | main(parse_args()) 361 | -------------------------------------------------------------------------------- /thumt/data/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | -------------------------------------------------------------------------------- /thumt/data/cache.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import tensorflow as tf 9 | 10 | 11 | def cache_features(features, num_shards): 12 | if num_shards == 1: 13 | return features, tf.no_op(name="init_queue") 14 | 15 | flat_features = list(features.itervalues()) 16 | queue = tf.FIFOQueue(num_shards, dtypes=[v.dtype for v in flat_features]) 17 | flat_features = [tf.split(v, num_shards, axis=0) for v in flat_features] 18 | flat_features = list(zip(*flat_features)) 19 | init_ops = [queue.enqueue(v, name="enqueue_%d" % i) 20 | for i, v in enumerate(flat_features)] 21 | flat_feature = queue.dequeue() 22 | new_features = {} 23 | 24 | for k, v in zip(features.iterkeys(), flat_feature): 25 | v.set_shape(features[k].shape) 26 | new_features[k] = v 27 | 28 | return new_features, tf.group(*init_ops) 29 | -------------------------------------------------------------------------------- /thumt/data/record.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Code modified from Tensor2Tensor library 3 | # Copyright 2018 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import math 10 | 11 | import numpy as np 12 | import six 13 | import tensorflow as tf 14 | from tensorflow.contrib.slim import parallel_reader, tfexample_decoder 15 | 16 | 17 | def input_pipeline(file_pattern, mode, capacity=64): 18 | keys_to_features = { 19 | "source": tf.VarLenFeature(tf.int64), 20 | "target": tf.VarLenFeature(tf.int64), 21 | "source_length": tf.FixedLenFeature([1], tf.int64), 22 | "target_length": tf.FixedLenFeature([1], tf.int64) 23 | } 24 | 25 | items_to_handlers = { 26 | "source": tfexample_decoder.Tensor("source"), 27 | "target": tfexample_decoder.Tensor("target"), 28 | "source_length": tfexample_decoder.Tensor("source_length"), 29 | "target_length": tfexample_decoder.Tensor("target_length") 30 | } 31 | 32 | # Now the non-trivial case construction. 33 | with tf.name_scope("examples_queue"): 34 | training = (mode == "train") 35 | # Read serialized examples using slim parallel_reader. 36 | num_epochs = None if training else 1 37 | data_files = parallel_reader.get_data_files(file_pattern) 38 | num_readers = min(4 if training else 1, len(data_files)) 39 | _, examples = parallel_reader.parallel_read([file_pattern], 40 | tf.TFRecordReader, 41 | num_epochs=num_epochs, 42 | shuffle=training, 43 | capacity=2 * capacity, 44 | min_after_dequeue=capacity, 45 | num_readers=num_readers) 46 | 47 | decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, 48 | items_to_handlers) 49 | 50 | decoded = decoder.decode(examples, items=list(items_to_handlers)) 51 | examples = {} 52 | 53 | for (field, tensor) in zip(keys_to_features, decoded): 54 | examples[field] = tensor 55 | 56 | # We do not want int64s as they do are not supported on GPUs. 57 | return {k: tf.to_int32(v) for (k, v) in six.iteritems(examples)} 58 | 59 | 60 | def batch_examples(examples, batch_size, max_length, mantissa_bits, 61 | shard_multiplier=1, length_multiplier=1, scheme="token", 62 | drop_long_sequences=True): 63 | with tf.name_scope("batch_examples"): 64 | max_length = max_length or batch_size 65 | min_length = 8 66 | mantissa_bits = mantissa_bits 67 | 68 | # compute boundaries 69 | x = min_length 70 | boundaries = [] 71 | 72 | while x < max_length: 73 | boundaries.append(x) 74 | x += 2 ** max(0, int(math.log(x, 2)) - mantissa_bits) 75 | 76 | if scheme is "token": 77 | batch_sizes = [max(1, batch_size // length) 78 | for length in boundaries + [max_length]] 79 | batch_sizes = [b * shard_multiplier for b in batch_sizes] 80 | bucket_capacities = [2 * b for b in batch_sizes] 81 | else: 82 | batch_sizes = batch_size * shard_multiplier 83 | bucket_capacities = [2 * n for n in boundaries + [max_length]] 84 | 85 | max_length *= length_multiplier 86 | boundaries = [boundary * length_multiplier for boundary in boundaries] 87 | max_length = max_length if drop_long_sequences else 10 ** 9 88 | 89 | # The queue to bucket on will be chosen based on maximum length. 90 | max_example_length = 0 91 | for v in examples.values(): 92 | seq_length = tf.shape(v)[0] 93 | max_example_length = tf.maximum(max_example_length, seq_length) 94 | 95 | (_, outputs) = tf.contrib.training.bucket_by_sequence_length( 96 | max_example_length, 97 | examples, 98 | batch_sizes, 99 | [b + 1 for b in boundaries], 100 | capacity=2, 101 | bucket_capacities=bucket_capacities, 102 | dynamic_pad=True, 103 | keep_input=(max_example_length <= max_length) 104 | ) 105 | 106 | return outputs 107 | 108 | 109 | def get_input_features(file_patterns, mode, params): 110 | with tf.name_scope("input_queues"): 111 | with tf.device("/cpu:0"): 112 | if mode != "train": 113 | num_datashards = 1 114 | batch_size = params.eval_batch_size 115 | else: 116 | num_datashards = len(params.device_list) 117 | batch_size = params.batch_size 118 | 119 | batch_size_multiplier = 1 120 | capacity = 64 * num_datashards 121 | examples = input_pipeline(file_patterns, mode, capacity) 122 | drop_long_sequences = (mode == "train") 123 | 124 | feature_map = batch_examples( 125 | examples, 126 | batch_size, 127 | params.max_length, 128 | params.mantissa_bits, 129 | num_datashards, 130 | batch_size_multiplier, 131 | "token" if not params.constant_batch_size else "constant", 132 | drop_long_sequences 133 | ) 134 | 135 | features = { 136 | "source": feature_map["source"], 137 | "target": feature_map["target"], 138 | "source_length": tf.squeeze(feature_map["source_length"], axis=1), 139 | "target_length": tf.squeeze(feature_map["target_length"], axis=1) 140 | } 141 | 142 | return features 143 | -------------------------------------------------------------------------------- /thumt/data/vocab.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import tensorflow as tf 9 | 10 | 11 | def load_vocabulary(filename): 12 | vocab = [] 13 | with tf.gfile.GFile(filename) as fd: 14 | for line in fd: 15 | word = line.strip() 16 | vocab.append(word) 17 | 18 | return vocab 19 | 20 | 21 | def process_vocabulary(vocab, params): 22 | if params.append_eos: 23 | vocab.append(params.eos) 24 | 25 | return vocab 26 | 27 | 28 | def get_control_mapping(vocab, symbols): 29 | mapping = {} 30 | 31 | for i, token in enumerate(vocab): 32 | for symbol in symbols: 33 | if symbol.decode("utf-8") == token.decode("utf-8"): 34 | mapping[symbol] = i 35 | 36 | return mapping 37 | -------------------------------------------------------------------------------- /thumt/interface/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | from thumt.interface.model import NMTModel 9 | -------------------------------------------------------------------------------- /thumt/interface/model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | 9 | class NMTModel(object): 10 | """ Abstract object representing an NMT model """ 11 | 12 | def __init__(self, params, scope): 13 | self._scope = scope 14 | self._params = params 15 | 16 | def get_training_func(self, initializer): 17 | """ 18 | :param initializer: the initializer used to initialize the model 19 | :return: a function with the following signature: 20 | (features, params, reuse) -> loss 21 | """ 22 | raise NotImplementedError("Not implemented") 23 | 24 | def get_evaluation_func(self): 25 | """ 26 | :return: a function with the following signature: 27 | (features, params) -> score 28 | """ 29 | raise NotImplementedError("Not implemented") 30 | 31 | def get_inference_func(self): 32 | """ 33 | :returns: 34 | If a model implements incremental decoding, this function should 35 | returns a tuple of (encoding_fn, decoding_fn), with the following 36 | requirements: 37 | encoding_fn: (features, params) -> initial_state 38 | decoding_fn: (feature, state, params) -> log_prob, next_state 39 | 40 | If a model does not implement the incremental decoding (slower 41 | decoding speed but easier to write the code), then this 42 | function should returns a single function with the following 43 | signature: 44 | (features, params) -> log_prob 45 | 46 | See models/transformer.py and models/rnnsearch.py 47 | for comparison. 48 | """ 49 | raise NotImplementedError("Not implemented") 50 | 51 | @staticmethod 52 | def get_name(): 53 | raise NotImplementedError("Not implemented") 54 | 55 | @staticmethod 56 | def get_parameters(): 57 | raise NotImplementedError("Not implemented") 58 | 59 | @property 60 | def parameters(self): 61 | return self._params 62 | -------------------------------------------------------------------------------- /thumt/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | import thumt.layers.attention 5 | import thumt.layers.nn 6 | import thumt.layers.rnn_cell 7 | -------------------------------------------------------------------------------- /thumt/layers/attention.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import math 9 | import tensorflow as tf 10 | 11 | from thumt.layers.nn import linear 12 | 13 | 14 | def add_timing_signal(x, min_timescale=1.0, max_timescale=1.0e4, name=None): 15 | """ 16 | This function adds a bunch of sinusoids of different frequencies to a 17 | Tensor. See paper: `Attention is all you need' 18 | 19 | :param x: A tensor with shape [batch, length, channels] 20 | :param min_timescale: A floating point number 21 | :param max_timescale: A floating point number 22 | :param name: An optional string 23 | 24 | :returns: a Tensor the same shape as x. 25 | """ 26 | 27 | with tf.name_scope(name, default_name="add_timing_signal", values=[x]): 28 | length = tf.shape(x)[1] 29 | channels = tf.shape(x)[2] 30 | position = tf.to_float(tf.range(length)) 31 | num_timescales = channels // 2 32 | 33 | log_timescale_increment = ( 34 | math.log(float(max_timescale) / float(min_timescale)) / 35 | (tf.to_float(num_timescales) - 1) 36 | ) 37 | inv_timescales = min_timescale * tf.exp( 38 | tf.to_float(tf.range(num_timescales)) * -log_timescale_increment 39 | ) 40 | 41 | scaled_time = (tf.expand_dims(position, 1) * 42 | tf.expand_dims(inv_timescales, 0)) 43 | signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) 44 | signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]]) 45 | signal = tf.reshape(signal, [1, length, channels]) 46 | 47 | return x + signal 48 | 49 | 50 | def split_heads(inputs, num_heads, name=None): 51 | """ Split heads 52 | :param inputs: A tensor with shape [batch, ..., channels] 53 | :param num_heads: An integer 54 | :param name: An optional string 55 | :returns: A tensor with shape [batch, heads, ..., channels / heads] 56 | """ 57 | 58 | with tf.name_scope(name, default_name="split_heads", values=[inputs]): 59 | x = inputs 60 | n = num_heads 61 | old_shape = x.get_shape().dims 62 | ndims = x.shape.ndims 63 | 64 | last = old_shape[-1] 65 | new_shape = old_shape[:-1] + [n] + [last // n if last else None] 66 | ret = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [n, -1]], 0)) 67 | ret.set_shape(new_shape) 68 | perm = [0, ndims - 1] + [i for i in range(1, ndims - 1)] + [ndims] 69 | return tf.transpose(ret, perm) 70 | 71 | 72 | def combine_heads(inputs, name=None): 73 | """ Combine heads 74 | :param inputs: A tensor with shape [batch, heads, length, channels] 75 | :param name: An optional string 76 | :returns: A tensor with shape [batch, length, heads * channels] 77 | """ 78 | 79 | with tf.name_scope(name, default_name="combine_heads", values=[inputs]): 80 | x = inputs 81 | x = tf.transpose(x, [0, 2, 1, 3]) 82 | old_shape = x.get_shape().dims 83 | a, b = old_shape[-2:] 84 | new_shape = old_shape[:-2] + [a * b if a and b else None] 85 | x = tf.reshape(x, tf.concat([tf.shape(x)[:-2], [-1]], 0)) 86 | x.set_shape(new_shape) 87 | 88 | return x 89 | 90 | 91 | def attention_bias(inputs, mode, inf=-1e9, name=None): 92 | """ A bias tensor used in attention mechanism 93 | :param inputs: A tensor 94 | :param mode: one of "causal", "masking", "proximal" or "distance" 95 | :param inf: A floating value 96 | :param name: optional string 97 | :returns: A 4D tensor with shape [batch, heads, queries, memories] 98 | """ 99 | 100 | with tf.name_scope(name, default_name="attention_bias", values=[inputs]): 101 | if mode == "causal": 102 | length = inputs 103 | lower_triangle = tf.matrix_band_part( 104 | tf.ones([length, length]), -1, 0 105 | ) 106 | ret = inf * (1.0 - lower_triangle) 107 | return tf.reshape(ret, [1, 1, length, length]) 108 | elif mode == "masking": 109 | mask = inputs 110 | ret = (1.0 - mask) * inf 111 | return tf.expand_dims(tf.expand_dims(ret, 1), 1) 112 | elif mode == "proximal": 113 | length = inputs 114 | r = tf.to_float(tf.range(length)) 115 | diff = tf.expand_dims(r, 0) - tf.expand_dims(r, 1) 116 | m = tf.expand_dims(tf.expand_dims(-tf.log(1 + tf.abs(diff)), 0), 0) 117 | return m 118 | elif mode == "distance": 119 | length, distance = inputs 120 | distance = tf.where(distance > length, 0, distance) 121 | distance = tf.cast(distance, tf.int64) 122 | lower_triangle = tf.matrix_band_part( 123 | tf.ones([length, length]), -1, 0 124 | ) 125 | mask_triangle = 1.0 - tf.matrix_band_part( 126 | tf.ones([length, length]), distance - 1, 0 127 | ) 128 | ret = inf * (1.0 - lower_triangle + mask_triangle) 129 | return tf.reshape(ret, [1, 1, length, length]) 130 | else: 131 | raise ValueError("Unknown mode %s" % mode) 132 | 133 | 134 | def attention(query, memories, bias, hidden_size, cache=None, reuse=None, 135 | dtype=None, scope=None): 136 | """ Standard attention layer 137 | 138 | :param query: A tensor with shape [batch, key_size] 139 | :param memories: A tensor with shape [batch, memory_size, key_size] 140 | :param bias: A tensor with shape [batch, memory_size] 141 | :param hidden_size: An integer 142 | :param cache: A dictionary of precomputed value 143 | :param reuse: A boolean value, whether to reuse the scope 144 | :param dtype: An optional instance of tf.DType 145 | :param scope: An optional string, the scope of this layer 146 | :return: A tensor with shape [batch, value_size] and 147 | a Tensor with shape [batch, memory_size] 148 | """ 149 | 150 | with tf.variable_scope(scope or "attention", reuse=reuse, 151 | values=[query, memories, bias], dtype=dtype): 152 | mem_shape = tf.shape(memories) 153 | key_size = memories.get_shape().as_list()[-1] 154 | 155 | if cache is None: 156 | k = tf.reshape(memories, [-1, key_size]) 157 | k = linear(k, hidden_size, False, False, scope="k_transform") 158 | 159 | if query is None: 160 | return {"key": k} 161 | else: 162 | k = cache["key"] 163 | 164 | q = linear(query, hidden_size, False, False, scope="q_transform") 165 | k = tf.reshape(k, [mem_shape[0], mem_shape[1], hidden_size]) 166 | 167 | hidden = tf.tanh(q[:, None, :] + k) 168 | hidden = tf.reshape(hidden, [-1, hidden_size]) 169 | 170 | # Shape: [batch, mem_size, 1] 171 | logits = linear(hidden, 1, False, False, scope="logits") 172 | logits = tf.reshape(logits, [-1, mem_shape[1]]) 173 | 174 | if bias is not None: 175 | logits = logits + bias 176 | 177 | alpha = tf.nn.softmax(logits) 178 | 179 | outputs = { 180 | "value": tf.reduce_sum(alpha[:, :, None] * memories, axis=1), 181 | "weight": alpha 182 | } 183 | 184 | return outputs 185 | 186 | 187 | def additive_attention(queries, keys, values, bias, hidden_size, concat=False, 188 | keep_prob=None, dtype=None, scope=None): 189 | """ Additive attention mechanism. This layer is implemented using a 190 | one layer feed forward neural network 191 | 192 | :param queries: A tensor with shape [batch, heads, length_q, depth_k] 193 | :param keys: A tensor with shape [batch, heads, length_kv, depth_k] 194 | :param values: A tensor with shape [batch, heads, length_kv, depth_v] 195 | :param bias: A tensor 196 | :param hidden_size: An integer 197 | :param concat: A boolean value. If ``concat'' is set to True, then 198 | the computation of attention mechanism is following $tanh(W[q, k])$. 199 | When ``concat'' is set to False, the computation is following 200 | $tanh(Wq + Vk)$ 201 | :param keep_prob: a scalar in [0, 1] 202 | :param dtype: An optional instance of tf.DType 203 | :param scope: An optional string, the scope of this layer 204 | 205 | :returns: A dict with the following keys: 206 | weights: A tensor with shape [batch, length_q] 207 | outputs: A tensor with shape [batch, length_q, depth_v] 208 | """ 209 | 210 | with tf.variable_scope(scope, default_name="additive_attention", 211 | values=[queries, keys, values, bias], dtype=dtype): 212 | length_q = tf.shape(queries)[2] 213 | length_kv = tf.shape(keys)[2] 214 | q = tf.tile(tf.expand_dims(queries, 3), [1, 1, 1, length_kv, 1]) 215 | k = tf.tile(tf.expand_dims(keys, 2), [1, 1, length_q, 1, 1]) 216 | 217 | if concat: 218 | combined = tf.tanh(linear(tf.concat([q, k], axis=-1), hidden_size, 219 | True, True, name="qk_transform")) 220 | else: 221 | q = linear(queries, hidden_size, True, True, name="q_transform") 222 | k = linear(keys, hidden_size, True, True, name="key_transform") 223 | combined = tf.tanh(q + k) 224 | 225 | # shape: [batch, heads, length_q, length_kv] 226 | logits = tf.squeeze(linear(combined, 1, True, True, name="logits"), 227 | axis=-1) 228 | 229 | if bias is not None: 230 | logits += bias 231 | 232 | weights = tf.nn.softmax(logits, name="attention_weights") 233 | 234 | if keep_prob or keep_prob < 1.0: 235 | weights = tf.nn.dropout(weights, keep_prob) 236 | 237 | outputs = tf.matmul(weights, values) 238 | 239 | return {"weights": weights, "outputs": outputs} 240 | 241 | 242 | def multiplicative_attention(queries, keys, values, bias, keep_prob=None, 243 | name=None): 244 | """ Multiplicative attention mechanism. This layer is implemented using 245 | dot-product operation. 246 | 247 | :param queries: A tensor with shape [batch, heads, length_q, depth_k] 248 | :param keys: A tensor with shape [batch, heads, length_kv, depth_k] 249 | :param values: A tensor with shape [batch, heads, length_kv, depth_v] 250 | :param bias: A tensor 251 | :param keep_prob: a scalar in (0, 1] 252 | :param name: the name of this operation 253 | 254 | :returns: A dict with the following keys: 255 | weights: A tensor with shape [batch, heads, length_q, length_kv] 256 | outputs: A tensor with shape [batch, heads, length_q, depth_v] 257 | """ 258 | 259 | with tf.name_scope(name, default_name="multiplicative_attention", 260 | values=[queries, keys, values, bias]): 261 | # shape: [batch, heads, length_q, length_kv] 262 | logits = tf.matmul(queries, keys, transpose_b=True) 263 | 264 | if bias is not None: 265 | logits += bias 266 | 267 | weights = tf.nn.softmax(logits, name="attention_weights") 268 | 269 | if keep_prob is not None and keep_prob < 1.0: 270 | weights = tf.nn.dropout(weights, keep_prob) 271 | 272 | outputs = tf.matmul(weights, values) 273 | 274 | return {"weights": weights, "outputs": outputs} 275 | 276 | 277 | def multihead_attention(queries, memories, bias, num_heads, key_size, 278 | value_size, output_size, keep_prob=None, output=True, 279 | state=None, dtype=None, scope=None, trainable=True): 280 | """ Multi-head scaled-dot-product attention with input/output 281 | transformations. 282 | 283 | :param queries: A tensor with shape [batch, length_q, depth_q] 284 | :param memories: A tensor with shape [batch, length_m, depth_m] 285 | :param bias: A tensor (see attention_bias) 286 | :param num_heads: An integer dividing key_size and value_size 287 | :param key_size: An integer 288 | :param value_size: An integer 289 | :param output_size: An integer 290 | :param keep_prob: A floating point number in (0, 1] 291 | :param output: Whether to use output transformation 292 | :param state: An optional dictionary used for incremental decoding 293 | :param dtype: An optional instance of tf.DType 294 | :param scope: An optional string 295 | 296 | :returns: A dict with the following keys: 297 | weights: A tensor with shape [batch, heads, length_q, length_kv] 298 | outputs: A tensor with shape [batch, length_q, depth_v] 299 | """ 300 | 301 | if key_size % num_heads != 0: 302 | raise ValueError("Key size (%d) must be divisible by the number of " 303 | "attention heads (%d)." % (key_size, num_heads)) 304 | 305 | if value_size % num_heads != 0: 306 | raise ValueError("Value size (%d) must be divisible by the number of " 307 | "attention heads (%d)." % (value_size, num_heads)) 308 | 309 | with tf.variable_scope(scope, default_name="multihead_attention", 310 | values=[queries, memories], dtype=dtype): 311 | next_state = {} 312 | 313 | if memories is None: 314 | # self attention 315 | size = key_size * 2 + value_size 316 | combined = linear(queries, size, True, True, scope="qkv_transform", trainable=trainable) 317 | q, k, v = tf.split(combined, [key_size, key_size, value_size], 318 | axis=-1) 319 | 320 | if state is not None: 321 | k = tf.concat([state["key"], k], axis=1) 322 | v = tf.concat([state["value"], v], axis=1) 323 | next_state["key"] = k 324 | next_state["value"] = v 325 | else: 326 | q = linear(queries, key_size, True, True, scope="q_transform", trainable=trainable) 327 | combined = linear(memories, key_size + value_size, True, 328 | scope="kv_transform", trainable=trainable) 329 | k, v = tf.split(combined, [key_size, value_size], axis=-1) 330 | 331 | # split heads 332 | q = split_heads(q, num_heads) 333 | k = split_heads(k, num_heads) 334 | v = split_heads(v, num_heads) 335 | 336 | # scale query 337 | key_depth_per_head = key_size // num_heads 338 | q *= key_depth_per_head ** -0.5 339 | 340 | # attention 341 | results = multiplicative_attention(q, k, v, bias, keep_prob) 342 | 343 | # combine heads 344 | weights = results["weights"] 345 | x = combine_heads(results["outputs"]) 346 | 347 | if output: 348 | outputs = linear(x, output_size, True, True, 349 | scope="output_transform", trainable=trainable) 350 | else: 351 | outputs = x 352 | 353 | outputs = {"weights": weights, "outputs": outputs} 354 | 355 | if state is not None: 356 | outputs["state"] = next_state 357 | 358 | return outputs 359 | -------------------------------------------------------------------------------- /thumt/layers/nn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import tensorflow as tf 9 | 10 | 11 | def linear(inputs, output_size, bias, concat=True, dtype=None, scope=None, trainable=True): 12 | """ 13 | Linear layer 14 | :param inputs: A Tensor or a list of Tensors with shape [batch, input_size] 15 | :param output_size: An integer specify the output size 16 | :param bias: a boolean value indicate whether to use bias term 17 | :param concat: a boolean value indicate whether to concatenate all inputs 18 | :param dtype: an instance of tf.DType, the default value is ``tf.float32'' 19 | :param scope: the scope of this layer, the default value is ``linear'' 20 | :returns: a Tensor with shape [batch, output_size] 21 | :raises RuntimeError: raises ``RuntimeError'' when input sizes do not 22 | compatible with each other 23 | """ 24 | 25 | with tf.variable_scope(scope, default_name="linear", values=[inputs]): 26 | if not isinstance(inputs, (list, tuple)): 27 | inputs = [inputs] 28 | 29 | input_size = [item.get_shape()[-1].value for item in inputs] 30 | 31 | if len(inputs) != len(input_size): 32 | raise RuntimeError("inputs and input_size unmatched!") 33 | 34 | output_shape = tf.concat([tf.shape(inputs[0])[:-1], [output_size]], 35 | axis=0) 36 | # Flatten to 2D 37 | inputs = [tf.reshape(inp, [-1, inp.shape[-1].value]) for inp in inputs] 38 | 39 | results = [] 40 | 41 | if concat: 42 | input_size = sum(input_size) 43 | inputs = tf.concat(inputs, 1) 44 | 45 | shape = [input_size, output_size] 46 | matrix = tf.get_variable("matrix", shape, dtype=dtype, trainable=trainable) 47 | results.append(tf.matmul(inputs, matrix)) 48 | else: 49 | for i in range(len(input_size)): 50 | shape = [input_size[i], output_size] 51 | name = "matrix_%d" % i 52 | matrix = tf.get_variable(name, shape, dtype=dtype, trainable=trainable) 53 | results.append(tf.matmul(inputs[i], matrix)) 54 | 55 | output = tf.add_n(results) 56 | 57 | if bias: 58 | shape = [output_size] 59 | bias = tf.get_variable("bias", shape, dtype=dtype, trainable=trainable) 60 | output = tf.nn.bias_add(output, bias) 61 | 62 | output = tf.reshape(output, output_shape) 63 | 64 | return output 65 | 66 | 67 | def maxout(inputs, output_size, maxpart=2, use_bias=True, concat=True, 68 | dtype=None, scope=None): 69 | """ 70 | Maxout layer 71 | :param inputs: see the corresponding description of ``linear'' 72 | :param output_size: see the corresponding description of ``linear'' 73 | :param maxpart: an integer, the default value is 2 74 | :param use_bias: a boolean value indicate whether to use bias term 75 | :param concat: concat all tensors if inputs is a list of tensors 76 | :param dtype: an optional instance of tf.Dtype 77 | :param scope: the scope of this layer, the default value is ``maxout'' 78 | :returns: a Tensor with shape [batch, output_size] 79 | :raises RuntimeError: see the corresponding description of ``linear'' 80 | """ 81 | 82 | candidate = linear(inputs, output_size * maxpart, use_bias, concat, 83 | dtype=dtype, scope=scope or "maxout") 84 | shape = tf.concat([tf.shape(candidate)[:-1], [output_size, maxpart]], 85 | axis=0) 86 | value = tf.reshape(candidate, shape) 87 | output = tf.reduce_max(value, -1) 88 | 89 | return output 90 | 91 | 92 | def layer_norm(inputs, epsilon=1e-6, dtype=None, scope=None, trainable=True): 93 | """ 94 | Layer Normalization 95 | :param inputs: A Tensor of shape [..., channel_size] 96 | :param epsilon: A floating number 97 | :param dtype: An optional instance of tf.DType 98 | :param scope: An optional string 99 | :returns: A Tensor with the same shape as inputs 100 | """ 101 | with tf.variable_scope(scope, default_name="layer_norm", values=[inputs], 102 | dtype=dtype): 103 | channel_size = inputs.get_shape().as_list()[-1] 104 | 105 | scale = tf.get_variable("scale", shape=[channel_size], 106 | initializer=tf.ones_initializer(), trainable=trainable) 107 | 108 | offset = tf.get_variable("offset", shape=[channel_size], 109 | initializer=tf.zeros_initializer(), trainable=trainable) 110 | 111 | mean = tf.reduce_mean(inputs, -1, True) 112 | variance = tf.reduce_mean(tf.square(inputs - mean), -1, True) 113 | 114 | norm_inputs = (inputs - mean) * tf.rsqrt(variance + epsilon) 115 | 116 | return norm_inputs * scale + offset 117 | 118 | 119 | def smoothed_softmax_cross_entropy_with_logits(**kwargs): 120 | logits = kwargs.get("logits") 121 | labels = kwargs.get("labels") 122 | smoothing = kwargs.get("smoothing") or 0.0 123 | normalize = kwargs.get("normalize") 124 | scope = kwargs.get("scope") 125 | 126 | if logits is None or labels is None: 127 | raise ValueError("Both logits and labels must be provided") 128 | 129 | with tf.name_scope(scope or "smoothed_softmax_cross_entropy_with_logits", 130 | values=[logits, labels]): 131 | 132 | labels = tf.reshape(labels, [-1]) 133 | 134 | if not smoothing: 135 | ce = tf.nn.sparse_softmax_cross_entropy_with_logits( 136 | logits=logits, 137 | labels=labels 138 | ) 139 | return ce 140 | 141 | # label smoothing 142 | vocab_size = tf.shape(logits)[1] 143 | 144 | n = tf.to_float(vocab_size - 1) 145 | p = 1.0 - smoothing 146 | q = smoothing / n 147 | 148 | soft_targets = tf.one_hot(tf.cast(labels, tf.int32), depth=vocab_size, 149 | on_value=p, off_value=q) 150 | xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, 151 | labels=soft_targets) 152 | 153 | if normalize is False: 154 | return xentropy 155 | 156 | # Normalizing constant is the best cross-entropy value with soft 157 | # targets. We subtract it just for readability, makes no difference on 158 | # learning 159 | normalizing = -(p * tf.log(p) + n * q * tf.log(q + 1e-20)) 160 | 161 | return xentropy - normalizing 162 | -------------------------------------------------------------------------------- /thumt/layers/rnn_cell.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import tensorflow as tf 9 | 10 | from .nn import linear 11 | 12 | 13 | class LegacyGRUCell(tf.nn.rnn_cell.RNNCell): 14 | """ Groundhog's implementation of GRUCell 15 | 16 | :param num_units: int, The number of units in the RNN cell. 17 | :param reuse: (optional) Python boolean describing whether to reuse 18 | variables in an existing scope. If not `True`, and the existing 19 | scope already has the given variables, an error is raised. 20 | """ 21 | 22 | def __init__(self, num_units, reuse=None): 23 | super(LegacyGRUCell, self).__init__(_reuse=reuse) 24 | self._num_units = num_units 25 | 26 | def __call__(self, inputs, state, scope=None): 27 | with tf.variable_scope(scope, default_name="gru_cell", 28 | values=[inputs, state]): 29 | if not isinstance(inputs, (list, tuple)): 30 | inputs = [inputs] 31 | 32 | all_inputs = list(inputs) + [state] 33 | r = tf.nn.sigmoid(linear(all_inputs, self._num_units, False, False, 34 | scope="reset_gate")) 35 | u = tf.nn.sigmoid(linear(all_inputs, self._num_units, False, False, 36 | scope="update_gate")) 37 | all_inputs = list(inputs) + [r * state] 38 | c = linear(all_inputs, self._num_units, True, False, 39 | scope="candidate") 40 | 41 | new_state = (1.0 - u) * state + u * tf.tanh(c) 42 | 43 | return new_state, new_state 44 | 45 | @property 46 | def state_size(self): 47 | return self._num_units 48 | 49 | @property 50 | def output_size(self): 51 | return self._num_units 52 | 53 | 54 | class StateToOutputWrapper(tf.nn.rnn_cell.RNNCell): 55 | """ Copy state to the output of RNNCell so that all states can be obtained 56 | when using tf.nn.dynamic_rnn 57 | 58 | :param cell: An instance of tf.nn.rnn_cell.RNNCell 59 | :param reuse: (optional) Python boolean describing whether to reuse 60 | variables in an existing scope. If not `True`, and the existing 61 | scope already has the given variables, an error is raised. 62 | """ 63 | 64 | def __init__(self, cell, reuse=None): 65 | super(StateToOutputWrapper, self).__init__(_reuse=reuse) 66 | self._cell = cell 67 | 68 | def __call__(self, inputs, state, scope=None): 69 | output, new_state = self._cell(inputs, state, scope=scope) 70 | 71 | return (output, new_state), new_state 72 | 73 | @property 74 | def state_size(self): 75 | return self._cell.state_size 76 | 77 | @property 78 | def output_size(self): 79 | return tuple([self._cell.output_size, self.state_size]) 80 | 81 | 82 | class AttentionWrapper(tf.nn.rnn_cell.RNNCell): 83 | """ Wrap an RNNCell with attention mechanism 84 | 85 | :param cell: An instance of tf.nn.rnn_cell.RNNCell 86 | :param memory: A tensor with shape [batch, mem_size, mem_dim] 87 | :param bias: A tensor with shape [batch, mem_size] 88 | :param attention_fn: A callable function with signature 89 | (inputs, state, memory, bias) -> (output, state, weight, value) 90 | :param output_weight: Whether to output attention weights 91 | :param output_value: Whether to output attention values 92 | :param reuse: (optional) Python boolean describing whether to reuse 93 | variables in an existing scope. If not `True`, and the existing 94 | scope already has the given variables, an error is raised. 95 | """ 96 | 97 | def __init__(self, cell, memory, bias, attention_fn, output_weight=False, 98 | output_value=False, reuse=None): 99 | super(AttentionWrapper, self).__init__(_reuse=reuse) 100 | memory.shape.assert_has_rank(3) 101 | self._cell = cell 102 | self._memory = memory 103 | self._bias = bias 104 | self._attention_fn = attention_fn 105 | self._output_weight = output_weight 106 | self._output_value = output_value 107 | 108 | def __call__(self, inputs, state, scope=None): 109 | outputs = self._attention_fn(inputs, state, self._memory, self._bias) 110 | cell_inputs, cell_state, weight, value = outputs 111 | cell_output, new_state = self._cell(cell_inputs, cell_state, 112 | scope=scope) 113 | 114 | if not self._output_weight and not self._output_value: 115 | return cell_output, new_state 116 | 117 | new_output = [cell_output] 118 | 119 | if self._output_weight: 120 | new_output.append(weights) 121 | 122 | if self._output_value: 123 | new_output.append(value) 124 | 125 | return tuple(new_output), new_state 126 | 127 | @property 128 | def state_size(self): 129 | return self._cell.state_size 130 | 131 | @property 132 | def output_size(self): 133 | if not self._output_weight and not self._output_value: 134 | return self._cell.output_size 135 | 136 | new_output_size = [self._cell.output_size] 137 | 138 | if self._output_weight: 139 | new_output_size.append(self._memory.shape[1]) 140 | 141 | if self._output_value: 142 | new_output_size.append(self._memory.shape[2].value) 143 | 144 | return tuple(new_output_size) 145 | -------------------------------------------------------------------------------- /thumt/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import thumt.models.seq2seq 9 | import thumt.models.rnnsearch 10 | import thumt.models.transformer 11 | import thumt.models.contextual_transformer 12 | 13 | 14 | def get_model(name): 15 | name = name.lower() 16 | 17 | if name == "rnnsearch": 18 | return thumt.models.rnnsearch.RNNsearch 19 | elif name == "seq2seq": 20 | return thumt.models.seq2seq.Seq2Seq 21 | elif name == "transformer": 22 | return thumt.models.transformer.Transformer 23 | elif name == "contextual_transformer": 24 | return thumt.models.contextual_transformer.Contextual_Transformer 25 | else: 26 | raise LookupError("Unknown model %s" % name) 27 | -------------------------------------------------------------------------------- /thumt/models/rnnsearch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import copy 9 | 10 | import tensorflow as tf 11 | import thumt.interface as interface 12 | import thumt.layers as layers 13 | 14 | 15 | def _copy_through(time, length, output, new_output): 16 | copy_cond = (time >= length) 17 | return tf.where(copy_cond, output, new_output) 18 | 19 | 20 | def _gru_encoder(cell, inputs, sequence_length, initial_state, dtype=None): 21 | # Assume that the underlying cell is GRUCell-like 22 | output_size = cell.output_size 23 | dtype = dtype or inputs.dtype 24 | 25 | batch = tf.shape(inputs)[0] 26 | time_steps = tf.shape(inputs)[1] 27 | 28 | zero_output = tf.zeros([batch, output_size], dtype) 29 | 30 | if initial_state is None: 31 | initial_state = cell.zero_state(batch, dtype) 32 | 33 | input_ta = tf.TensorArray(dtype, time_steps, 34 | tensor_array_name="input_array") 35 | output_ta = tf.TensorArray(dtype, time_steps, 36 | tensor_array_name="output_array") 37 | input_ta = input_ta.unstack(tf.transpose(inputs, [1, 0, 2])) 38 | 39 | def loop_func(t, out_ta, state): 40 | inp_t = input_ta.read(t) 41 | cell_output, new_state = cell(inp_t, state) 42 | cell_output = _copy_through(t, sequence_length, zero_output, 43 | cell_output) 44 | new_state = _copy_through(t, sequence_length, state, new_state) 45 | out_ta = out_ta.write(t, cell_output) 46 | return t + 1, out_ta, new_state 47 | 48 | time = tf.constant(0, dtype=tf.int32, name="time") 49 | loop_vars = (time, output_ta, initial_state) 50 | 51 | outputs = tf.while_loop(lambda t, *_: t < time_steps, loop_func, 52 | loop_vars, parallel_iterations=32, 53 | swap_memory=True) 54 | 55 | output_final_ta = outputs[1] 56 | final_state = outputs[2] 57 | 58 | all_output = output_final_ta.stack() 59 | all_output.set_shape([None, None, output_size]) 60 | all_output = tf.transpose(all_output, [1, 0, 2]) 61 | 62 | return all_output, final_state 63 | 64 | 65 | def _encoder(cell_fw, cell_bw, inputs, sequence_length, dtype=None, 66 | scope=None): 67 | with tf.variable_scope(scope or "encoder", 68 | values=[inputs, sequence_length]): 69 | inputs_fw = inputs 70 | inputs_bw = tf.reverse_sequence(inputs, sequence_length, 71 | batch_axis=0, seq_axis=1) 72 | 73 | with tf.variable_scope("forward"): 74 | output_fw, state_fw = _gru_encoder(cell_fw, inputs_fw, 75 | sequence_length, None, 76 | dtype=dtype) 77 | 78 | with tf.variable_scope("backward"): 79 | output_bw, state_bw = _gru_encoder(cell_bw, inputs_bw, 80 | sequence_length, None, 81 | dtype=dtype) 82 | output_bw = tf.reverse_sequence(output_bw, sequence_length, 83 | batch_axis=0, seq_axis=1) 84 | 85 | results = { 86 | "annotation": tf.concat([output_fw, output_bw], axis=2), 87 | "outputs": { 88 | "forward": output_fw, 89 | "backward": output_bw 90 | }, 91 | "final_states": { 92 | "forward": state_fw, 93 | "backward": state_bw 94 | } 95 | } 96 | 97 | return results 98 | 99 | 100 | def _decoder(cell, inputs, memory, sequence_length, initial_state, dtype=None, 101 | scope=None): 102 | # Assume that the underlying cell is GRUCell-like 103 | batch = tf.shape(inputs)[0] 104 | time_steps = tf.shape(inputs)[1] 105 | dtype = dtype or inputs.dtype 106 | output_size = cell.output_size 107 | zero_output = tf.zeros([batch, output_size], dtype) 108 | zero_value = tf.zeros([batch, memory.shape[-1].value], dtype) 109 | 110 | with tf.variable_scope(scope or "decoder", dtype=dtype): 111 | inputs = tf.transpose(inputs, [1, 0, 2]) 112 | mem_mask = tf.sequence_mask(sequence_length["source"], 113 | maxlen=tf.shape(memory)[1], 114 | dtype=tf.float32) 115 | bias = layers.attention.attention_bias(mem_mask, "masking") 116 | bias = tf.squeeze(bias, axis=[1, 2]) 117 | cache = layers.attention.attention(None, memory, None, output_size) 118 | 119 | input_ta = tf.TensorArray(tf.float32, time_steps, 120 | tensor_array_name="input_array") 121 | output_ta = tf.TensorArray(tf.float32, time_steps, 122 | tensor_array_name="output_array") 123 | value_ta = tf.TensorArray(tf.float32, time_steps, 124 | tensor_array_name="value_array") 125 | alpha_ta = tf.TensorArray(tf.float32, time_steps, 126 | tensor_array_name="alpha_array") 127 | input_ta = input_ta.unstack(inputs) 128 | initial_state = layers.nn.linear(initial_state, output_size, True, 129 | False, scope="s_transform") 130 | initial_state = tf.tanh(initial_state) 131 | 132 | def loop_func(t, out_ta, att_ta, val_ta, state, cache_key): 133 | inp_t = input_ta.read(t) 134 | results = layers.attention.attention(state, memory, bias, 135 | output_size, 136 | cache={"key": cache_key}) 137 | alpha = results["weight"] 138 | context = results["value"] 139 | cell_input = [inp_t, context] 140 | cell_output, new_state = cell(cell_input, state) 141 | cell_output = _copy_through(t, sequence_length["target"], 142 | zero_output, cell_output) 143 | new_state = _copy_through(t, sequence_length["target"], state, 144 | new_state) 145 | new_value = _copy_through(t, sequence_length["target"], zero_value, 146 | context) 147 | 148 | out_ta = out_ta.write(t, cell_output) 149 | att_ta = att_ta.write(t, alpha) 150 | val_ta = val_ta.write(t, new_value) 151 | cache_key = tf.identity(cache_key) 152 | return t + 1, out_ta, att_ta, val_ta, new_state, cache_key 153 | 154 | time = tf.constant(0, dtype=tf.int32, name="time") 155 | loop_vars = (time, output_ta, alpha_ta, value_ta, initial_state, 156 | cache["key"]) 157 | 158 | outputs = tf.while_loop(lambda t, *_: t < time_steps, 159 | loop_func, loop_vars, 160 | parallel_iterations=32, 161 | swap_memory=True) 162 | 163 | output_final_ta = outputs[1] 164 | value_final_ta = outputs[3] 165 | 166 | final_output = output_final_ta.stack() 167 | final_output.set_shape([None, None, output_size]) 168 | final_output = tf.transpose(final_output, [1, 0, 2]) 169 | 170 | final_value = value_final_ta.stack() 171 | final_value.set_shape([None, None, memory.shape[-1].value]) 172 | final_value = tf.transpose(final_value, [1, 0, 2]) 173 | 174 | result = { 175 | "outputs": final_output, 176 | "values": final_value, 177 | "initial_state": initial_state 178 | } 179 | 180 | return result 181 | 182 | 183 | def model_graph(features, mode, params): 184 | src_vocab_size = len(params.vocabulary["source"]) 185 | tgt_vocab_size = len(params.vocabulary["target"]) 186 | 187 | with tf.variable_scope("source_embedding"): 188 | src_emb = tf.get_variable("embedding", 189 | [src_vocab_size, params.embedding_size]) 190 | src_bias = tf.get_variable("bias", [params.embedding_size]) 191 | src_inputs = tf.nn.embedding_lookup(src_emb, features["source"]) 192 | 193 | with tf.variable_scope("target_embedding"): 194 | tgt_emb = tf.get_variable("embedding", 195 | [tgt_vocab_size, params.embedding_size]) 196 | tgt_bias = tf.get_variable("bias", [params.embedding_size]) 197 | tgt_inputs = tf.nn.embedding_lookup(tgt_emb, features["target"]) 198 | 199 | src_inputs = tf.nn.bias_add(src_inputs, src_bias) 200 | tgt_inputs = tf.nn.bias_add(tgt_inputs, tgt_bias) 201 | 202 | if params.dropout and not params.use_variational_dropout: 203 | src_inputs = tf.nn.dropout(src_inputs, 1.0 - params.dropout) 204 | tgt_inputs = tf.nn.dropout(tgt_inputs, 1.0 - params.dropout) 205 | 206 | # encoder 207 | cell_fw = layers.rnn_cell.LegacyGRUCell(params.hidden_size) 208 | cell_bw = layers.rnn_cell.LegacyGRUCell(params.hidden_size) 209 | 210 | if params.use_variational_dropout: 211 | cell_fw = tf.nn.rnn_cell.DropoutWrapper( 212 | cell_fw, 213 | input_keep_prob=1.0 - params.dropout, 214 | output_keep_prob=1.0 - params.dropout, 215 | state_keep_prob=1.0 - params.dropout, 216 | variational_recurrent=True, 217 | input_size=params.embedding_size, 218 | dtype=tf.float32 219 | ) 220 | cell_bw = tf.nn.rnn_cell.DropoutWrapper( 221 | cell_bw, 222 | input_keep_prob=1.0 - params.dropout, 223 | output_keep_prob=1.0 - params.dropout, 224 | state_keep_prob=1.0 - params.dropout, 225 | variational_recurrent=True, 226 | input_size=params.embedding_size, 227 | dtype=tf.float32 228 | ) 229 | 230 | encoder_output = _encoder(cell_fw, cell_bw, src_inputs, 231 | features["source_length"]) 232 | 233 | # decoder 234 | cell = layers.rnn_cell.LegacyGRUCell(params.hidden_size) 235 | 236 | if params.use_variational_dropout: 237 | cell = tf.nn.rnn_cell.DropoutWrapper( 238 | cell, 239 | input_keep_prob=1.0 - params.dropout, 240 | output_keep_prob=1.0 - params.dropout, 241 | state_keep_prob=1.0 - params.dropout, 242 | variational_recurrent=True, 243 | # input + context 244 | input_size=params.embedding_size + 2 * params.hidden_size, 245 | dtype=tf.float32 246 | ) 247 | 248 | length = { 249 | "source": features["source_length"], 250 | "target": features["target_length"] 251 | } 252 | initial_state = encoder_output["final_states"]["backward"] 253 | decoder_output = _decoder(cell, tgt_inputs, encoder_output["annotation"], 254 | length, initial_state) 255 | 256 | # Shift left 257 | shifted_tgt_inputs = tf.pad(tgt_inputs, [[0, 0], [1, 0], [0, 0]]) 258 | shifted_tgt_inputs = shifted_tgt_inputs[:, :-1, :] 259 | 260 | all_outputs = tf.concat( 261 | [ 262 | tf.expand_dims(decoder_output["initial_state"], axis=1), 263 | decoder_output["outputs"], 264 | ], 265 | axis=1 266 | ) 267 | shifted_outputs = all_outputs[:, :-1, :] 268 | 269 | maxout_features = [ 270 | shifted_tgt_inputs, 271 | shifted_outputs, 272 | decoder_output["values"] 273 | ] 274 | maxout_size = params.hidden_size // params.maxnum 275 | 276 | if mode is "infer": 277 | # Special case for non-incremental decoding 278 | maxout_features = [ 279 | shifted_tgt_inputs[:, -1, :], 280 | shifted_outputs[:, -1, :], 281 | decoder_output["values"][:, -1, :] 282 | ] 283 | maxhid = layers.nn.maxout(maxout_features, maxout_size, params.maxnum, 284 | concat=False) 285 | readout = layers.nn.linear(maxhid, params.embedding_size, False, 286 | False, scope="deepout") 287 | 288 | # Prediction 289 | logits = layers.nn.linear(readout, tgt_vocab_size, True, False, 290 | scope="softmax") 291 | 292 | return tf.nn.log_softmax(logits) 293 | 294 | maxhid = layers.nn.maxout(maxout_features, maxout_size, params.maxnum, 295 | concat=False) 296 | readout = layers.nn.linear(maxhid, params.embedding_size, False, False, 297 | scope="deepout") 298 | 299 | if params.dropout and not params.use_variational_dropout: 300 | readout = tf.nn.dropout(readout, 1.0 - params.dropout) 301 | 302 | # Prediction 303 | logits = layers.nn.linear(readout, tgt_vocab_size, True, False, 304 | scope="softmax") 305 | logits = tf.reshape(logits, [-1, tgt_vocab_size]) 306 | labels = features["target"] 307 | 308 | ce = layers.nn.smoothed_softmax_cross_entropy_with_logits( 309 | logits=logits, 310 | labels=labels, 311 | smoothing=params.label_smoothing, 312 | normalize=True 313 | ) 314 | 315 | ce = tf.reshape(ce, tf.shape(labels)) 316 | tgt_mask = tf.to_float( 317 | tf.sequence_mask( 318 | features["target_length"], 319 | maxlen=tf.shape(features["target"])[1] 320 | ) 321 | ) 322 | 323 | if mode == "eval": 324 | return -tf.reduce_sum(ce * tgt_mask, axis=1) 325 | 326 | loss = tf.reduce_sum(ce * tgt_mask) / tf.reduce_sum(tgt_mask) 327 | 328 | return loss 329 | 330 | 331 | class RNNsearch(interface.NMTModel): 332 | 333 | def __init__(self, params, scope="rnnsearch"): 334 | super(RNNsearch, self).__init__(params=params, scope=scope) 335 | 336 | def get_training_func(self, initializer): 337 | def training_fn(features, params=None, reuse=None): 338 | if params is None: 339 | params = self.parameters 340 | with tf.variable_scope(self._scope, initializer=initializer, 341 | reuse=reuse): 342 | loss = model_graph(features, "train", params) 343 | return loss 344 | 345 | return training_fn 346 | 347 | def get_evaluation_func(self): 348 | def evaluation_fn(features, params=None): 349 | if params is None: 350 | params = copy.copy(self.parameters) 351 | else: 352 | params = copy.copy(params) 353 | 354 | params.dropout = 0.0 355 | params.use_variational_dropout = False 356 | params.label_smoothing = 0.0 357 | 358 | with tf.variable_scope(self._scope): 359 | score = model_graph(features, "eval", params) 360 | 361 | return score 362 | 363 | return evaluation_fn 364 | 365 | def get_inference_func(self): 366 | def inference_fn(features, params=None): 367 | if params is None: 368 | params = copy.copy(self.parameters) 369 | else: 370 | params = copy.copy(params) 371 | 372 | params.dropout = 0.0 373 | params.use_variational_dropout = False 374 | params.label_smoothing = 0.0 375 | 376 | with tf.variable_scope(self._scope): 377 | log_prob = model_graph(features, "infer", params) 378 | 379 | return log_prob 380 | 381 | return inference_fn 382 | 383 | @staticmethod 384 | def get_name(): 385 | return "rnnsearch" 386 | 387 | @staticmethod 388 | def get_parameters(): 389 | params = tf.contrib.training.HParams( 390 | # vocabulary 391 | pad="", 392 | unk="", 393 | eos="", 394 | bos="", 395 | append_eos=False, 396 | # model 397 | rnn_cell="LegacyGRUCell", 398 | embedding_size=620, 399 | hidden_size=1000, 400 | maxnum=2, 401 | # regularization 402 | dropout=0.2, 403 | use_variational_dropout=False, 404 | label_smoothing=0.1, 405 | constant_batch_size=True, 406 | batch_size=128, 407 | max_length=60, 408 | clip_grad_norm=5.0 409 | ) 410 | 411 | return params 412 | -------------------------------------------------------------------------------- /thumt/models/seq2seq.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import copy 9 | 10 | import tensorflow as tf 11 | import thumt.interface as interface 12 | import thumt.layers as layers 13 | 14 | 15 | def model_graph(features, mode, params): 16 | src_vocab_size = len(params.vocabulary["source"]) 17 | tgt_vocab_size = len(params.vocabulary["target"]) 18 | 19 | src_seq = features["source"] 20 | tgt_seq = features["target"] 21 | 22 | if params.reverse_source: 23 | src_seq = tf.reverse_sequence(src_seq, seq_dim=1, 24 | seq_lengths=features["source_length"]) 25 | 26 | with tf.device("/cpu:0"): 27 | with tf.variable_scope("source_embedding"): 28 | src_emb = tf.get_variable("embedding", 29 | [src_vocab_size, params.embedding_size]) 30 | src_bias = tf.get_variable("bias", [params.embedding_size]) 31 | src_inputs = tf.nn.embedding_lookup(src_emb, src_seq) 32 | 33 | with tf.variable_scope("target_embedding"): 34 | tgt_emb = tf.get_variable("embedding", 35 | [tgt_vocab_size, params.embedding_size]) 36 | tgt_bias = tf.get_variable("bias", [params.embedding_size]) 37 | tgt_inputs = tf.nn.embedding_lookup(tgt_emb, tgt_seq) 38 | 39 | src_inputs = tf.nn.bias_add(src_inputs, src_bias) 40 | tgt_inputs = tf.nn.bias_add(tgt_inputs, tgt_bias) 41 | 42 | if params.dropout and not params.use_variational_dropout: 43 | src_inputs = tf.nn.dropout(src_inputs, 1.0 - params.dropout) 44 | tgt_inputs = tf.nn.dropout(tgt_inputs, 1.0 - params.dropout) 45 | 46 | cell_enc = [] 47 | cell_dec = [] 48 | 49 | for _ in range(params.num_hidden_layers): 50 | if params.rnn_cell == "LSTMCell": 51 | cell_e = tf.nn.rnn_cell.BasicLSTMCell(params.hidden_size) 52 | cell_d = tf.nn.rnn_cell.BasicLSTMCell(params.hidden_size) 53 | elif params.rnn_cell == "GRUCell": 54 | cell_e = tf.nn.rnn_cell.GRUCell(params.hidden_size) 55 | cell_d = tf.nn.rnn_cell.GRUCell(params.hidden_size) 56 | else: 57 | raise ValueError("%s not supported" % params.rnn_cell) 58 | 59 | cell_e = tf.nn.rnn_cell.DropoutWrapper( 60 | cell_e, 61 | output_keep_prob=1.0 - params.dropout, 62 | variational_recurrent=params.use_variational_dropout, 63 | input_size=params.embedding_size, 64 | dtype=tf.float32 65 | ) 66 | cell_d = tf.nn.rnn_cell.DropoutWrapper( 67 | cell_d, 68 | output_keep_prob=1.0 - params.dropout, 69 | variational_recurrent=params.use_variational_dropout, 70 | input_size=params.embedding_size, 71 | dtype=tf.float32 72 | ) 73 | 74 | if params.use_residual: 75 | cell_e = tf.nn.rnn_cell.ResidualWrapper(cell_e) 76 | cell_d = tf.nn.rnn_cell.ResidualWrapper(cell_d) 77 | 78 | cell_enc.append(cell_e) 79 | cell_dec.append(cell_d) 80 | 81 | cell_enc = tf.nn.rnn_cell.MultiRNNCell(cell_enc) 82 | cell_dec = tf.nn.rnn_cell.MultiRNNCell(cell_dec) 83 | 84 | with tf.variable_scope("encoder"): 85 | _, final_state = tf.nn.dynamic_rnn(cell_enc, src_inputs, 86 | features["source_length"], 87 | dtype=tf.float32) 88 | # Shift left 89 | shifted_tgt_inputs = tf.pad(tgt_inputs, [[0, 0], [1, 0], [0, 0]]) 90 | shifted_tgt_inputs = shifted_tgt_inputs[:, :-1, :] 91 | 92 | with tf.variable_scope("decoder"): 93 | outputs, _ = tf.nn.dynamic_rnn(cell_dec, shifted_tgt_inputs, 94 | features["target_length"], 95 | initial_state=final_state) 96 | 97 | if params.dropout: 98 | outputs = tf.nn.dropout(outputs, 1.0 - params.dropout) 99 | 100 | if mode == "infer": 101 | # Prediction 102 | logits = layers.nn.linear(outputs[:, -1, :], tgt_vocab_size, True, 103 | scope="softmax") 104 | 105 | return tf.nn.log_softmax(logits) 106 | 107 | # Prediction 108 | logits = layers.nn.linear(outputs, tgt_vocab_size, True, scope="softmax") 109 | logits = tf.reshape(logits, [-1, tgt_vocab_size]) 110 | labels = features["target"] 111 | 112 | ce = layers.nn.smoothed_softmax_cross_entropy_with_logits( 113 | logits=logits, 114 | labels=labels, 115 | smoothing=params.label_smoothing, 116 | normalize=True 117 | ) 118 | 119 | ce = tf.reshape(ce, tf.shape(labels)) 120 | tgt_mask = tf.to_float( 121 | tf.sequence_mask( 122 | features["target_length"], 123 | maxlen=tf.shape(features["target"])[1] 124 | ) 125 | ) 126 | 127 | if mode == "eval": 128 | return -tf.reduce_sum(ce * tgt_mask, axis=1) 129 | 130 | loss = tf.reduce_sum(ce * tgt_mask) / tf.reduce_sum(tgt_mask) 131 | 132 | return loss 133 | 134 | 135 | class Seq2Seq(interface.NMTModel): 136 | 137 | def __init__(self, params, scope="seq2seq"): 138 | super(Seq2Seq, self).__init__(params=params, scope=scope) 139 | 140 | def get_training_func(self, initializer): 141 | def training_fn(features, params=None, reuse=None): 142 | if params is None: 143 | params = self.parameters 144 | with tf.variable_scope(self._scope, initializer=initializer, 145 | reuse=reuse): 146 | loss = model_graph(features, "train", params) 147 | return loss 148 | 149 | return training_fn 150 | 151 | def get_evaluation_func(self): 152 | def evaluation_fn(features, params=None): 153 | if params is None: 154 | params = copy.copy(self.parameters) 155 | else: 156 | params = copy.copy(params) 157 | params.dropout = 0.0 158 | params.use_variational_dropout = False 159 | params.label_smoothing = 0.0 160 | 161 | with tf.variable_scope(self._scope): 162 | score = model_graph(features, "eval", params) 163 | 164 | return score 165 | 166 | return evaluation_fn 167 | 168 | def get_inference_func(self): 169 | def inference_fn(features, params=None): 170 | if params is None: 171 | params = copy.copy(self.parameters) 172 | else: 173 | params = copy.copy(params) 174 | params.dropout = 0.0 175 | params.use_variational_dropout = False 176 | params.label_smoothing = 0.0 177 | 178 | with tf.variable_scope(self._scope): 179 | logits = model_graph(features, "infer", params) 180 | 181 | return logits 182 | 183 | return inference_fn 184 | 185 | @staticmethod 186 | def get_name(): 187 | return "seq2seq" 188 | 189 | @staticmethod 190 | def get_parameters(): 191 | params = tf.contrib.training.HParams( 192 | # vocabulary 193 | pad="", 194 | bos="", 195 | eos="", 196 | unk="", 197 | append_eos=False, 198 | # model 199 | rnn_cell="LSTMCell", 200 | embedding_size=1000, 201 | hidden_size=1000, 202 | num_hidden_layers=4, 203 | # regularization 204 | dropout=0.2, 205 | use_variational_dropout=False, 206 | label_smoothing=0.1, 207 | constant_batch_size=True, 208 | batch_size=128, 209 | max_length=80, 210 | reverse_source=True, 211 | use_residual=True, 212 | clip_grad_norm=5.0 213 | ) 214 | 215 | return params 216 | -------------------------------------------------------------------------------- /thumt/scripts/build_vocab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import collections 11 | 12 | 13 | def count_words(filename): 14 | counter = collections.Counter() 15 | 16 | with open(filename, "r") as fd: 17 | for line in fd: 18 | words = line.strip().split() 19 | counter.update(words) 20 | 21 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) 22 | words, counts = list(zip(*count_pairs)) 23 | 24 | return words, counts 25 | 26 | 27 | def control_symbols(string): 28 | if not string: 29 | return [] 30 | else: 31 | return string.strip().split(",") 32 | 33 | 34 | def save_vocab(name, vocab): 35 | if name.split(".")[-1] != "txt": 36 | name = name + ".txt" 37 | 38 | pairs = sorted(vocab.items(), key=lambda x: (x[1], x[0])) 39 | words, ids = list(zip(*pairs)) 40 | 41 | with open(name, "w") as f: 42 | for word in words: 43 | f.write(word + "\n") 44 | 45 | 46 | def parse_args(): 47 | parser = argparse.ArgumentParser(description="Create vocabulary") 48 | 49 | parser.add_argument("corpus", help="input corpus") 50 | parser.add_argument("output", default="vocab.txt", 51 | help="Output vocabulary name") 52 | parser.add_argument("--limit", default=0, type=int, help="Vocabulary size") 53 | parser.add_argument("--control", type=str, default=",,", 54 | help="Add control symbols to vocabulary. " 55 | "Control symbols are separated by comma.") 56 | 57 | return parser.parse_args() 58 | 59 | 60 | def main(args): 61 | vocab = {} 62 | limit = args.limit 63 | count = 0 64 | 65 | words, counts = count_words(args.corpus) 66 | ctrl_symbols = control_symbols(args.control) 67 | 68 | for sym in ctrl_symbols: 69 | vocab[sym] = len(vocab) 70 | 71 | for word, freq in zip(words, counts): 72 | if limit and len(vocab) >= limit: 73 | break 74 | 75 | if word in vocab: 76 | print("Warning: found duplicate token %s, ignored" % word) 77 | continue 78 | 79 | vocab[word] = len(vocab) 80 | count += freq 81 | 82 | save_vocab(args.output, vocab) 83 | 84 | print("Total words: %d" % sum(counts)) 85 | print("Unique words: %d" % len(words)) 86 | print("Vocabulary coverage: %4.2f%%" % (100.0 * count / sum(counts))) 87 | 88 | 89 | if __name__ == "__main__": 90 | main(parse_args()) 91 | -------------------------------------------------------------------------------- /thumt/scripts/change.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import operator 11 | import os 12 | 13 | import numpy as np 14 | import tensorflow as tf 15 | 16 | 17 | def parseargs(): 18 | msg = "Average checkpoints" 19 | usage = "average.py [] [-h | --help]" 20 | parser = argparse.ArgumentParser(description=msg, usage=usage) 21 | 22 | parser.add_argument("--model", type=str, required=True, 23 | help="checkpoint dir") 24 | parser.add_argument("--output", type=str, help="output path") 25 | 26 | return parser.parse_args() 27 | 28 | def main(_): 29 | tf.logging.set_verbosity(tf.logging.INFO) 30 | 31 | var_list = tf.contrib.framework.list_variables(FLAGS.model) 32 | var_values, var_dtypes = {}, {} 33 | model_from = "transformer" 34 | model_to = "contextual_transformer" 35 | 36 | for (name, shape) in var_list: 37 | if True:#not name.startswith("global_step") and not 'Adam' in name: 38 | name = name.replace(model_from, model_to) 39 | var_values[name] = np.zeros(shape) 40 | print(name) 41 | 42 | reader = tf.contrib.framework.load_checkpoint(FLAGS.model) 43 | for name in var_values: 44 | name_ori = name.replace(model_to, model_from) 45 | tensor = reader.get_tensor(name_ori) 46 | var_dtypes[name] = tensor.dtype 47 | var_values[name] += tensor 48 | tf.logging.info("Read from %s", FLAGS.model) 49 | 50 | tf_vars = [ 51 | tf.get_variable(name, shape=var_values[name].shape, 52 | dtype=var_dtypes[name]) for name in var_values 53 | ] 54 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 55 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 56 | global_step = tf.Variable(0, name="global_step", trainable=False, 57 | dtype=tf.int64) 58 | saver = tf.train.Saver(tf.global_variables()) 59 | 60 | with tf.Session() as sess: 61 | sess.run(tf.global_variables_initializer()) 62 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 63 | var_values.iteritems()): 64 | sess.run(assign_op, {p: value}) 65 | saved_name = os.path.join(FLAGS.output, "new") 66 | saver.save(sess, saved_name, global_step=global_step) 67 | 68 | tf.logging.info("Averaged checkpoints saved in %s", saved_name) 69 | 70 | if __name__ == "__main__": 71 | FLAGS = parseargs() 72 | tf.app.run() 73 | -------------------------------------------------------------------------------- /thumt/scripts/check_param.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import operator 11 | import os 12 | 13 | import numpy as np 14 | import tensorflow as tf 15 | 16 | 17 | def parseargs(): 18 | msg = "Average checkpoints" 19 | usage = "average.py [] [-h | --help]" 20 | parser = argparse.ArgumentParser(description=msg, usage=usage) 21 | 22 | parser.add_argument("--model", type=str, required=True, 23 | help="checkpoint dir") 24 | 25 | return parser.parse_args() 26 | 27 | def main(_): 28 | tf.logging.set_verbosity(tf.logging.INFO) 29 | 30 | var_list = tf.contrib.framework.list_variables(FLAGS.model) 31 | var_values, var_dtypes = {}, {} 32 | model_from = "transformer_cov" 33 | model_to = "transformer_lrp" 34 | 35 | count = 0 36 | for (name, shape) in var_list: 37 | if True:#not name.startswith("global_step") and not 'Adam' in name: 38 | count += 1 39 | print(name, shape) 40 | name = name.replace(model_from, model_to) 41 | var_values[name] = np.zeros(shape) 42 | print(len(var_list)) 43 | print(count) 44 | 45 | 46 | if __name__ == "__main__": 47 | FLAGS = parseargs() 48 | tf.app.run() 49 | -------------------------------------------------------------------------------- /thumt/scripts/checkpoint_averaging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import operator 11 | import os 12 | 13 | import numpy as np 14 | import tensorflow as tf 15 | 16 | 17 | def parseargs(): 18 | msg = "Average checkpoints" 19 | usage = "average.py [] [-h | --help]" 20 | parser = argparse.ArgumentParser(description=msg, usage=usage) 21 | 22 | parser.add_argument("--path", type=str, required=True, 23 | help="checkpoint dir") 24 | parser.add_argument("--checkpoints", type=int, required=True, 25 | help="number of checkpoints to use") 26 | parser.add_argument("--output", type=str, help="output path") 27 | 28 | return parser.parse_args() 29 | 30 | 31 | def get_checkpoints(path): 32 | if not tf.gfile.Exists(os.path.join(path, "checkpoint")): 33 | raise ValueError("Cannot find checkpoints in %s" % path) 34 | 35 | checkpoint_names = [] 36 | 37 | with tf.gfile.GFile(os.path.join(path, "checkpoint")) as fd: 38 | # Skip the first line 39 | fd.readline() 40 | for line in fd: 41 | name = line.strip().split(":")[-1].strip()[1:-1] 42 | key = int(name.split("-")[-1]) 43 | checkpoint_names.append((key, os.path.join(path, name))) 44 | 45 | sorted_names = sorted(checkpoint_names, key=operator.itemgetter(0), 46 | reverse=True) 47 | 48 | return [item[-1] for item in sorted_names] 49 | 50 | 51 | def checkpoint_exists(path): 52 | return (tf.gfile.Exists(path) or tf.gfile.Exists(path + ".meta") or 53 | tf.gfile.Exists(path + ".index")) 54 | 55 | 56 | def main(_): 57 | tf.logging.set_verbosity(tf.logging.INFO) 58 | checkpoints = get_checkpoints(FLAGS.path) 59 | checkpoints = checkpoints[:FLAGS.checkpoints] 60 | 61 | if not checkpoints: 62 | raise ValueError("No checkpoints provided for averaging.") 63 | 64 | checkpoints = [c for c in checkpoints if checkpoint_exists(c)] 65 | 66 | if not checkpoints: 67 | raise ValueError( 68 | "None of the provided checkpoints exist. %s" % FLAGS.checkpoints 69 | ) 70 | 71 | var_list = tf.contrib.framework.list_variables(checkpoints[0]) 72 | var_values, var_dtypes = {}, {} 73 | 74 | for (name, shape) in var_list: 75 | if not name.startswith("global_step"): 76 | var_values[name] = np.zeros(shape) 77 | 78 | for checkpoint in checkpoints: 79 | reader = tf.contrib.framework.load_checkpoint(checkpoint) 80 | for name in var_values: 81 | tensor = reader.get_tensor(name) 82 | var_dtypes[name] = tensor.dtype 83 | var_values[name] += tensor 84 | tf.logging.info("Read from checkpoint %s", checkpoint) 85 | 86 | # Average checkpoints 87 | for name in var_values: 88 | var_values[name] /= len(checkpoints) 89 | 90 | tf_vars = [ 91 | tf.get_variable(name, shape=var_values[name].shape, 92 | dtype=var_dtypes[name]) for name in var_values 93 | ] 94 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 95 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 96 | global_step = tf.Variable(0, name="global_step", trainable=False, 97 | dtype=tf.int64) 98 | saver = tf.train.Saver(tf.global_variables()) 99 | 100 | with tf.Session() as sess: 101 | sess.run(tf.global_variables_initializer()) 102 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 103 | var_values.iteritems()): 104 | sess.run(assign_op, {p: value}) 105 | saved_name = os.path.join(FLAGS.output, "average") 106 | saver.save(sess, saved_name, global_step=global_step) 107 | 108 | tf.logging.info("Averaged checkpoints saved in %s", saved_name) 109 | 110 | params_pattern = os.path.join(FLAGS.path, "*.json") 111 | params_files = tf.gfile.Glob(params_pattern) 112 | 113 | for name in params_files: 114 | new_name = name.replace(FLAGS.path.rstrip("/"), 115 | FLAGS.output.rstrip("/")) 116 | tf.gfile.Copy(name, new_name, overwrite=True) 117 | 118 | 119 | if __name__ == "__main__": 120 | FLAGS = parseargs() 121 | tf.app.run() 122 | -------------------------------------------------------------------------------- /thumt/scripts/combine.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import operator 11 | import os 12 | 13 | import numpy as np 14 | import tensorflow as tf 15 | 16 | 17 | def parseargs(): 18 | msg = "Average checkpoints" 19 | usage = "average.py [] [-h | --help]" 20 | parser = argparse.ArgumentParser(description=msg, usage=usage) 21 | 22 | parser.add_argument("--model", type=str, required=True, 23 | help="checkpoint dir") 24 | parser.add_argument("--part", type=str, required=True, 25 | help="partial model dir") 26 | parser.add_argument("--output", type=str, help="output path") 27 | 28 | return parser.parse_args() 29 | 30 | def main(_): 31 | tf.logging.set_verbosity(tf.logging.INFO) 32 | 33 | var_list = tf.contrib.framework.list_variables(FLAGS.model) 34 | var_part = tf.contrib.framework.list_variables(FLAGS.part) 35 | var_values, var_dtypes = {}, {} 36 | var_values_part = {} 37 | 38 | for (name, shape) in var_list: 39 | if True:#not name.startswith("global_step") and not 'Adam' in name: 40 | var_values[name] = np.zeros(shape) 41 | for (name, shape) in var_part: 42 | var_values_part[name] = np.zeros(shape) 43 | 44 | reader = tf.contrib.framework.load_checkpoint(FLAGS.model) 45 | reader_part = tf.contrib.framework.load_checkpoint(FLAGS.part) 46 | for name in var_values: 47 | if name in var_values_part: 48 | tensor = reader_part.get_tensor(name) 49 | var_dtypes[name] = tensor.dtype 50 | var_values[name] += tensor 51 | print(name+' in part') 52 | else: 53 | tensor = reader.get_tensor(name) 54 | var_dtypes[name] = tensor.dtype 55 | var_values[name] += tensor 56 | print(name+' is new') 57 | tf.logging.info("Read from %s and %s", FLAGS.model, FLAGS.part) 58 | 59 | tf_vars = [ 60 | tf.get_variable(name, shape=var_values[name].shape, 61 | dtype=var_dtypes[name]) for name in var_values 62 | ] 63 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 64 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 65 | global_step = tf.Variable(0, name="global_step", trainable=False, 66 | dtype=tf.int64) 67 | saver = tf.train.Saver(tf.global_variables()) 68 | 69 | with tf.Session() as sess: 70 | sess.run(tf.global_variables_initializer()) 71 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 72 | var_values.iteritems()): 73 | sess.run(assign_op, {p: value}) 74 | saved_name = os.path.join(FLAGS.output, "new") 75 | saver.save(sess, saved_name, global_step=global_step) 76 | 77 | tf.logging.info("Averaged checkpoints saved in %s", saved_name) 78 | 79 | if __name__ == "__main__": 80 | FLAGS = parseargs() 81 | tf.app.run() 82 | -------------------------------------------------------------------------------- /thumt/scripts/combine_add.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import operator 11 | import os 12 | 13 | import numpy as np 14 | import tensorflow as tf 15 | 16 | 17 | def parseargs(): 18 | msg = "Average checkpoints" 19 | usage = "average.py [] [-h | --help]" 20 | parser = argparse.ArgumentParser(description=msg, usage=usage) 21 | 22 | parser.add_argument("--model", type=str, required=True, 23 | help="checkpoint dir") 24 | parser.add_argument("--part", type=str, required=True, 25 | help="partial model dir") 26 | parser.add_argument("--output", type=str, help="output path") 27 | 28 | return parser.parse_args() 29 | 30 | def main(_): 31 | tf.logging.set_verbosity(tf.logging.INFO) 32 | 33 | var_list = tf.contrib.framework.list_variables(FLAGS.model) 34 | var_part = tf.contrib.framework.list_variables(FLAGS.part) 35 | var_values, var_dtypes = {}, {} 36 | var_values_part = {} 37 | 38 | for (name, shape) in var_list: 39 | if True:#not name.startswith("global_step") and not 'Adam' in name: 40 | var_values[name] = np.zeros(shape) 41 | for (name, shape) in var_part: 42 | var_values[name] = np.zeros(shape) 43 | var_values_part[name] = np.zeros(shape) 44 | 45 | reader = tf.contrib.framework.load_checkpoint(FLAGS.model) 46 | reader_part = tf.contrib.framework.load_checkpoint(FLAGS.part) 47 | for name in var_values: 48 | if name in var_values_part: 49 | tensor = reader_part.get_tensor(name) 50 | var_dtypes[name] = tensor.dtype 51 | var_values[name] += tensor 52 | print(name+' in part') 53 | else: 54 | tensor = reader.get_tensor(name) 55 | var_dtypes[name] = tensor.dtype 56 | var_values[name] += tensor 57 | print(name+' is new') 58 | tf.logging.info("Read from %s and %s", FLAGS.model, FLAGS.part) 59 | 60 | tf_vars = [ 61 | tf.get_variable(name, shape=var_values[name].shape, 62 | dtype=var_dtypes[name]) for name in var_values 63 | ] 64 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 65 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 66 | global_step = tf.Variable(0, name="global_step", trainable=False, 67 | dtype=tf.int64) 68 | saver = tf.train.Saver(tf.global_variables()) 69 | 70 | with tf.Session() as sess: 71 | sess.run(tf.global_variables_initializer()) 72 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 73 | var_values.iteritems()): 74 | sess.run(assign_op, {p: value}) 75 | saved_name = os.path.join(FLAGS.output, "new") 76 | saver.save(sess, saved_name, global_step=global_step) 77 | 78 | tf.logging.info("Averaged checkpoints saved in %s", saved_name) 79 | 80 | if __name__ == "__main__": 81 | FLAGS = parseargs() 82 | tf.app.run() 83 | -------------------------------------------------------------------------------- /thumt/scripts/compare.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import operator 11 | import os 12 | 13 | import numpy as np 14 | import tensorflow as tf 15 | 16 | 17 | def parseargs(): 18 | msg = "Average checkpoints" 19 | usage = "average.py [] [-h | --help]" 20 | parser = argparse.ArgumentParser(description=msg, usage=usage) 21 | 22 | parser.add_argument("--model", type=str, required=True, 23 | help="checkpoint dir") 24 | parser.add_argument("--part", type=str, required=True, 25 | help="partial model dir") 26 | 27 | return parser.parse_args() 28 | 29 | def main(_): 30 | tf.logging.set_verbosity(tf.logging.INFO) 31 | 32 | var_list = tf.contrib.framework.list_variables(FLAGS.model) 33 | var_part = tf.contrib.framework.list_variables(FLAGS.part) 34 | var_values, var_dtypes = {}, {} 35 | var_values_part = {} 36 | 37 | for (name, shape) in var_list: 38 | var_values[name] = np.zeros(shape) 39 | for (name, shape) in var_part: 40 | var_values_part[name] = np.zeros(shape) 41 | 42 | reader = tf.contrib.framework.load_checkpoint(FLAGS.model) 43 | reader_part = tf.contrib.framework.load_checkpoint(FLAGS.part) 44 | for name in var_values: 45 | if name in var_values_part: 46 | tensor_part = reader_part.get_tensor(name) 47 | tensor = reader.get_tensor(name) 48 | print(type(tensor)) 49 | if tensor.equal(tensor_part): 50 | print('name '+name+' equals') 51 | else: 52 | print('name '+name+' is different') 53 | tf.logging.info("Read from %s and %s", FLAGS.model, FLAGS.part) 54 | 55 | if __name__ == "__main__": 56 | FLAGS = parseargs() 57 | tf.app.run() 58 | -------------------------------------------------------------------------------- /thumt/scripts/convert_old_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | 15 | def parseargs(): 16 | parser = argparse.ArgumentParser(description="Convert old models") 17 | 18 | parser.add_argument("--input", type=str, required=True, 19 | help="Path of old model") 20 | parser.add_argument("--output", type=str, required=True, 21 | help="Path of output checkpoint") 22 | 23 | return parser.parse_args() 24 | 25 | 26 | def old_keys(): 27 | keys = [ 28 | "GRU_dec_attcontext", 29 | "GRU_dec_att", 30 | "GRU_dec_atthidden", 31 | "GRU_dec_inputoffset", 32 | "GRU_dec_inputemb", 33 | "GRU_dec_inputcontext", 34 | "GRU_dec_inputhidden", 35 | "GRU_dec_resetemb", 36 | "GRU_dec_resetcontext", 37 | "GRU_dec_resethidden", 38 | "GRU_dec_gateemb", 39 | "GRU_dec_gatecontext", 40 | "GRU_dec_gatehidden", 41 | "initer_b", 42 | "initer_W", 43 | "GRU_dec_probsemb", 44 | "GRU_enc_back_inputoffset", 45 | "GRU_enc_back_inputemb", 46 | "GRU_enc_back_inputhidden", 47 | "GRU_enc_back_resetemb", 48 | "GRU_enc_back_resethidden", 49 | "GRU_enc_back_gateemb", 50 | "GRU_enc_back_gatehidden", 51 | "GRU_enc_inputoffset", 52 | "GRU_enc_inputemb", 53 | "GRU_enc_inputhidden", 54 | "GRU_enc_resetemb", 55 | "GRU_enc_resethidden", 56 | "GRU_enc_gateemb", 57 | "GRU_enc_gatehidden", 58 | "GRU_dec_readoutoffset", 59 | "GRU_dec_readoutemb", 60 | "GRU_dec_readouthidden", 61 | "GRU_dec_readoutcontext", 62 | "GRU_dec_probsoffset", 63 | "GRU_dec_probs", 64 | "emb_src_b", 65 | "emb_src_emb", 66 | "emb_trg_b", 67 | "emb_trg_emb" 68 | ] 69 | 70 | return keys 71 | 72 | 73 | def new_keys(): 74 | keys = [ 75 | "rnnsearch/decoder/attention/k_transform/matrix_0", 76 | "rnnsearch/decoder/attention/logits/matrix_0", 77 | "rnnsearch/decoder/attention/q_transform/matrix_0", 78 | "rnnsearch/decoder/gru_cell/candidate/bias", 79 | "rnnsearch/decoder/gru_cell/candidate/matrix_0", 80 | "rnnsearch/decoder/gru_cell/candidate/matrix_1", 81 | "rnnsearch/decoder/gru_cell/candidate/matrix_2", 82 | "rnnsearch/decoder/gru_cell/reset_gate/matrix_0", 83 | "rnnsearch/decoder/gru_cell/reset_gate/matrix_1", 84 | "rnnsearch/decoder/gru_cell/reset_gate/matrix_2", 85 | "rnnsearch/decoder/gru_cell/update_gate/matrix_0", 86 | "rnnsearch/decoder/gru_cell/update_gate/matrix_1", 87 | "rnnsearch/decoder/gru_cell/update_gate/matrix_2", 88 | "rnnsearch/decoder/s_transform/bias", 89 | "rnnsearch/decoder/s_transform/matrix_0", 90 | "rnnsearch/deepout/matrix_0", 91 | "rnnsearch/encoder/backward/gru_cell/candidate/bias", 92 | "rnnsearch/encoder/backward/gru_cell/candidate/matrix_0", 93 | "rnnsearch/encoder/backward/gru_cell/candidate/matrix_1", 94 | "rnnsearch/encoder/backward/gru_cell/reset_gate/matrix_0", 95 | "rnnsearch/encoder/backward/gru_cell/reset_gate/matrix_1", 96 | "rnnsearch/encoder/backward/gru_cell/update_gate/matrix_0", 97 | "rnnsearch/encoder/backward/gru_cell/update_gate/matrix_1", 98 | "rnnsearch/encoder/forward/gru_cell/candidate/bias", 99 | "rnnsearch/encoder/forward/gru_cell/candidate/matrix_0", 100 | "rnnsearch/encoder/forward/gru_cell/candidate/matrix_1", 101 | "rnnsearch/encoder/forward/gru_cell/reset_gate/matrix_0", 102 | "rnnsearch/encoder/forward/gru_cell/reset_gate/matrix_1", 103 | "rnnsearch/encoder/forward/gru_cell/update_gate/matrix_0", 104 | "rnnsearch/encoder/forward/gru_cell/update_gate/matrix_1", 105 | "rnnsearch/maxout/bias", 106 | "rnnsearch/maxout/matrix_0", 107 | "rnnsearch/maxout/matrix_1", 108 | "rnnsearch/maxout/matrix_2", 109 | "rnnsearch/softmax/bias", 110 | "rnnsearch/softmax/matrix_0", 111 | "rnnsearch/source_embedding/bias", 112 | "rnnsearch/source_embedding/embedding", 113 | "rnnsearch/target_embedding/bias", 114 | "rnnsearch/target_embedding/embedding", 115 | ] 116 | 117 | return keys 118 | 119 | 120 | def main(args): 121 | values = dict(np.load(args.input)) 122 | variables = {} 123 | o_keys = old_keys() 124 | n_keys = new_keys() 125 | 126 | for i, key in enumerate(o_keys): 127 | v = values[key] 128 | variables[n_keys[i]] = v 129 | 130 | with tf.Graph().as_default(): 131 | with tf.device("/cpu:0"): 132 | tf_vars = [ 133 | tf.get_variable(v, initializer=variables[v], dtype=tf.float32) 134 | for v in variables 135 | ] 136 | global_step = tf.Variable(0, name="global_step", trainable=False, 137 | dtype=tf.int64) 138 | 139 | saver = tf.train.Saver(tf_vars) 140 | 141 | with tf.Session() as sess: 142 | sess.run(tf.global_variables_initializer()) 143 | saver.save(sess, args.output, global_step=global_step) 144 | 145 | 146 | if __name__ == "__main__": 147 | main(parseargs()) 148 | -------------------------------------------------------------------------------- /thumt/scripts/convert_vocab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 The THUMT Authors 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import cPickle 10 | import sys 11 | 12 | if __name__ == "__main__": 13 | with open(sys.argv[1]) as fd: 14 | voc = cPickle.load(fd) 15 | 16 | ivoc = {} 17 | 18 | for key in voc: 19 | ivoc[voc[key]] = key 20 | 21 | with open(sys.argv[2], "w") as fd: 22 | for key in ivoc: 23 | val = ivoc[key] 24 | fd.write(val + "\n") 25 | -------------------------------------------------------------------------------- /thumt/scripts/input_converter.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import argparse 9 | import os 10 | import random 11 | import six 12 | 13 | import tensorflow as tf 14 | 15 | 16 | def load_vocab(filename): 17 | with tf.gfile.Open(filename) as fd: 18 | count = 0 19 | vocab = {} 20 | for line in fd: 21 | word = line.strip() 22 | vocab[word] = count 23 | count += 1 24 | 25 | return vocab 26 | 27 | 28 | def to_example(dictionary): 29 | """ Convert python dictionary to tf.train.Example """ 30 | features = {} 31 | 32 | for (k, v) in six.iteritems(dictionary): 33 | if not v: 34 | raise ValueError("Empty generated field: %s", str((k, v))) 35 | 36 | if isinstance(v[0], six.integer_types): 37 | int64_list = tf.train.Int64List(value=v) 38 | features[k] = tf.train.Feature(int64_list=int64_list) 39 | elif isinstance(v[0], float): 40 | float_list = tf.train.FloatList(value=v) 41 | features[k] = tf.train.Feature(float_list=float_list) 42 | elif isinstance(v[0], six.string_types): 43 | bytes_list = tf.train.BytesList(value=v) 44 | features[k] = tf.train.Feature(bytes_list=bytes_list) 45 | else: 46 | raise ValueError("Value is neither an int nor a float; " 47 | "v: %s type: %s" % (str(v[0]), str(type(v[0])))) 48 | 49 | return tf.train.Example(features=tf.train.Features(feature=features)) 50 | 51 | 52 | def write_records(records, out_filename): 53 | """ Write to TensorFlow record """ 54 | writer = tf.python_io.TFRecordWriter(out_filename) 55 | 56 | for count, record in enumerate(records): 57 | writer.write(record) 58 | if count % 10000 == 0: 59 | tf.logging.info("write: %d", count) 60 | 61 | writer.close() 62 | 63 | 64 | def convert_to_record(inputs, vocab, output_name, output_dir, num_shards, 65 | shuffle=False): 66 | """ Convert plain parallel text to TensorFlow record """ 67 | source, target = inputs 68 | svocab, tvocab = vocab 69 | records = [] 70 | 71 | with tf.gfile.Open(source) as src: 72 | with tf.gfile.Open(target) as tgt: 73 | for sline, tline in zip(src, tgt): 74 | sline = sline.strip().split() 75 | sline = [svocab[item] if item in svocab else svocab[FLAGS.unk] 76 | for item in sline] + [svocab[FLAGS.eos]] 77 | tline = tline.strip().split() 78 | tline = [tvocab[item] if item in tvocab else tvocab[FLAGS.unk] 79 | for item in tline] + [tvocab[FLAGS.eos]] 80 | 81 | feature = { 82 | "source": sline, 83 | "target": tline, 84 | "source_length": [len(sline)], 85 | "target_length": [len(tline)] 86 | } 87 | records.append(feature) 88 | 89 | output_files = [] 90 | writers = [] 91 | 92 | for shard in xrange(num_shards): 93 | output_filename = "%s-%.5d-of-%.5d" % (output_name, shard, num_shards) 94 | output_file = os.path.join(output_dir, output_filename) 95 | output_files.append(output_file) 96 | writers.append(tf.python_io.TFRecordWriter(output_file)) 97 | 98 | counter, shard = 0, 0 99 | 100 | if shuffle: 101 | random.shuffle(records) 102 | 103 | for record in records: 104 | counter += 1 105 | example = to_example(record) 106 | writers[shard].write(example.SerializeToString()) 107 | shard = (shard + 1) % num_shards 108 | 109 | for writer in writers: 110 | writer.close() 111 | 112 | 113 | def parse_args(): 114 | msg = "convert inputs to tf.Record format" 115 | usage = "input_converter.py [] [-h | --help]" 116 | parser = argparse.ArgumentParser(description=msg, usage=usage) 117 | 118 | parser.add_argument("--input", required=True, type=str, nargs=2, 119 | help="Path of input file") 120 | parser.add_argument("--output_name", required=True, type=str, 121 | help="Output name") 122 | parser.add_argument("--output_dir", required=True, type=str, 123 | help="Output directory") 124 | parser.add_argument("--vocab", nargs=2, required=True, type=str, 125 | help="Path of vocabulary") 126 | parser.add_argument("--num_shards", default=100, type=int, 127 | help="Number of output shards") 128 | parser.add_argument("--shuffle", action="store_true", 129 | help="Shuffle inputs") 130 | parser.add_argument("--unk", default="", type=str, 131 | help="Unknown word symbol") 132 | parser.add_argument("--eos", default="", type=str, 133 | help="End of sentence symbol") 134 | 135 | return parser.parse_args() 136 | 137 | 138 | def main(_): 139 | svocab = load_vocab(FLAGS.vocab[0]) 140 | tvocab = load_vocab(FLAGS.vocab[1]) 141 | 142 | # convert data 143 | convert_to_record(FLAGS.input, [svocab, tvocab], FLAGS.output_name, 144 | FLAGS.output_dir, FLAGS.num_shards, FLAGS.shuffle) 145 | 146 | 147 | if __name__ == "__main__": 148 | FLAGS = parse_args() 149 | tf.app.run() 150 | -------------------------------------------------------------------------------- /thumt/scripts/shuffle_corpus.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import argparse 9 | import numpy 10 | 11 | 12 | def parseargs(): 13 | parser = argparse.ArgumentParser(description="Shuffle corpus") 14 | 15 | parser.add_argument("--corpus", nargs="+", required=True, 16 | help="input corpora") 17 | parser.add_argument("--suffix", type=str, default="shuf", 18 | help="Suffix of output files") 19 | parser.add_argument("--seed", type=int, help="Random seed") 20 | 21 | return parser.parse_args() 22 | 23 | 24 | def main(args): 25 | name = args.corpus 26 | suffix = "." + args.suffix 27 | stream = [open(item, "r") for item in name] 28 | data = [fd.readlines() for fd in stream] 29 | minlen = min([len(lines) for lines in data]) 30 | 31 | if args.seed: 32 | numpy.random.seed(args.seed) 33 | 34 | indices = numpy.arange(minlen) 35 | numpy.random.shuffle(indices) 36 | 37 | newstream = [open(item + suffix, "w") for item in name] 38 | 39 | for idx in indices.tolist(): 40 | lines = [item[idx] for item in data] 41 | 42 | for line, fd in zip(lines, newstream): 43 | fd.write(line) 44 | 45 | for fdr, fdw in zip(stream, newstream): 46 | fdr.close() 47 | fdw.close() 48 | 49 | 50 | if __name__ == "__main__": 51 | parsed_args = parseargs() 52 | main(parsed_args) 53 | -------------------------------------------------------------------------------- /thumt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | -------------------------------------------------------------------------------- /thumt/utils/bleu.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import math 9 | 10 | from collections import Counter 11 | 12 | 13 | def closest_length(candidate, references): 14 | clen = len(candidate) 15 | closest_diff = 9999 16 | closest_len = 9999 17 | 18 | for reference in references: 19 | rlen = len(reference) 20 | diff = abs(rlen - clen) 21 | 22 | if diff < closest_diff: 23 | closest_diff = diff 24 | closest_len = rlen 25 | elif diff == closest_diff: 26 | closest_len = rlen if rlen < closest_len else closest_len 27 | 28 | return closest_len 29 | 30 | 31 | def shortest_length(references): 32 | return min([len(ref) for ref in references]) 33 | 34 | 35 | def modified_precision(candidate, references, n): 36 | tngrams = len(candidate) + 1 - n 37 | counts = Counter([tuple(candidate[i:i+n]) for i in range(tngrams)]) 38 | 39 | if len(counts) == 0: 40 | return 0, 0 41 | 42 | max_counts = {} 43 | for reference in references: 44 | rngrams = len(reference) + 1 - n 45 | ngrams = [tuple(reference[i:i+n]) for i in range(rngrams)] 46 | ref_counts = Counter(ngrams) 47 | for ngram in counts: 48 | mcount = 0 if ngram not in max_counts else max_counts[ngram] 49 | rcount = 0 if ngram not in ref_counts else ref_counts[ngram] 50 | max_counts[ngram] = max(mcount, rcount) 51 | 52 | clipped_counts = {} 53 | 54 | for ngram, count in counts.items(): 55 | clipped_counts[ngram] = min(count, max_counts[ngram]) 56 | 57 | return float(sum(clipped_counts.values())), float(sum(counts.values())) 58 | 59 | 60 | def brevity_penalty(trans, refs, mode="closest"): 61 | bp_c = 0.0 62 | bp_r = 0.0 63 | 64 | for candidate, references in zip(trans, refs): 65 | bp_c += len(candidate) 66 | 67 | if mode == "shortest": 68 | bp_r += shortest_length(references) 69 | else: 70 | bp_r += closest_length(candidate, references) 71 | 72 | # Prevent zero divide 73 | bp_c = bp_c or 1.0 74 | 75 | return math.exp(min(0, 1.0 - bp_r / bp_c)) 76 | 77 | 78 | def bleu(trans, refs, bp="closest", smooth=False, n=4, weights=None): 79 | p_norm = [0 for _ in range(n)] 80 | p_denorm = [0 for _ in range(n)] 81 | 82 | for candidate, references in zip(trans, refs): 83 | for i in range(n): 84 | ccount, tcount = modified_precision(candidate, references, i + 1) 85 | p_norm[i] += ccount 86 | p_denorm[i] += tcount 87 | 88 | bleu_n = [0 for _ in range(n)] 89 | 90 | for i in range(n): 91 | # add one smoothing 92 | if smooth and i > 0: 93 | p_norm[i] += 1 94 | p_denorm[i] += 1 95 | 96 | if p_norm[i] == 0 or p_denorm[i] == 0: 97 | bleu_n[i] = -9999 98 | else: 99 | bleu_n[i] = math.log(float(p_norm[i]) / float(p_denorm[i])) 100 | 101 | if weights: 102 | if len(weights) != n: 103 | raise ValueError("len(weights) != n: invalid weight number") 104 | log_precision = sum([bleu_n[i] * weights[i] for i in range(n)]) 105 | else: 106 | log_precision = sum(bleu_n) / float(n) 107 | 108 | bp = brevity_penalty(trans, refs, bp) 109 | 110 | score = bp * math.exp(log_precision) 111 | 112 | return score 113 | -------------------------------------------------------------------------------- /thumt/utils/hooks.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import datetime 9 | import operator 10 | import os 11 | 12 | import tensorflow as tf 13 | import thumt.utils.bleu as bleu 14 | 15 | 16 | def _get_saver(): 17 | # Get saver from the SAVERS collection if present. 18 | collection_key = tf.GraphKeys.SAVERS 19 | savers = tf.get_collection(collection_key) 20 | 21 | if not savers: 22 | raise RuntimeError("No items in collection {}. " 23 | "Please add a saver to the collection ") 24 | elif len(savers) > 1: 25 | raise RuntimeError("More than one item in collection") 26 | 27 | return savers[0] 28 | 29 | 30 | def _save_log(filename, result): 31 | metric, global_step, score = result 32 | 33 | with open(filename, "a") as fd: 34 | time = datetime.datetime.now() 35 | msg = "%s: %s at step %d: %f\n" % (time, metric, global_step, score) 36 | fd.write(msg) 37 | 38 | 39 | def _read_checkpoint_def(filename): 40 | records = [] 41 | 42 | with tf.gfile.GFile(filename) as fd: 43 | fd.readline() 44 | 45 | for line in fd: 46 | records.append(line.strip().split(":")[-1].strip()[1:-1]) 47 | 48 | return records 49 | 50 | 51 | def _save_checkpoint_def(filename, checkpoint_names): 52 | keys = [] 53 | 54 | for checkpoint_name in checkpoint_names: 55 | step = int(checkpoint_name.strip().split("-")[-1]) 56 | keys.append((step, checkpoint_name)) 57 | 58 | sorted_names = sorted(keys, key=operator.itemgetter(0), 59 | reverse=True) 60 | 61 | with tf.gfile.GFile(filename, "w") as fd: 62 | fd.write("model_checkpoint_path: \"%s\"\n" % checkpoint_names[0]) 63 | 64 | for checkpoint_name in sorted_names: 65 | checkpoint_name = checkpoint_name[1] 66 | fd.write("all_model_checkpoint_paths: \"%s\"\n" % checkpoint_name) 67 | 68 | 69 | def _read_score_record(filename): 70 | # "checkpoint_name": score 71 | records = [] 72 | 73 | if not tf.gfile.Exists(filename): 74 | return records 75 | 76 | with tf.gfile.GFile(filename) as fd: 77 | for line in fd: 78 | name, score = line.strip().split(":") 79 | name = name.strip()[1:-1] 80 | score = float(score) 81 | records.append([name, score]) 82 | 83 | return records 84 | 85 | 86 | def _save_score_record(filename, records): 87 | keys = [] 88 | 89 | for record in records: 90 | checkpoint_name = record[0] 91 | step = int(checkpoint_name.strip().split("-")[-1]) 92 | keys.append((step, record)) 93 | 94 | sorted_keys = sorted(keys, key=operator.itemgetter(0), 95 | reverse=True) 96 | sorted_records = [item[1] for item in sorted_keys] 97 | 98 | with tf.gfile.GFile(filename, "w") as fd: 99 | for record in sorted_records: 100 | checkpoint_name, score = record 101 | fd.write("\"%s\": %f\n" % (checkpoint_name, score)) 102 | 103 | 104 | def _add_to_record(records, record, max_to_keep): 105 | added = None 106 | removed = None 107 | models = {} 108 | 109 | for (name, score) in records: 110 | models[name] = score 111 | 112 | if len(records) < max_to_keep: 113 | if record[0] not in models: 114 | added = record[0] 115 | records.append(record) 116 | else: 117 | sorted_records = sorted(records, key=lambda x: -x[1]) 118 | worst_score = sorted_records[-1][1] 119 | current_score = record[1] 120 | 121 | if current_score >= worst_score: 122 | if record[0] not in models: 123 | added = record[0] 124 | removed = sorted_records[-1][0] 125 | records = sorted_records[:-1] + [record] 126 | 127 | # Sort 128 | records = sorted(records, key=lambda x: -x[1]) 129 | 130 | return added, removed, records 131 | 132 | 133 | def _evaluate(eval_fn, input_fn, decode_fn, path, config): 134 | graph = tf.Graph() 135 | with graph.as_default(): 136 | features = input_fn() 137 | refs = features["references"] 138 | placeholders = { 139 | "source": tf.placeholder(tf.int32, [None, None], "source"), 140 | "source_length": tf.placeholder(tf.int32, [None], "source_length") 141 | } 142 | predictions = eval_fn(placeholders) 143 | predictions = predictions[0][:, 0, :] 144 | 145 | all_refs = [[] for _ in range(len(refs))] 146 | all_outputs = [] 147 | 148 | sess_creator = tf.train.ChiefSessionCreator( 149 | checkpoint_dir=path, 150 | config=config 151 | ) 152 | 153 | with tf.train.MonitoredSession(session_creator=sess_creator) as sess: 154 | while not sess.should_stop(): 155 | feats = sess.run(features) 156 | outputs = sess.run(predictions, feed_dict={ 157 | placeholders["source"]: feats["source"], 158 | placeholders["source_length"]: feats["source_length"] 159 | }) 160 | # shape: [batch, len] 161 | outputs = outputs.tolist() 162 | # shape: ([batch, len], ..., [batch, len]) 163 | references = [item.tolist() for item in feats["references"]] 164 | 165 | all_outputs.extend(outputs) 166 | 167 | for i in range(len(refs)): 168 | all_refs[i].extend(references[i]) 169 | 170 | decoded_symbols = decode_fn(all_outputs) 171 | decoded_refs = [decode_fn(refs) for refs in all_refs] 172 | decoded_refs = [list(x) for x in zip(*decoded_refs)] 173 | 174 | return bleu.bleu(decoded_symbols, decoded_refs) 175 | 176 | 177 | class EvaluationHook(tf.train.SessionRunHook): 178 | """ Validate and save checkpoints every N steps or seconds. 179 | This hook only saves checkpoint according to a specific metric. 180 | """ 181 | 182 | def __init__(self, eval_fn, eval_input_fn, eval_decode_fn, base_dir, 183 | session_config, max_to_keep=5, eval_secs=None, 184 | eval_steps=None, metric="BLEU"): 185 | """ Initializes a `EvaluationHook`. 186 | :param eval_fn: A function with signature (feature) 187 | :param eval_input_fn: A function with signature () 188 | :param eval_decode_fn: A function with signature (inputs) 189 | :param base_dir: A string. Base directory for the checkpoint files. 190 | :param session_config: An instance of tf.ConfigProto 191 | :param max_to_keep: An integer. The maximum of checkpoints to save 192 | :param eval_secs: An integer, eval every N secs. 193 | :param eval_steps: An integer, eval every N steps. 194 | :param checkpoint_basename: `str`, base name for the checkpoint files. 195 | :raises ValueError: One of `save_steps` or `save_secs` should be set. 196 | :raises ValueError: At most one of saver or scaffold should be set. 197 | """ 198 | tf.logging.info("Create EvaluationHook.") 199 | 200 | if metric != "BLEU": 201 | raise ValueError("Currently, EvaluationHook only support BLEU") 202 | 203 | self._base_dir = base_dir.rstrip("/") 204 | self._session_config = session_config 205 | self._save_path = os.path.join(base_dir, "eval") 206 | self._record_name = os.path.join(self._save_path, "record") 207 | self._log_name = os.path.join(self._save_path, "log") 208 | self._eval_fn = eval_fn 209 | self._eval_input_fn = eval_input_fn 210 | self._eval_decode_fn = eval_decode_fn 211 | self._max_to_keep = max_to_keep 212 | self._metric = metric 213 | self._global_step = None 214 | self._timer = tf.train.SecondOrStepTimer( 215 | every_secs=eval_secs or None, every_steps=eval_steps or None 216 | ) 217 | 218 | def begin(self): 219 | if self._timer.last_triggered_step() is None: 220 | self._timer.update_last_triggered_step(0) 221 | 222 | global_step = tf.train.get_global_step() 223 | 224 | if not tf.gfile.Exists(self._save_path): 225 | tf.logging.info("Making dir: %s" % self._save_path) 226 | tf.gfile.MakeDirs(self._save_path) 227 | 228 | params_pattern = os.path.join(self._base_dir, "*.json") 229 | params_files = tf.gfile.Glob(params_pattern) 230 | 231 | for name in params_files: 232 | new_name = name.replace(self._base_dir, self._save_path) 233 | tf.gfile.Copy(name, new_name, overwrite=True) 234 | 235 | if global_step is None: 236 | raise RuntimeError("Global step should be created first") 237 | 238 | self._global_step = global_step 239 | 240 | def before_run(self, run_context): 241 | args = tf.train.SessionRunArgs(self._global_step) 242 | return args 243 | 244 | def after_run(self, run_context, run_values): 245 | stale_global_step = run_values.results 246 | 247 | if self._timer.should_trigger_for_step(stale_global_step + 1): 248 | global_step = run_context.session.run(self._global_step) 249 | 250 | # Get the real value 251 | if self._timer.should_trigger_for_step(global_step): 252 | self._timer.update_last_triggered_step(global_step) 253 | # Save model 254 | save_path = os.path.join(self._base_dir, "model.ckpt") 255 | saver = _get_saver() 256 | tf.logging.info("Saving checkpoints for %d into %s." % 257 | (global_step, save_path)) 258 | saver.save(run_context.session, 259 | save_path, 260 | global_step=global_step) 261 | # Do validation here 262 | tf.logging.info("Validating model at step %d" % global_step) 263 | score = _evaluate(self._eval_fn, self._eval_input_fn, 264 | self._eval_decode_fn, 265 | self._base_dir, 266 | self._session_config) 267 | tf.logging.info("%s at step %d: %f" % 268 | (self._metric, global_step, score)) 269 | 270 | _save_log(self._log_name, (self._metric, global_step, score)) 271 | 272 | checkpoint_filename = os.path.join(self._base_dir, 273 | "checkpoint") 274 | all_checkpoints = _read_checkpoint_def(checkpoint_filename) 275 | records = _read_score_record(self._record_name) 276 | latest_checkpoint = all_checkpoints[-1] 277 | record = [latest_checkpoint, score] 278 | added, removed, records = _add_to_record(records, record, 279 | self._max_to_keep) 280 | 281 | if added is not None: 282 | old_path = os.path.join(self._base_dir, added) 283 | new_path = os.path.join(self._save_path, added) 284 | old_files = tf.gfile.Glob(old_path + "*") 285 | tf.logging.info("Copying %s to %s" % (old_path, new_path)) 286 | 287 | for o_file in old_files: 288 | n_file = o_file.replace(old_path, new_path) 289 | tf.gfile.Copy(o_file, n_file, overwrite=True) 290 | 291 | if removed is not None: 292 | filename = os.path.join(self._save_path, removed) 293 | tf.logging.info("Removing %s" % filename) 294 | files = tf.gfile.Glob(filename + "*") 295 | 296 | for name in files: 297 | tf.gfile.Remove(name) 298 | 299 | _save_score_record(self._record_name, records) 300 | checkpoint_filename = checkpoint_filename.replace( 301 | self._base_dir, self._save_path 302 | ) 303 | _save_checkpoint_def(checkpoint_filename, 304 | [item[0] for item in records]) 305 | 306 | best_score = records[0][1] 307 | tf.logging.info("Best score at step %d: %f" % 308 | (global_step, best_score)) 309 | 310 | def end(self, session): 311 | last_step = session.run(self._global_step) 312 | 313 | if last_step != self._timer.last_triggered_step(): 314 | global_step = last_step 315 | tf.logging.info("Validating model at step %d" % global_step) 316 | score = _evaluate(self._eval_fn, self._eval_input_fn, 317 | self._eval_decode_fn, 318 | self._base_dir, 319 | self._session_config) 320 | tf.logging.info("%s at step %d: %f" % 321 | (self._metric, global_step, score)) 322 | 323 | checkpoint_filename = os.path.join(self._base_dir, 324 | "checkpoint") 325 | all_checkpoints = _read_checkpoint_def(checkpoint_filename) 326 | records = _read_score_record(self._record_name) 327 | latest_checkpoint = all_checkpoints[-1] 328 | record = [latest_checkpoint, score] 329 | added, removed, records = _add_to_record(records, record, 330 | self._max_to_keep) 331 | 332 | if added is not None: 333 | old_path = os.path.join(self._base_dir, added) 334 | new_path = os.path.join(self._save_path, added) 335 | old_files = tf.gfile.Glob(old_path + "*") 336 | tf.logging.info("Copying %s to %s" % (old_path, new_path)) 337 | 338 | for o_file in old_files: 339 | n_file = o_file.replace(old_path, new_path) 340 | tf.gfile.Copy(o_file, n_file, overwrite=True) 341 | 342 | if removed is not None: 343 | filename = os.path.join(self._save_path, removed) 344 | tf.logging.info("Removing %s" % filename) 345 | files = tf.gfile.Glob(filename + "*") 346 | 347 | for name in files: 348 | tf.gfile.Remove(name) 349 | 350 | _save_score_record(self._record_name, records) 351 | checkpoint_filename = checkpoint_filename.replace( 352 | self._base_dir, self._save_path 353 | ) 354 | _save_checkpoint_def(checkpoint_filename, 355 | [item[0] for item in records]) 356 | 357 | best_score = records[0][1] 358 | tf.logging.info("Best score: %f" % best_score) 359 | -------------------------------------------------------------------------------- /thumt/utils/inference.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import copy 9 | import tensorflow as tf 10 | 11 | from collections import namedtuple 12 | from tensorflow.python.util import nest 13 | 14 | 15 | class BeamSearchState(namedtuple("BeamSearchState", 16 | ("inputs", "state", "finish"))): 17 | pass 18 | 19 | 20 | def _get_inference_fn(model_fns, features): 21 | def inference_fn(inputs, state): 22 | local_features = { 23 | "source": features["source"], 24 | "source_length": features["source_length"], 25 | # [bos_id, ...] => [..., 0] 26 | "target": tf.pad(inputs[:, 1:], [[0, 0], [0, 1]]), 27 | "target_length": tf.fill([tf.shape(inputs)[0]], 28 | tf.shape(inputs)[1]) 29 | } 30 | 31 | outputs = [] 32 | next_state = [] 33 | 34 | for (model_fn, model_state) in zip(model_fns, state): 35 | if model_state: 36 | output, new_state = model_fn(local_features, model_state) 37 | outputs.append(output) 38 | next_state.append(new_state) 39 | else: 40 | output = model_fn(local_features) 41 | outputs.append(output) 42 | next_state.append({}) 43 | 44 | # Ensemble 45 | log_prob = tf.add_n(outputs) / float(len(outputs)) 46 | 47 | return log_prob, next_state 48 | 49 | return inference_fn 50 | 51 | 52 | def _infer_shape(x): 53 | x = tf.convert_to_tensor(x) 54 | 55 | # If unknown rank, return dynamic shape 56 | if x.shape.dims is None: 57 | return tf.shape(x) 58 | 59 | static_shape = x.shape.as_list() 60 | dynamic_shape = tf.shape(x) 61 | 62 | ret = [] 63 | for i in range(len(static_shape)): 64 | dim = static_shape[i] 65 | if dim is None: 66 | dim = dynamic_shape[i] 67 | ret.append(dim) 68 | 69 | return ret 70 | 71 | 72 | def _infer_shape_invariants(tensor): 73 | shape = tensor.shape.as_list() 74 | for i in range(1, len(shape) - 1): 75 | shape[i] = None 76 | return tf.TensorShape(shape) 77 | 78 | 79 | def _merge_first_two_dims(tensor): 80 | shape = _infer_shape(tensor) 81 | shape[0] *= shape[1] 82 | shape.pop(1) 83 | return tf.reshape(tensor, shape) 84 | 85 | 86 | def _split_first_two_dims(tensor, dim_0, dim_1): 87 | shape = _infer_shape(tensor) 88 | new_shape = [dim_0] + [dim_1] + shape[1:] 89 | return tf.reshape(tensor, new_shape) 90 | 91 | 92 | def _tile_to_beam_size(tensor, beam_size): 93 | """Tiles a given tensor by beam_size. """ 94 | tensor = tf.expand_dims(tensor, axis=1) 95 | tile_dims = [1] * tensor.shape.ndims 96 | tile_dims[1] = beam_size 97 | 98 | return tf.tile(tensor, tile_dims) 99 | 100 | 101 | def _gather_2d(params, indices, name=None): 102 | """ Gather the 2nd dimension given indices 103 | :param params: A tensor with shape [batch_size, M, ...] 104 | :param indices: A tensor with shape [batch_size, N] 105 | :return: A tensor with shape [batch_size, N, ...] 106 | """ 107 | batch_size = tf.shape(params)[0] 108 | range_size = tf.shape(indices)[1] 109 | batch_pos = tf.range(batch_size * range_size) // range_size 110 | batch_pos = tf.reshape(batch_pos, [batch_size, range_size]) 111 | indices = tf.stack([batch_pos, indices], axis=-1) 112 | output = tf.gather_nd(params, indices, name=name) 113 | 114 | return output 115 | 116 | 117 | def _beam_search_step(time, func, state, batch_size, beam_size, alpha, 118 | pad_id, eos_id): 119 | # Compute log probabilities 120 | seqs, log_probs = state.inputs[:2] 121 | flat_seqs = _merge_first_two_dims(seqs) 122 | flat_state = nest.map_structure(lambda x: _merge_first_two_dims(x), 123 | state.state) 124 | step_log_probs, next_state = func(flat_seqs, flat_state) 125 | step_log_probs = _split_first_two_dims(step_log_probs, batch_size, 126 | beam_size) 127 | next_state = nest.map_structure( 128 | lambda x: _split_first_two_dims(x, batch_size, beam_size), next_state) 129 | curr_log_probs = tf.expand_dims(log_probs, 2) + step_log_probs 130 | 131 | # Apply length penalty 132 | length_penalty = tf.pow((5.0 + tf.to_float(time + 1)) / 6.0, alpha) 133 | curr_scores = curr_log_probs / length_penalty 134 | vocab_size = curr_scores.shape[-1].value or tf.shape(curr_scores)[-1] 135 | 136 | # Select top-k candidates 137 | # [batch_size, beam_size * vocab_size] 138 | curr_scores = tf.reshape(curr_scores, [-1, beam_size * vocab_size]) 139 | # [batch_size, 2 * beam_size] 140 | top_scores, top_indices = tf.nn.top_k(curr_scores, k=2 * beam_size) 141 | # Shape: [batch_size, 2 * beam_size] 142 | beam_indices = top_indices // vocab_size 143 | symbol_indices = top_indices % vocab_size 144 | # Expand sequences 145 | # [batch_size, 2 * beam_size, time] 146 | candidate_seqs = _gather_2d(seqs, beam_indices) 147 | candidate_seqs = tf.concat([candidate_seqs, 148 | tf.expand_dims(symbol_indices, 2)], 2) 149 | 150 | # Expand sequences 151 | # Suppress finished sequences 152 | flags = tf.equal(symbol_indices, eos_id) 153 | # [batch, 2 * beam_size] 154 | alive_scores = top_scores + tf.to_float(flags) * tf.float32.min 155 | # [batch, beam_size] 156 | alive_scores, alive_indices = tf.nn.top_k(alive_scores, beam_size) 157 | alive_symbols = _gather_2d(symbol_indices, alive_indices) 158 | alive_indices = _gather_2d(beam_indices, alive_indices) 159 | alive_seqs = _gather_2d(seqs, alive_indices) 160 | # [batch_size, beam_size, time + 1] 161 | alive_seqs = tf.concat([alive_seqs, tf.expand_dims(alive_symbols, 2)], 2) 162 | alive_state = nest.map_structure(lambda x: _gather_2d(x, alive_indices), 163 | next_state) 164 | alive_log_probs = alive_scores * length_penalty 165 | 166 | # Select finished sequences 167 | prev_fin_flags, prev_fin_seqs, prev_fin_scores = state.finish 168 | # [batch, 2 * beam_size] 169 | step_fin_scores = top_scores + (1.0 - tf.to_float(flags)) * tf.float32.min 170 | # [batch, 3 * beam_size] 171 | fin_flags = tf.concat([prev_fin_flags, flags], axis=1) 172 | fin_scores = tf.concat([prev_fin_scores, step_fin_scores], axis=1) 173 | # [batch, beam_size] 174 | fin_scores, fin_indices = tf.nn.top_k(fin_scores, beam_size) 175 | fin_flags = _gather_2d(fin_flags, fin_indices) 176 | pad_seqs = tf.fill([batch_size, beam_size, 1], 177 | tf.constant(pad_id, tf.int32)) 178 | prev_fin_seqs = tf.concat([prev_fin_seqs, pad_seqs], axis=2) 179 | fin_seqs = tf.concat([prev_fin_seqs, candidate_seqs], axis=1) 180 | fin_seqs = _gather_2d(fin_seqs, fin_indices) 181 | 182 | new_state = BeamSearchState( 183 | inputs=(alive_seqs, alive_log_probs, alive_scores), 184 | state=alive_state, 185 | finish=(fin_flags, fin_seqs, fin_scores), 186 | ) 187 | 188 | return time + 1, new_state 189 | 190 | 191 | def beam_search(func, state, batch_size, beam_size, max_length, alpha, 192 | pad_id, bos_id, eos_id): 193 | init_seqs = tf.fill([batch_size, beam_size, 1], bos_id) 194 | init_log_probs = tf.constant([[0.] + [tf.float32.min] * (beam_size - 1)]) 195 | init_log_probs = tf.tile(init_log_probs, [batch_size, 1]) 196 | init_scores = tf.zeros_like(init_log_probs) 197 | fin_seqs = tf.zeros([batch_size, beam_size, 1], tf.int32) 198 | fin_scores = tf.fill([batch_size, beam_size], tf.float32.min) 199 | fin_flags = tf.zeros([batch_size, beam_size], tf.bool) 200 | 201 | state = BeamSearchState( 202 | inputs=(init_seqs, init_log_probs, init_scores), 203 | state=state, 204 | finish=(fin_flags, fin_seqs, fin_scores), 205 | ) 206 | 207 | max_step = tf.reduce_max(max_length) 208 | 209 | def _is_finished(t, s): 210 | log_probs = s.inputs[1] 211 | finished_flags = s.finish[0] 212 | finished_scores = s.finish[2] 213 | max_lp = tf.pow(((5.0 + tf.to_float(max_step)) / 6.0), alpha) 214 | best_alive_score = log_probs[:, 0] / max_lp 215 | worst_finished_score = tf.reduce_min( 216 | finished_scores * tf.to_float(finished_flags), axis=1) 217 | add_mask = 1.0 - tf.to_float(tf.reduce_any(finished_flags, 1)) 218 | worst_finished_score += tf.float32.min * add_mask 219 | bound_is_met = tf.reduce_all(tf.greater(worst_finished_score, 220 | best_alive_score)) 221 | 222 | cond = tf.logical_and(tf.less(t, max_step), 223 | tf.logical_not(bound_is_met)) 224 | 225 | return cond 226 | 227 | def _loop_fn(t, s): 228 | outs = _beam_search_step(t, func, s, batch_size, beam_size, alpha, 229 | pad_id, eos_id) 230 | return outs 231 | 232 | time = tf.constant(0, name="time") 233 | shape_invariants = BeamSearchState( 234 | inputs=(tf.TensorShape([None, None, None]), 235 | tf.TensorShape([None, None]), 236 | tf.TensorShape([None, None])), 237 | state=nest.map_structure(_infer_shape_invariants, state.state), 238 | finish=(tf.TensorShape([None, None]), 239 | tf.TensorShape([None, None, None]), 240 | tf.TensorShape([None, None])) 241 | ) 242 | outputs = tf.while_loop(_is_finished, _loop_fn, [time, state], 243 | shape_invariants=[tf.TensorShape([]), 244 | shape_invariants], 245 | parallel_iterations=1, 246 | back_prop=False) 247 | 248 | final_state = outputs[1] 249 | alive_seqs = final_state.inputs[0] 250 | alive_scores = final_state.inputs[2] 251 | final_flags = final_state.finish[0] 252 | final_seqs = final_state.finish[1] 253 | final_scores = final_state.finish[2] 254 | 255 | alive_seqs.set_shape([None, beam_size, None]) 256 | final_seqs.set_shape((None, beam_size, None)) 257 | 258 | final_seqs = tf.where(tf.reduce_any(final_flags, 1), final_seqs, 259 | alive_seqs) 260 | final_scores = tf.where(tf.reduce_any(final_flags, 1), final_scores, 261 | alive_scores) 262 | 263 | return final_seqs, final_scores 264 | 265 | 266 | def create_inference_graph(model_fns, features, params): 267 | if not isinstance(model_fns, (list, tuple)): 268 | raise ValueError("mode_fns must be a list or tuple") 269 | 270 | features = copy.copy(features) 271 | 272 | decode_length = params.decode_length 273 | beam_size = params.beam_size 274 | top_beams = params.top_beams 275 | alpha = params.decode_alpha 276 | 277 | # Compute initial state if necessary 278 | states = [] 279 | funcs = [] 280 | 281 | for model_fn in model_fns: 282 | if callable(model_fn): 283 | # For non-incremental decoding 284 | states.append({}) 285 | funcs.append(model_fn) 286 | else: 287 | # For incremental decoding where model_fn is a tuple: 288 | # (encoding_fn, decoding_fn) 289 | states.append(model_fn[0](features)) 290 | funcs.append(model_fn[1]) 291 | 292 | batch_size = tf.shape(features["source"])[0] 293 | pad_id = params.mapping["target"][params.pad] 294 | bos_id = params.mapping["target"][params.bos] 295 | eos_id = params.mapping["target"][params.eos] 296 | 297 | # Expand the inputs in to the beam size 298 | # [batch, length] => [batch, beam_size, length] 299 | features["source"] = tf.expand_dims(features["source"], 1) 300 | features["source"] = tf.tile(features["source"], [1, beam_size, 1]) 301 | shape = tf.shape(features["source"]) 302 | 303 | # [batch, beam_size, length] => [batch * beam_size, length] 304 | features["source"] = tf.reshape(features["source"], 305 | [shape[0] * shape[1], shape[2]]) 306 | 307 | # For source sequence length 308 | features["source_length"] = tf.expand_dims(features["source_length"], 1) 309 | features["source_length"] = tf.tile(features["source_length"], 310 | [1, beam_size]) 311 | shape = tf.shape(features["source_length"]) 312 | 313 | max_length = features["source_length"] + decode_length 314 | 315 | # [batch, beam_size, length] => [batch * beam_size, length] 316 | features["source_length"] = tf.reshape(features["source_length"], 317 | [shape[0] * shape[1]]) 318 | decoding_fn = _get_inference_fn(funcs, features) 319 | states = nest.map_structure(lambda x: _tile_to_beam_size(x, beam_size), 320 | states) 321 | 322 | seqs, scores = beam_search(decoding_fn, states, batch_size, beam_size, 323 | max_length, alpha, pad_id, bos_id, eos_id) 324 | 325 | return seqs[:, :top_beams, 1:], scores[:, :top_beams] 326 | -------------------------------------------------------------------------------- /thumt/utils/inference_ctx.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import copy 9 | import tensorflow as tf 10 | 11 | from collections import namedtuple 12 | from tensorflow.python.util import nest 13 | 14 | 15 | class BeamSearchState(namedtuple("BeamSearchState", 16 | ("inputs", "state", "finish"))): 17 | pass 18 | 19 | 20 | def _get_inference_fn(model_fns, features): 21 | def inference_fn(inputs, state): 22 | local_features = { 23 | "source": features["source"], 24 | "source_length": features["source_length"], 25 | "context": features["context"], 26 | "context_length": features["context_length"], 27 | # [bos_id, ...] => [..., 0] 28 | "target": tf.pad(inputs[:, 1:], [[0, 0], [0, 1]]), 29 | "target_length": tf.fill([tf.shape(inputs)[0]], 30 | tf.shape(inputs)[1]) 31 | } 32 | 33 | outputs = [] 34 | next_state = [] 35 | 36 | for (model_fn, model_state) in zip(model_fns, state): 37 | if model_state: 38 | output, new_state = model_fn(local_features, model_state) 39 | outputs.append(output) 40 | next_state.append(new_state) 41 | else: 42 | output = model_fn(local_features) 43 | outputs.append(output) 44 | next_state.append({}) 45 | 46 | # Ensemble 47 | log_prob = tf.add_n(outputs) / float(len(outputs)) 48 | 49 | return log_prob, next_state 50 | 51 | return inference_fn 52 | 53 | 54 | def _infer_shape(x): 55 | x = tf.convert_to_tensor(x) 56 | 57 | # If unknown rank, return dynamic shape 58 | if x.shape.dims is None: 59 | return tf.shape(x) 60 | 61 | static_shape = x.shape.as_list() 62 | dynamic_shape = tf.shape(x) 63 | 64 | ret = [] 65 | for i in range(len(static_shape)): 66 | dim = static_shape[i] 67 | if dim is None: 68 | dim = dynamic_shape[i] 69 | ret.append(dim) 70 | 71 | return ret 72 | 73 | 74 | def _infer_shape_invariants(tensor): 75 | shape = tensor.shape.as_list() 76 | for i in range(1, len(shape) - 1): 77 | shape[i] = None 78 | return tf.TensorShape(shape) 79 | 80 | 81 | def _merge_first_two_dims(tensor): 82 | shape = _infer_shape(tensor) 83 | shape[0] *= shape[1] 84 | shape.pop(1) 85 | return tf.reshape(tensor, shape) 86 | 87 | 88 | def _split_first_two_dims(tensor, dim_0, dim_1): 89 | shape = _infer_shape(tensor) 90 | new_shape = [dim_0] + [dim_1] + shape[1:] 91 | return tf.reshape(tensor, new_shape) 92 | 93 | 94 | def _tile_to_beam_size(tensor, beam_size): 95 | """Tiles a given tensor by beam_size. """ 96 | tensor = tf.expand_dims(tensor, axis=1) 97 | tile_dims = [1] * tensor.shape.ndims 98 | tile_dims[1] = beam_size 99 | 100 | return tf.tile(tensor, tile_dims) 101 | 102 | 103 | def _gather_2d(params, indices, name=None): 104 | """ Gather the 2nd dimension given indices 105 | :param params: A tensor with shape [batch_size, M, ...] 106 | :param indices: A tensor with shape [batch_size, N] 107 | :return: A tensor with shape [batch_size, N, ...] 108 | """ 109 | batch_size = tf.shape(params)[0] 110 | range_size = tf.shape(indices)[1] 111 | batch_pos = tf.range(batch_size * range_size) // range_size 112 | batch_pos = tf.reshape(batch_pos, [batch_size, range_size]) 113 | indices = tf.stack([batch_pos, indices], axis=-1) 114 | output = tf.gather_nd(params, indices, name=name) 115 | 116 | return output 117 | 118 | 119 | def _beam_search_step(time, func, state, batch_size, beam_size, alpha, 120 | pad_id, eos_id): 121 | # Compute log probabilities 122 | print('st2', state) 123 | seqs, log_probs = state.inputs[:2] 124 | flat_seqs = _merge_first_two_dims(seqs) 125 | flat_state = nest.map_structure(lambda x: _merge_first_two_dims(x), 126 | state.state) 127 | print('st3', flat_state) 128 | step_log_probs, next_state = func(flat_seqs, flat_state) 129 | step_log_probs = _split_first_two_dims(step_log_probs, batch_size, 130 | beam_size) 131 | print('st4', next_state) 132 | next_state = nest.map_structure( 133 | lambda x: _split_first_two_dims(x, batch_size, beam_size), next_state) 134 | curr_log_probs = tf.expand_dims(log_probs, 2) + step_log_probs 135 | 136 | # Apply length penalty 137 | length_penalty = tf.pow((5.0 + tf.to_float(time + 1)) / 6.0, alpha) 138 | curr_scores = curr_log_probs / length_penalty 139 | vocab_size = curr_scores.shape[-1].value or tf.shape(curr_scores)[-1] 140 | 141 | # Select top-k candidates 142 | # [batch_size, beam_size * vocab_size] 143 | curr_scores = tf.reshape(curr_scores, [-1, beam_size * vocab_size]) 144 | # [batch_size, 2 * beam_size] 145 | top_scores, top_indices = tf.nn.top_k(curr_scores, k=2 * beam_size) 146 | # Shape: [batch_size, 2 * beam_size] 147 | beam_indices = top_indices // vocab_size 148 | symbol_indices = top_indices % vocab_size 149 | # Expand sequences 150 | # [batch_size, 2 * beam_size, time] 151 | candidate_seqs = _gather_2d(seqs, beam_indices) 152 | candidate_seqs = tf.concat([candidate_seqs, 153 | tf.expand_dims(symbol_indices, 2)], 2) 154 | 155 | # Expand sequences 156 | # Suppress finished sequences 157 | flags = tf.equal(symbol_indices, eos_id) 158 | # [batch, 2 * beam_size] 159 | alive_scores = top_scores + tf.to_float(flags) * tf.float32.min 160 | # [batch, beam_size] 161 | alive_scores, alive_indices = tf.nn.top_k(alive_scores, beam_size) 162 | alive_symbols = _gather_2d(symbol_indices, alive_indices) 163 | alive_indices = _gather_2d(beam_indices, alive_indices) 164 | alive_seqs = _gather_2d(seqs, alive_indices) 165 | # [batch_size, beam_size, time + 1] 166 | alive_seqs = tf.concat([alive_seqs, tf.expand_dims(alive_symbols, 2)], 2) 167 | alive_state = nest.map_structure(lambda x: _gather_2d(x, alive_indices), 168 | next_state) 169 | print('st5', alive_state) 170 | alive_log_probs = alive_scores * length_penalty 171 | 172 | # Select finished sequences 173 | prev_fin_flags, prev_fin_seqs, prev_fin_scores = state.finish 174 | # [batch, 2 * beam_size] 175 | step_fin_scores = top_scores + (1.0 - tf.to_float(flags)) * tf.float32.min 176 | # [batch, 3 * beam_size] 177 | fin_flags = tf.concat([prev_fin_flags, flags], axis=1) 178 | fin_scores = tf.concat([prev_fin_scores, step_fin_scores], axis=1) 179 | # [batch, beam_size] 180 | fin_scores, fin_indices = tf.nn.top_k(fin_scores, beam_size) 181 | fin_flags = _gather_2d(fin_flags, fin_indices) 182 | pad_seqs = tf.fill([batch_size, beam_size, 1], 183 | tf.constant(pad_id, tf.int32)) 184 | prev_fin_seqs = tf.concat([prev_fin_seqs, pad_seqs], axis=2) 185 | fin_seqs = tf.concat([prev_fin_seqs, candidate_seqs], axis=1) 186 | fin_seqs = _gather_2d(fin_seqs, fin_indices) 187 | 188 | new_state = BeamSearchState( 189 | inputs=(alive_seqs, alive_log_probs, alive_scores), 190 | state=alive_state, 191 | finish=(fin_flags, fin_seqs, fin_scores), 192 | ) 193 | 194 | return time + 1, new_state 195 | 196 | 197 | def beam_search(func, state, batch_size, beam_size, max_length, alpha, 198 | pad_id, bos_id, eos_id): 199 | init_seqs = tf.fill([batch_size, beam_size, 1], bos_id) 200 | init_log_probs = tf.constant([[0.] + [tf.float32.min] * (beam_size - 1)]) 201 | init_log_probs = tf.tile(init_log_probs, [batch_size, 1]) 202 | init_scores = tf.zeros_like(init_log_probs) 203 | fin_seqs = tf.zeros([batch_size, beam_size, 1], tf.int32) 204 | fin_scores = tf.fill([batch_size, beam_size], tf.float32.min) 205 | fin_flags = tf.zeros([batch_size, beam_size], tf.bool) 206 | 207 | state = BeamSearchState( 208 | inputs=(init_seqs, init_log_probs, init_scores), 209 | state=state, 210 | finish=(fin_flags, fin_seqs, fin_scores), 211 | ) 212 | print('st1',state) 213 | 214 | max_step = tf.reduce_max(max_length) 215 | 216 | def _is_finished(t, s): 217 | log_probs = s.inputs[1] 218 | finished_flags = s.finish[0] 219 | finished_scores = s.finish[2] 220 | max_lp = tf.pow(((5.0 + tf.to_float(max_step)) / 6.0), alpha) 221 | best_alive_score = log_probs[:, 0] / max_lp 222 | worst_finished_score = tf.reduce_min( 223 | finished_scores * tf.to_float(finished_flags), axis=1) 224 | add_mask = 1.0 - tf.to_float(tf.reduce_any(finished_flags, 1)) 225 | worst_finished_score += tf.float32.min * add_mask 226 | bound_is_met = tf.reduce_all(tf.greater(worst_finished_score, 227 | best_alive_score)) 228 | 229 | cond = tf.logical_and(tf.less(t, max_step), 230 | tf.logical_not(bound_is_met)) 231 | 232 | return cond 233 | 234 | def _loop_fn(t, s): 235 | outs = _beam_search_step(t, func, s, batch_size, beam_size, alpha, 236 | pad_id, eos_id) 237 | return outs 238 | 239 | time = tf.constant(0, name="time") 240 | shape_invariants = BeamSearchState( 241 | inputs=(tf.TensorShape([None, None, None]), 242 | tf.TensorShape([None, None]), 243 | tf.TensorShape([None, None])), 244 | state=nest.map_structure(_infer_shape_invariants, state.state), 245 | finish=(tf.TensorShape([None, None]), 246 | tf.TensorShape([None, None, None]), 247 | tf.TensorShape([None, None])) 248 | ) 249 | outputs = tf.while_loop(_is_finished, _loop_fn, [time, state], 250 | shape_invariants=[tf.TensorShape([]), 251 | shape_invariants], 252 | parallel_iterations=1, 253 | back_prop=False) 254 | 255 | final_state = outputs[1] 256 | alive_seqs = final_state.inputs[0] 257 | alive_scores = final_state.inputs[2] 258 | final_flags = final_state.finish[0] 259 | final_seqs = final_state.finish[1] 260 | final_scores = final_state.finish[2] 261 | 262 | alive_seqs.set_shape([None, beam_size, None]) 263 | final_seqs.set_shape((None, beam_size, None)) 264 | 265 | final_seqs = tf.where(tf.reduce_any(final_flags, 1), final_seqs, 266 | alive_seqs) 267 | final_scores = tf.where(tf.reduce_any(final_flags, 1), final_scores, 268 | alive_scores) 269 | 270 | return final_seqs, final_scores 271 | 272 | 273 | def create_inference_graph(model_fns, features, params): 274 | if not isinstance(model_fns, (list, tuple)): 275 | raise ValueError("mode_fns must be a list or tuple") 276 | 277 | features = copy.copy(features) 278 | 279 | decode_length = params.decode_length 280 | beam_size = params.beam_size 281 | top_beams = params.top_beams 282 | alpha = params.decode_alpha 283 | 284 | # Compute initial state if necessary 285 | states = [] 286 | funcs = [] 287 | 288 | for model_fn in model_fns: 289 | if callable(model_fn): 290 | # For non-incremental decoding 291 | states.append({}) 292 | funcs.append(model_fn) 293 | else: 294 | # For incremental decoding where model_fn is a tuple: 295 | # (encoding_fn, decoding_fn) 296 | states.append(model_fn[0](features)) 297 | funcs.append(model_fn[1]) 298 | 299 | batch_size = tf.shape(features["source"])[0] 300 | pad_id = params.mapping["target"][params.pad] 301 | bos_id = params.mapping["target"][params.bos] 302 | eos_id = params.mapping["target"][params.eos] 303 | 304 | # Expand the inputs in to the beam size 305 | # [batch, length] => [batch, beam_size, length] 306 | features["source"] = tf.expand_dims(features["source"], 1) 307 | features["source"] = tf.tile(features["source"], [1, beam_size, 1]) 308 | shape = tf.shape(features["source"]) 309 | 310 | # [batch, beam_size, length] => [batch * beam_size, length] 311 | features["source"] = tf.reshape(features["source"], 312 | [shape[0] * shape[1], shape[2]]) 313 | 314 | # For source sequence length 315 | features["source_length"] = tf.expand_dims(features["source_length"], 1) 316 | features["source_length"] = tf.tile(features["source_length"], 317 | [1, beam_size]) 318 | shape = tf.shape(features["source_length"]) 319 | 320 | max_length = features["source_length"] + decode_length 321 | 322 | # [batch, beam_size, length] => [batch * beam_size, length] 323 | features["source_length"] = tf.reshape(features["source_length"], 324 | [shape[0] * shape[1]]) 325 | 326 | ###### 327 | # Expand the inputs in to the beam size 328 | # [batch, length] => [batch, beam_size, length] 329 | features["context"] = tf.expand_dims(features["context"], 1) 330 | features["context"] = tf.tile(features["context"], [1, beam_size, 1]) 331 | shape = tf.shape(features["context"]) 332 | 333 | # [batch, beam_size, length] => [batch * beam_size, length] 334 | features["context"] = tf.reshape(features["context"], 335 | [shape[0] * shape[1], shape[2]]) 336 | 337 | # For context sequence length 338 | features["context_length"] = tf.expand_dims(features["context_length"], 1) 339 | features["context_length"] = tf.tile(features["context_length"], 340 | [1, beam_size]) 341 | shape = tf.shape(features["context_length"]) 342 | 343 | # [batch, beam_size, length] => [batch * beam_size, length] 344 | features["context_length"] = tf.reshape(features["context_length"], 345 | [shape[0] * shape[1]]) 346 | 347 | 348 | decoding_fn = _get_inference_fn(funcs, features) 349 | states = nest.map_structure(lambda x: _tile_to_beam_size(x, beam_size), 350 | states) 351 | 352 | seqs, scores = beam_search(decoding_fn, states, batch_size, beam_size, 353 | max_length, alpha, pad_id, bos_id, eos_id) 354 | 355 | return seqs[:, :top_beams, 1:], scores[:, :top_beams] 356 | -------------------------------------------------------------------------------- /thumt/utils/optimize.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import tensorflow as tf 9 | 10 | 11 | def _get_loss_variable(graph=None): 12 | graph = graph or tf.get_default_graph() 13 | loss_tensors = tf.get_collection("loss") 14 | 15 | if len(loss_tensors) == 1: 16 | loss_tensor = loss_tensors[0] 17 | elif not loss_tensors: 18 | try: 19 | loss_tensor = graph.get_tensor_by_name("loss_tensor:0") 20 | except KeyError: 21 | return None 22 | else: 23 | tf.logging.error("Multiple tensors in loss collection.") 24 | return None 25 | 26 | return loss_tensor 27 | 28 | 29 | def _create_loss_variable(graph=None): 30 | graph = graph or tf.get_default_graph() 31 | if _get_loss_variable(graph) is not None: 32 | raise ValueError("'loss' already exists.") 33 | 34 | # Create in proper graph and base name_scope. 35 | with graph.as_default() as g, g.name_scope(None): 36 | tensor = tf.get_variable("loss", shape=[], dtype=tf.float32, 37 | initializer=tf.zeros_initializer(), 38 | trainable=False, 39 | collections=[tf.GraphKeys.GLOBAL_VARIABLES, 40 | "loss"]) 41 | 42 | return tensor 43 | 44 | 45 | def _get_or_create_loss_variable(graph=None): 46 | graph = graph or tf.get_default_graph() 47 | loss_tensor = _get_loss_variable(graph) 48 | if loss_tensor is None: 49 | loss_tensor = _create_loss_variable(graph) 50 | return loss_tensor 51 | 52 | 53 | def _zero_variables(variables, name=None): 54 | ops = [] 55 | 56 | for var in variables: 57 | with tf.device(var.device): 58 | op = var.assign(tf.zeros(var.shape.as_list())) 59 | ops.append(op) 60 | 61 | return tf.group(*ops, name=name or "zero_variables") 62 | 63 | 64 | def _replicate_variables(variables, device=None): 65 | new_vars = [] 66 | 67 | for var in variables: 68 | device = device or var.device 69 | with tf.device(device): 70 | name = var.name.split(":")[0].rstrip("/") + "/replica" 71 | new_vars.append(tf.Variable(tf.zeros(var.shape.as_list()), 72 | name=name, trainable=False)) 73 | 74 | return new_vars 75 | 76 | 77 | def _collect_gradients(gradients, variables): 78 | ops = [] 79 | 80 | for grad, var in zip(gradients, variables): 81 | if isinstance(grad, tf.Tensor): 82 | ops.append(tf.assign_add(var, grad)) 83 | else: 84 | ops.append(tf.scatter_add(var, grad.indices, grad.values)) 85 | 86 | return tf.group(*ops, name="collect_gradients") 87 | 88 | 89 | def _scale_variables(variables, scale): 90 | if not isinstance(variables, (list, tuple)): 91 | return tf.assign(variables, scale * variables) 92 | 93 | ops = [] 94 | 95 | for var in variables: 96 | ops.append(tf.assign(var, scale * var)) 97 | 98 | return tf.group(*ops, name="scale_variables") 99 | 100 | 101 | def create_train_op(loss, optimizer, global_step, params): 102 | with tf.name_scope("create_train_op"): 103 | grads_and_vars = optimizer.compute_gradients( 104 | loss, colocate_gradients_with_ops=True) 105 | gradients = [item[0] for item in grads_and_vars] 106 | variables = [item[1] for item in grads_and_vars] 107 | 108 | if params.update_cycle == 1: 109 | zero_variables_op = tf.no_op("zero_variables") 110 | collect_op = tf.no_op("collect_op") 111 | scale_op = tf.no_op("scale_op") 112 | else: 113 | # collect 114 | loss_tensor = _get_or_create_loss_variable() 115 | slot_variables = _replicate_variables(variables) 116 | zero_variables_op = _zero_variables(slot_variables + [loss_tensor]) 117 | collect_grads_op = _collect_gradients(gradients, slot_variables) 118 | collect_loss_op = tf.assign_add(loss_tensor, loss) 119 | collect_op = tf.group(collect_loss_op, collect_grads_op, 120 | name="collect_op") 121 | # scale 122 | scale = 1.0 / params.update_cycle 123 | scale_grads_op = _scale_variables(slot_variables, scale) 124 | scale_loss_op = _scale_variables(loss_tensor, scale) 125 | scale_op = tf.group(scale_grads_op, scale_loss_op, name="scale_op") 126 | gradients = slot_variables 127 | loss = tf.convert_to_tensor(loss_tensor) 128 | 129 | # Add summaries 130 | tf.summary.scalar("loss", loss) 131 | tf.summary.scalar("global_norm/gradient_norm", 132 | tf.global_norm(gradients)) 133 | 134 | for gradient, variable in zip(gradients, variables): 135 | if isinstance(gradient, tf.IndexedSlices): 136 | grad_values = gradient.values 137 | else: 138 | grad_values = gradient 139 | 140 | if grad_values is not None: 141 | var_name = variable.name.replace(":", "_") 142 | tf.summary.histogram("gradients/%s" % var_name, grad_values) 143 | tf.summary.scalar("gradient_norm/%s" % var_name, 144 | tf.global_norm([grad_values])) 145 | 146 | # Gradient clipping 147 | if isinstance(params.clip_grad_norm or None, float): 148 | gradients, _ = tf.clip_by_global_norm(gradients, 149 | params.clip_grad_norm) 150 | 151 | # Update variables 152 | grads_and_vars = list(zip(gradients, tf.trainable_variables())) 153 | train_op = optimizer.apply_gradients(grads_and_vars, global_step) 154 | 155 | ops = { 156 | "zero_op": zero_variables_op, 157 | "collect_op": collect_op, 158 | "scale_op": scale_op, 159 | "train_op": train_op 160 | } 161 | 162 | return loss, ops 163 | -------------------------------------------------------------------------------- /thumt/utils/parallel.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import operator 9 | 10 | import tensorflow as tf 11 | 12 | 13 | class GPUParamServerDeviceSetter(object): 14 | 15 | def __init__(self, worker_device, ps_devices): 16 | self.ps_devices = ps_devices 17 | self.worker_device = worker_device 18 | self.ps_sizes = [0] * len(self.ps_devices) 19 | 20 | def __call__(self, op): 21 | if op.device: 22 | return op.device 23 | if op.type not in ["Variable", "VariableV2", "VarHandleOp"]: 24 | return self.worker_device 25 | 26 | # Gets the least loaded ps_device 27 | device_index, _ = min(enumerate(self.ps_sizes), 28 | key=operator.itemgetter(1)) 29 | device_name = self.ps_devices[device_index] 30 | var_size = op.outputs[0].get_shape().num_elements() 31 | self.ps_sizes[device_index] += var_size 32 | 33 | return device_name 34 | 35 | 36 | def _maybe_repeat(x, n): 37 | if isinstance(x, list): 38 | assert len(x) == n 39 | return x 40 | else: 41 | return [x] * n 42 | 43 | 44 | def _create_device_setter(is_cpu_ps, worker, num_gpus): 45 | if is_cpu_ps: 46 | # tf.train.replica_device_setter supports placing variables on the CPU, 47 | # all on one GPU, or on ps_servers defined in a cluster_spec. 48 | return tf.train.replica_device_setter( 49 | worker_device=worker, ps_device="/cpu:0", ps_tasks=1) 50 | else: 51 | gpus = ["/gpu:%d" % i for i in range(num_gpus)] 52 | return GPUParamServerDeviceSetter(worker, gpus) 53 | 54 | 55 | # Data-level parallelism 56 | def data_parallelism(devices, fn, *args, **kwargs): 57 | num_worker = len(devices) 58 | 59 | # Replicate args and kwargs 60 | if args: 61 | new_args = [_maybe_repeat(arg, num_worker) for arg in args] 62 | # Transpose 63 | new_args = [list(x) for x in zip(*new_args)] 64 | else: 65 | new_args = [[] for _ in range(num_worker)] 66 | 67 | new_kwargs = [{} for _ in range(num_worker)] 68 | 69 | for k, v in kwargs.iteritems(): 70 | vals = _maybe_repeat(v, num_worker) 71 | 72 | for i in range(num_worker): 73 | new_kwargs[i][k] = vals[i] 74 | 75 | fns = _maybe_repeat(fn, num_worker) 76 | 77 | # Now make the parallel call. 78 | outputs = [] 79 | 80 | for i in range(num_worker): 81 | worker = "/gpu:%d" % i 82 | device_setter = _create_device_setter(False, worker, len(devices)) 83 | with tf.variable_scope(tf.get_variable_scope(), reuse=(i != 0)): 84 | with tf.name_scope("parallel_%d" % i): 85 | with tf.device(device_setter): 86 | outputs.append(fns[i](*new_args[i], **new_kwargs[i])) 87 | 88 | if isinstance(outputs[0], tuple): 89 | outputs = list(zip(*outputs)) 90 | outputs = tuple([list(o) for o in outputs]) 91 | 92 | return outputs 93 | 94 | 95 | def shard_features(features, device_list): 96 | num_datashards = len(device_list) 97 | 98 | sharded_features = {} 99 | 100 | for k, v in features.iteritems(): 101 | v = tf.convert_to_tensor(v) 102 | if not v.shape.as_list(): 103 | v = tf.expand_dims(v, axis=-1) 104 | v = tf.tile(v, [num_datashards]) 105 | with tf.device(v.device): 106 | sharded_features[k] = tf.split(v, num_datashards, 0) 107 | 108 | datashard_to_features = [] 109 | 110 | for d in range(num_datashards): 111 | feat = { 112 | k: v[d] for k, v in sharded_features.iteritems() 113 | } 114 | datashard_to_features.append(feat) 115 | 116 | return datashard_to_features 117 | 118 | 119 | def parallel_model(model_fn, features, devices, use_cpu=False): 120 | devices = ["gpu:%d" % d for d in devices] 121 | 122 | if use_cpu: 123 | devices += ["cpu:0"] 124 | 125 | if len(devices) == 1: 126 | return [model_fn(features)] 127 | 128 | features = shard_features(features, devices) 129 | 130 | outputs = data_parallelism(devices, model_fn, features) 131 | return outputs 132 | -------------------------------------------------------------------------------- /thumt/utils/sample.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | -------------------------------------------------------------------------------- /thumt/utils/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The THUMT Authors 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import tensorflow as tf 9 | 10 | 11 | def session_run(monitored_session, args): 12 | # Call raw TF session directly 13 | return monitored_session._tf_sess().run(args) 14 | --------------------------------------------------------------------------------