├── LICENSE
├── README.md
├── UserManual.pdf
├── docs
├── UserManual.tex
├── everb.sty
└── thumt.bib
└── thumt
├── __init__.py
├── bin
├── scorer.py
├── trainer.py
├── trainer_ctx.py
├── translator.py
└── translator_ctx.py
├── data
├── __init__.py
├── cache.py
├── dataset.py
├── record.py
└── vocab.py
├── interface
├── __init__.py
└── model.py
├── layers
├── __init__.py
├── attention.py
├── nn.py
└── rnn_cell.py
├── models
├── __init__.py
├── contextual_transformer.py
├── rnnsearch.py
├── seq2seq.py
└── transformer.py
├── scripts
├── build_vocab.py
├── change.py
├── check_param.py
├── checkpoint_averaging.py
├── combine.py
├── combine_add.py
├── compare.py
├── convert_old_model.py
├── convert_vocab.py
├── input_converter.py
└── shuffle_corpus.py
└── utils
├── __init__.py
├── bleu.py
├── hooks.py
├── inference.py
├── inference_ctx.py
├── optimize.py
├── parallel.py
├── sample.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018, Natural Language Processing Lab at Tsinghua University
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification,
5 | are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice, this
11 | list of conditions and the following disclaimer in the documentation and/or
12 | other materials provided with the distribution.
13 |
14 | * Neither the name of the copyright holder nor the names of its
15 | contributors may be used to endorse or promote products derived from this
16 | software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
22 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
25 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Improving the Transformer Translation Model with Document-Level Context
2 | ## Contents
3 | * [Introduction](#introduction)
4 | * [Usage](#usage)
5 | * [Citation](#citation)
6 | * [FAQ](#faq)
7 |
8 | ## Introduction
9 |
10 | This is the implementation of our work, which extends Transformer to integrate document-level context \[[paper](https://arxiv.org/abs/1810.03581)\]. The implementation is on top of [THUMT](https://github.com/thumt/THUMT)
11 |
12 | ## Usage
13 |
14 | Note: The usage is not user-friendly. May improve later.
15 |
16 | 1. Train a standard Transformer model, please refer to the user manual of [THUMT](https://github.com/thumt/THUMT). Suppose that model_baseline/model.ckpt-30000 performs best on validation set.
17 |
18 | 2. Generate a dummy improved Transformer model with the following command:
19 |
20 |
python THUMT/thumt/bin/trainer_ctx.py --inputs [source corpus] [target corpus] \
21 | --context [context corpus] \
22 | --vocabulary [source vocabulary] [target vocabulary] \
23 | --output model_dummy --model contextual_transformer \
24 | --parameters train_steps=1
25 |
26 |
27 | 3. Generate the initial model by merging the standard Transformer model into the dummy model, then create a checkpoint file:
28 |
29 | python THUMT/thumt/scripts/combine_add.py --model model_dummy/model.ckpt-0 \
30 | --part model_baseline/model.ckpt-30000 --output train
31 | printf 'model_checkpoint_path: "new-0"\nall_model_checkpoint_paths: "new-0"' > train/checkpoint
32 |
33 |
34 |
35 | 4. Train the improved Transformer model with the following command:
36 |
37 | python THUMT/thumt/bin/trainer_ctx.py --inputs [source corpus] [target corpus] \
38 | --context [context corpus] \
39 | --vocabulary [source vocabulary] [target vocabulary] \
40 | --output train --model contextual_transformer \
41 | --parameters start_steps=30000,num_context_layers=1
42 |
43 |
44 | 5. Translate with the improved Transformer model:
45 |
46 | python THUMT/thumt/bin/translator_ctx.py --inputs [source corpus] --context [context corpus] \
47 | --output [translation result] \
48 | --vocabulary [source vocabulary] [target vocabulary] \
49 | --model contextual_transformer --checkpoints [model path] \
50 | --parameters num_context_layers=1
51 |
52 |
53 | ## Citation
54 |
55 | Please cite the following paper if you use the code:
56 |
57 | @InProceedings{Zhang:18,
58 | author = {Zhang, Jiacheng and Luan, Huanbo and Sun, Maosong and Zhai, Feifei and Xu, Jingfang and Zhang, Min and Liu, Yang},
59 | title = {Improving the Transformer Translation Model with Document-Level Context},
60 | booktitle = {Proceedings of EMNLP},
61 | year = {2018},
62 | }
63 |
64 |
65 |
66 | ## FAQ
67 |
68 | 1. What is the context corpus?
69 |
70 | The context corpus file contains one context sentence each line. Normally, context sentence is the several preceding source sentences within a document. For example, if the origin document-level corpus is:
71 |
72 | ==== source ====
73 | <document id=XXX>
74 | <seg id=1>source sentence #1</seg>
75 | <seg id=2>source sentence #2</seg>
76 | <seg id=3>source sentence #3</seg>
77 | <seg id=4>source sentence #4</seg>
78 | </document>
79 |
80 | ==== target ====
81 | <document id=XXX>
82 | <seg id=1>target sentence #1</seg>
83 | <seg id=2>target sentence #2</seg>
84 | <seg id=3>target sentence #3</seg>
85 | <seg id=4>target sentence #4</seg>
86 | </document>
87 |
88 | The inputs to our system should be processed as (suppose that 2 preceding source sentences are used as context):
89 |
90 | ==== train.src ==== (source corpus)
91 | source sentence #1
92 | source sentence #2
93 | source sentence #3
94 | source sentence #4
95 |
96 | ==== train.ctx ==== (context corpus)
97 | (the first line is empty)
98 | source sentence #1
99 | source sentence #1 source sentence #2 (there is only a space between the two sentence)
100 | source sentence #2 source sentence #3
101 |
102 | ==== train.trg ==== (target corpus)
103 | target sentence #1
104 | target sentence #2
105 | target sentence #3
106 | target sentence #4
107 |
108 |
109 |
110 |
--------------------------------------------------------------------------------
/UserManual.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-MT/Document-Transformer/5bcc7f43cc948240fa0e3a400bffdc178f841fcd/UserManual.pdf
--------------------------------------------------------------------------------
/docs/thumt.bib:
--------------------------------------------------------------------------------
1 | @inproceedings{Bahdanau:15,
2 | author = {Bahdanau, Dzmitry and Cho, KyungHyun and Bengio, Yoshua},
3 | title = {Neural Machine Translation by Jointly Learning to Align and Translate},
4 | booktitle = {Proceedings of ICLR},
5 | year = {2015},
6 | }
7 |
8 | @article{Brown:93,
9 | author = "Brown, Peter F. and Della Pietra, Stephen A. and Della Pietra, Vincent J. and Mercer, Robert L.",
10 | title = "The mathematics of statistical machine translation: Parameter estimation",
11 | journal = "Computational Linguistics",
12 | year = "1993",
13 | }
14 |
15 | @InProceedings{Cheng:16,
16 | author = {Cheng, Yong and Xu, Wei and He, Zhongjun and He, Wei and Wu, Hua and Sun, Maosong and Liu, Yang},
17 | title = {Semi-Supervised Learning for Neural Machine Translation},
18 | booktitle = {Proceedings of ACL},
19 | year = {2016},
20 | }
21 |
22 | @InProceedings{Chiang:05,
23 | author = {Chiang, David},
24 | title = {A Hierarchical Phrase-based Model for Statistical Machine Translation},
25 | booktitle = {Proceedings of ACL},
26 | year = {2005},
27 | }
28 |
29 |
30 | @InProceedings{Ding:17,
31 | author = {Ding, Yanzhuo and Liu, Yang and Luan, Huanbo and Sun, Maosong},
32 | title = {Visualizing and Understanding Neural Machine Translation},
33 | booktitle = {Proceedings of ACL},
34 | year = {2017},
35 | }
36 |
37 | @misc{Kingma:14,
38 | author = {Kingma, Diederik P. and Ba, Jimmy},
39 | title = {Adam: A Method for Stochastic Optimization},
40 | howpublished = {arXiv:1412.6980},
41 | year = {2014},
42 | }
43 |
44 | @InProceedings{Koehn:03,
45 | author = {Koehn, Philipp and Och, Franz J. and Marcu, Daniel},
46 | title = {Statistical Phrase-based Translation},
47 | booktitle = {Proceedings of NAACL},
48 | year = {2003},
49 | }
50 |
51 | @InProceedings{Luong:15,
52 | author = {Luong, Thang and Sutskever, Ilya and Le, Quoc and Vinyals, Oriol and Zaremba, Wojciech},
53 | title = {Addressing the Rare Word Problem in Neural Machine Translation},
54 | booktitle = {Proceedings of ACL},
55 | year = {2015},
56 | }
57 |
58 | @inproceedings{Papineni:02,
59 | author = {Papineni, Kishore and Roukos, Salim and Ward, Todd and Zhu, Wei-Jing},
60 | title = {BLEU: A Method for Automatic Evaluation of Machine Translation},
61 | booktitle = {Proceedings of ACL},
62 | year = {2002},
63 | }
64 |
65 | @InProceedings{Sennrich:16,
66 | author = {Sennrich, Rico and Haddow, Barry and Birch, Alexandra},
67 | title = {Neural Machine Translation of Rare Words with Subword Units},
68 | booktitle = {Proceedings of ACL},
69 | year = {2016},
70 | }
71 |
72 | @inproceedings{Shen:16,
73 | author = {Shen, Shiqi and Cheng, Yong and He, Zhongjun and He, Wei and Wu, Hua and Sun, Maosong and Liu, Yang},
74 | title = {Minimum Risk Training for Neural Machine Translation},
75 | booktitle = {Proceedings of ACL},
76 | year = {2016},
77 | }
78 |
79 | @inproceedings{Sutskever:14,
80 | author = {Sutskever, Ilya and Vinyals, Oriol and Le, Quoc V.},
81 | title = {Sequence to Sequence Learning with Neural Networks},
82 | booktitle = {Proceedings of NIPS},
83 | year = {2014},
84 | }
85 |
86 | @inproceedings{Vaswani:17,
87 | title={Attention Is All You Need},
88 | author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and Kaiser, Lukasz and Polosukhin, Illia},
89 | booktitle={Proceedings of NIPS},
90 | year={2017}
91 | }
92 |
93 | @misc{Wu:16,
94 | author= {Yonghui Wu and
95 | Mike Schuster and
96 | Zhifeng Chen and
97 | Quoc V. Le and
98 | Mohammad Norouzi and
99 | Wolfgang Macherey and
100 | Maxim Krikun and
101 | Yuan Cao and
102 | Qin Gao and
103 | Klaus Macherey and
104 | Jeff Klingner and
105 | Apurva Shah and
106 | Melvin Johnson and
107 | Xiaobing Liu and
108 | Lukasz Kaiser and
109 | Stephan Gouws and
110 | Yoshikiyo Kato and
111 | Taku Kudo and
112 | Hideto Kazawa and
113 | Keith Stevens and
114 | George Kurian and
115 | Nishant Patil and
116 | Wei Wang and
117 | Cliff Young and
118 | Jason Smith and
119 | Jason Riesa and
120 | Alex Rudnick and
121 | Oriol Vinyals and
122 | Greg Corrado and
123 | Macduff Hughes and
124 | Jeffrey Dean},
125 | title = {Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation},
126 | howpublished = {arXiv:1609.08144v2},
127 | year = {2016},
128 | }
129 |
130 | @misc{Zeiler:12,
131 | author = {Zeiler, Matthew D.},
132 | title = {AdaDelta: An Adaptive Learning Rate Method},
133 | howpublished = {arXiv:1212.5701v1},
134 | year = {2012},
135 | }
--------------------------------------------------------------------------------
/thumt/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
--------------------------------------------------------------------------------
/thumt/bin/scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2018 The THUMT Authors
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import argparse
10 | import itertools
11 | import os
12 |
13 | import tensorflow as tf
14 | import thumt.data.vocab as vocabulary
15 | import thumt.models as models
16 |
17 |
18 | def parse_args():
19 | parser = argparse.ArgumentParser(
20 | description="Translate using existing NMT models",
21 | usage="translator.py [] [-h | --help]"
22 | )
23 |
24 | # input files
25 | parser.add_argument("--input", type=str, required=True, nargs=2,
26 | help="Path of input file")
27 | parser.add_argument("--output", type=str, required=True,
28 | help="Path of output file")
29 | parser.add_argument("--checkpoint", type=str, required=True,
30 | help="Path of trained models")
31 | parser.add_argument("--vocabulary", type=str, nargs=2, required=True,
32 | help="Path of source and target vocabulary")
33 |
34 | # model and configuration
35 | parser.add_argument("--model", type=str, required=True,
36 | help="Name of the model")
37 | parser.add_argument("--parameters", type=str,
38 | help="Additional hyper parameters")
39 |
40 | return parser.parse_args()
41 |
42 |
43 | def default_parameters():
44 | params = tf.contrib.training.HParams(
45 | input=None,
46 | output=None,
47 | vocabulary=None,
48 | model=None,
49 | # vocabulary specific
50 | pad="",
51 | bos="",
52 | eos="",
53 | unk="",
54 | mapping=None,
55 | append_eos=False,
56 | device_list=[0],
57 | num_threads=6,
58 | eval_batch_size=32
59 | )
60 |
61 | return params
62 |
63 |
64 | def merge_parameters(params1, params2):
65 | params = tf.contrib.training.HParams()
66 |
67 | for (k, v) in params1.values().iteritems():
68 | params.add_hparam(k, v)
69 |
70 | params_dict = params.values()
71 |
72 | for (k, v) in params2.values().iteritems():
73 | if k in params_dict:
74 | # Override
75 | setattr(params, k, v)
76 | else:
77 | params.add_hparam(k, v)
78 |
79 | return params
80 |
81 |
82 | def import_params(model_dir, model_name, params):
83 | model_dir = os.path.abspath(model_dir)
84 | m_name = os.path.join(model_dir, model_name + ".json")
85 |
86 | if not tf.gfile.Exists(m_name):
87 | return params
88 |
89 | with tf.gfile.Open(m_name) as fd:
90 | tf.logging.info("Restoring model parameters from %s" % m_name)
91 | json_str = fd.readline()
92 | params.parse_json(json_str)
93 |
94 | return params
95 |
96 |
97 | def override_parameters(params, args):
98 | if args.parameters:
99 | params.parse(args.parameters)
100 |
101 | params.vocabulary = {
102 | "source": vocabulary.load_vocabulary(args.vocabulary[0]),
103 | "target": vocabulary.load_vocabulary(args.vocabulary[1])
104 | }
105 | params.vocabulary["source"] = vocabulary.process_vocabulary(
106 | params.vocabulary["source"], params
107 | )
108 | params.vocabulary["target"] = vocabulary.process_vocabulary(
109 | params.vocabulary["target"], params
110 | )
111 |
112 | control_symbols = [params.pad, params.bos, params.eos, params.unk]
113 |
114 | params.mapping = {
115 | "source": vocabulary.get_control_mapping(
116 | params.vocabulary["source"],
117 | control_symbols
118 | ),
119 | "target": vocabulary.get_control_mapping(
120 | params.vocabulary["target"],
121 | control_symbols
122 | )
123 | }
124 |
125 | return params
126 |
127 |
128 | def session_config(params):
129 | optimizer_options = tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L1,
130 | do_function_inlining=False)
131 | graph_options = tf.GraphOptions(optimizer_options=optimizer_options)
132 | config = tf.ConfigProto(allow_soft_placement=True,
133 | graph_options=graph_options)
134 | if params.device_list:
135 | device_str = ",".join([str(i) for i in params.device_list])
136 | config.gpu_options.visible_device_list = device_str
137 |
138 | return config
139 |
140 |
141 | def set_variables(var_list, value_dict, prefix):
142 | ops = []
143 | for var in var_list:
144 | for name in value_dict:
145 | var_name = "/".join([prefix] + list(name.split("/")[1:]))
146 |
147 | if var.name[:-2] == var_name:
148 | tf.logging.debug("restoring %s -> %s" % (name, var.name))
149 | with tf.device("/cpu:0"):
150 | op = tf.assign(var, value_dict[name])
151 | ops.append(op)
152 | break
153 |
154 | return ops
155 |
156 |
157 | def read_files(names):
158 | inputs = [[] for _ in range(len(names))]
159 | files = [tf.gfile.GFile(name) for name in names]
160 |
161 | count = 0
162 |
163 | for lines in zip(*files):
164 | lines = [line.strip() for line in lines]
165 |
166 | for i, line in enumerate(lines):
167 | inputs[i].append(line)
168 |
169 | count += 1
170 |
171 | # Close files
172 | for fd in files:
173 | fd.close()
174 |
175 | return inputs
176 |
177 |
178 | def get_features(inputs, params):
179 | with tf.device("/cpu:0"):
180 | # Create datasets
181 | datasets = []
182 |
183 | for data in inputs:
184 | dataset = tf.data.Dataset.from_tensor_slices(data)
185 | # Split string
186 | dataset = dataset.map(lambda x: tf.string_split([x]).values,
187 | num_parallel_calls=params.num_threads)
188 | # Append
189 | dataset = dataset.map(
190 | lambda x: tf.concat([x, [tf.constant(params.eos)]], axis=0),
191 | num_parallel_calls=params.num_threads
192 | )
193 | datasets.append(dataset)
194 |
195 | dataset = tf.data.Dataset.zip(tuple(datasets))
196 |
197 | # Convert tuple to dictionary
198 | dataset = dataset.map(
199 | lambda *x: {
200 | "source": x[0],
201 | "source_length": tf.shape(x[0])[0],
202 | "target": x[1],
203 | "target_length": tf.shape(x[1])[0]
204 | },
205 | num_parallel_calls=params.num_threads
206 | )
207 |
208 | dataset = dataset.padded_batch(
209 | params.eval_batch_size,
210 | {
211 | "source": [tf.Dimension(None)],
212 | "source_length": [],
213 | "target": [tf.Dimension(None)],
214 | "target_length": []
215 | },
216 | {
217 | "source": params.pad,
218 | "source_length": 0,
219 | "target": params.pad,
220 | "target_length": 0
221 | }
222 | )
223 |
224 | iterator = dataset.make_one_shot_iterator()
225 | features = iterator.get_next()
226 |
227 | src_table = tf.contrib.lookup.index_table_from_tensor(
228 | tf.constant(params.vocabulary["source"]),
229 | default_value=params.mapping["source"][params.unk]
230 | )
231 | tgt_table = tf.contrib.lookup.index_table_from_tensor(
232 | tf.constant(params.vocabulary["target"]),
233 | default_value=params.mapping["target"][params.unk]
234 | )
235 | features["source"] = src_table.lookup(features["source"])
236 | features["target"] = tgt_table.lookup(features["target"])
237 |
238 | return features
239 |
240 |
241 | def main(args):
242 | tf.logging.set_verbosity(tf.logging.INFO)
243 | model_cls = models.get_model(args.model)
244 | params = default_parameters()
245 |
246 | # Import and override parameters
247 | # Priorities (low -> high):
248 | # default -> saved -> command
249 | params = merge_parameters(params, model_cls.get_parameters())
250 | params = import_params(args.checkpoint, args.model, params)
251 | override_parameters(params, args)
252 |
253 | # Build Graph
254 | with tf.Graph().as_default():
255 | model = model_cls(params)
256 | inputs = read_files(args.input)
257 | features = get_features(inputs, params)
258 | score_fn = model.get_evaluation_func()
259 | scores = score_fn(features, params)
260 |
261 | sess_creator = tf.train.ChiefSessionCreator(
262 | config=session_config(params)
263 | )
264 |
265 | # Load checkpoint
266 | tf.logging.info("Loading %s" % args.checkpoint)
267 | var_list = tf.train.list_variables(args.checkpoint)
268 | values = {}
269 | reader = tf.train.load_checkpoint(args.checkpoint)
270 |
271 | for (name, shape) in var_list:
272 | if not name.startswith(model_cls.get_name()):
273 | continue
274 |
275 | tensor = reader.get_tensor(name)
276 | values[name] = tensor
277 |
278 | ops = set_variables(tf.trainable_variables(), values,
279 | model_cls.get_name())
280 | assign_op = tf.group(*ops)
281 |
282 | # Create session
283 | with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
284 | # Restore variables
285 | sess.run(assign_op)
286 | fd = tf.gfile.Open(args.output, "w")
287 |
288 | while not sess.should_stop():
289 | results = sess.run(scores)
290 | for value in results:
291 | fd.write("%f\n" % value)
292 |
293 | fd.close()
294 |
295 |
296 | if __name__ == "__main__":
297 | main(parse_args())
298 |
--------------------------------------------------------------------------------
/thumt/bin/trainer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2018 The THUMT Authors
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import argparse
10 | import os
11 | import six
12 |
13 | import numpy as np
14 | import tensorflow as tf
15 | import thumt.data.cache as cache
16 | import thumt.data.dataset as dataset
17 | import thumt.data.record as record
18 | import thumt.data.vocab as vocabulary
19 | import thumt.models as models
20 | import thumt.utils.hooks as hooks
21 | import thumt.utils.inference as inference
22 | import thumt.utils.optimize as optimize
23 | import thumt.utils.parallel as parallel
24 | import thumt.utils.utils as utils
25 |
26 |
27 | def parse_args(args=None):
28 | parser = argparse.ArgumentParser(
29 | description="Training neural machine translation models",
30 | usage="trainer.py [] [-h | --help]"
31 | )
32 |
33 | # input files
34 | parser.add_argument("--input", type=str, nargs=2,
35 | help="Path of source and target corpus")
36 | parser.add_argument("--record", type=str,
37 | help="Path to tf.Record data")
38 | parser.add_argument("--output", type=str, default="train",
39 | help="Path to saved models")
40 | parser.add_argument("--vocabulary", type=str, nargs=2,
41 | help="Path of source and target vocabulary")
42 | parser.add_argument("--validation", type=str,
43 | help="Path of validation file")
44 | parser.add_argument("--references", type=str, nargs="+",
45 | help="Path of reference files")
46 |
47 | # model and configuration
48 | parser.add_argument("--model", type=str, required=True,
49 | help="Name of the model")
50 | parser.add_argument("--parameters", type=str, default="",
51 | help="Additional hyper parameters")
52 |
53 | return parser.parse_args(args)
54 |
55 |
56 | def default_parameters():
57 | params = tf.contrib.training.HParams(
58 | input=["", ""],
59 | output="",
60 | record="",
61 | model="transformer",
62 | vocab=["", ""],
63 | # Default training hyper parameters
64 | num_threads=6,
65 | batch_size=4096,
66 | max_length=256,
67 | length_multiplier=1,
68 | mantissa_bits=2,
69 | warmup_steps=4000,
70 | train_steps=100000,
71 | buffer_size=10000,
72 | constant_batch_size=False,
73 | device_list=[0],
74 | update_cycle=1,
75 | initializer="uniform_unit_scaling",
76 | initializer_gain=1.0,
77 | optimizer="Adam",
78 | adam_beta1=0.9,
79 | adam_beta2=0.999,
80 | adam_epsilon=1e-8,
81 | clip_grad_norm=5.0,
82 | learning_rate=1.0,
83 | learning_rate_decay="linear_warmup_rsqrt_decay",
84 | learning_rate_boundaries=[0],
85 | learning_rate_values=[0.0],
86 | keep_checkpoint_max=20,
87 | keep_top_checkpoint_max=5,
88 | # Validation
89 | eval_steps=2000,
90 | eval_secs=0,
91 | eval_batch_size=32,
92 | top_beams=1,
93 | beam_size=4,
94 | decode_alpha=0.6,
95 | decode_length=50,
96 | validation="",
97 | references=[""],
98 | save_checkpoint_secs=0,
99 | save_checkpoint_steps=1000,
100 | # Setting this to True can save disk spaces, but cannot restore
101 | # training using the saved checkpoint
102 | only_save_trainable=False
103 | )
104 |
105 | return params
106 |
107 |
108 | def import_params(model_dir, model_name, params):
109 | model_dir = os.path.abspath(model_dir)
110 | p_name = os.path.join(model_dir, "params.json")
111 | m_name = os.path.join(model_dir, model_name + ".json")
112 |
113 | if not tf.gfile.Exists(p_name) or not tf.gfile.Exists(m_name):
114 | return params
115 |
116 | with tf.gfile.Open(p_name) as fd:
117 | tf.logging.info("Restoring hyper parameters from %s" % p_name)
118 | json_str = fd.readline()
119 | params.parse_json(json_str)
120 |
121 | with tf.gfile.Open(m_name) as fd:
122 | tf.logging.info("Restoring model parameters from %s" % m_name)
123 | json_str = fd.readline()
124 | params.parse_json(json_str)
125 |
126 | return params
127 |
128 |
129 | def export_params(output_dir, name, params):
130 | if not tf.gfile.Exists(output_dir):
131 | tf.gfile.MkDir(output_dir)
132 |
133 | # Save params as params.json
134 | filename = os.path.join(output_dir, name)
135 | with tf.gfile.Open(filename, "w") as fd:
136 | fd.write(params.to_json())
137 |
138 |
139 | def collect_params(all_params, params):
140 | collected = tf.contrib.training.HParams()
141 |
142 | for k in params.values().iterkeys():
143 | collected.add_hparam(k, getattr(all_params, k))
144 |
145 | return collected
146 |
147 |
148 | def merge_parameters(params1, params2):
149 | params = tf.contrib.training.HParams()
150 |
151 | for (k, v) in params1.values().iteritems():
152 | params.add_hparam(k, v)
153 |
154 | params_dict = params.values()
155 |
156 | for (k, v) in params2.values().iteritems():
157 | if k in params_dict:
158 | # Override
159 | setattr(params, k, v)
160 | else:
161 | params.add_hparam(k, v)
162 |
163 | return params
164 |
165 |
166 | def override_parameters(params, args):
167 | params.model = args.model
168 | params.input = args.input or params.input
169 | params.output = args.output or params.output
170 | params.record = args.record or params.record
171 | params.vocab = args.vocabulary or params.vocab
172 | params.validation = args.validation or params.validation
173 | params.references = args.references or params.references
174 | params.parse(args.parameters)
175 |
176 | params.vocabulary = {
177 | "source": vocabulary.load_vocabulary(params.vocab[0]),
178 | "target": vocabulary.load_vocabulary(params.vocab[1])
179 | }
180 | params.vocabulary["source"] = vocabulary.process_vocabulary(
181 | params.vocabulary["source"], params
182 | )
183 | params.vocabulary["target"] = vocabulary.process_vocabulary(
184 | params.vocabulary["target"], params
185 | )
186 |
187 | control_symbols = [params.pad, params.bos, params.eos, params.unk]
188 |
189 | params.mapping = {
190 | "source": vocabulary.get_control_mapping(
191 | params.vocabulary["source"],
192 | control_symbols
193 | ),
194 | "target": vocabulary.get_control_mapping(
195 | params.vocabulary["target"],
196 | control_symbols
197 | )
198 | }
199 |
200 | return params
201 |
202 |
203 | def get_initializer(params):
204 | if params.initializer == "uniform":
205 | max_val = params.initializer_gain
206 | return tf.random_uniform_initializer(-max_val, max_val)
207 | elif params.initializer == "normal":
208 | return tf.random_normal_initializer(0.0, params.initializer_gain)
209 | elif params.initializer == "normal_unit_scaling":
210 | return tf.variance_scaling_initializer(params.initializer_gain,
211 | mode="fan_avg",
212 | distribution="normal")
213 | elif params.initializer == "uniform_unit_scaling":
214 | return tf.variance_scaling_initializer(params.initializer_gain,
215 | mode="fan_avg",
216 | distribution="uniform")
217 | else:
218 | raise ValueError("Unrecognized initializer: %s" % params.initializer)
219 |
220 |
221 | def get_learning_rate_decay(learning_rate, global_step, params):
222 | if params.learning_rate_decay in ["linear_warmup_rsqrt_decay", "noam"]:
223 | step = tf.to_float(global_step)
224 | warmup_steps = tf.to_float(params.warmup_steps)
225 | multiplier = params.hidden_size ** -0.5
226 | decay = multiplier * tf.minimum((step + 1) * (warmup_steps ** -1.5),
227 | (step + 1) ** -0.5)
228 |
229 | return learning_rate * decay
230 | elif params.learning_rate_decay == "piecewise_constant":
231 | return tf.train.piecewise_constant(tf.to_int32(global_step),
232 | params.learning_rate_boundaries,
233 | params.learning_rate_values)
234 | elif params.learning_rate_decay == "none":
235 | return learning_rate
236 | else:
237 | raise ValueError("Unknown learning_rate_decay")
238 |
239 |
240 | def session_config(params):
241 | optimizer_options = tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L1,
242 | do_function_inlining=True)
243 | graph_options = tf.GraphOptions(optimizer_options=optimizer_options)
244 | config = tf.ConfigProto(allow_soft_placement=True,
245 | graph_options=graph_options)
246 | if params.device_list:
247 | device_str = ",".join([str(i) for i in params.device_list])
248 | config.gpu_options.visible_device_list = device_str
249 |
250 | return config
251 |
252 |
253 | def decode_target_ids(inputs, params):
254 | decoded = []
255 | vocab = params.vocabulary["target"]
256 |
257 | for item in inputs:
258 | syms = []
259 | for idx in item:
260 | if isinstance(idx, six.integer_types):
261 | sym = vocab[idx]
262 | else:
263 | sym = idx
264 |
265 | if sym == params.eos:
266 | break
267 |
268 | if sym == params.pad:
269 | break
270 |
271 | syms.append(sym)
272 | decoded.append(syms)
273 |
274 | return decoded
275 |
276 |
277 | def main(args):
278 | tf.logging.set_verbosity(tf.logging.INFO)
279 | model_cls = models.get_model(args.model)
280 | params = default_parameters()
281 |
282 | # Import and override parameters
283 | # Priorities (low -> high):
284 | # default -> saved -> command
285 | params = merge_parameters(params, model_cls.get_parameters())
286 | params = import_params(args.output, args.model, params)
287 | override_parameters(params, args)
288 |
289 | # Export all parameters and model specific parameters
290 | export_params(params.output, "params.json", params)
291 | export_params(
292 | params.output,
293 | "%s.json" % args.model,
294 | collect_params(params, model_cls.get_parameters())
295 | )
296 |
297 | # Build Graph
298 | with tf.Graph().as_default():
299 | if not params.record:
300 | # Build input queue
301 | features = dataset.get_training_input(params.input, params)
302 | else:
303 | features = record.get_input_features(
304 | os.path.join(params.record, "*train*"), "train", params
305 | )
306 |
307 | features, init_op = cache.cache_features(features,
308 | params.update_cycle)
309 |
310 | # Build model
311 | initializer = get_initializer(params)
312 | model = model_cls(params)
313 |
314 | # Multi-GPU setting
315 | sharded_losses = parallel.parallel_model(
316 | model.get_training_func(initializer),
317 | features,
318 | params.device_list
319 | )
320 | loss = tf.add_n(sharded_losses) / len(sharded_losses)
321 |
322 | # Create global step
323 | global_step = tf.train.get_or_create_global_step()
324 |
325 | # Print parameters
326 | all_weights = {v.name: v for v in tf.trainable_variables()}
327 | total_size = 0
328 |
329 | for v_name in sorted(list(all_weights)):
330 | v = all_weights[v_name]
331 | tf.logging.info("%s\tshape %s", v.name[:-2].ljust(80),
332 | str(v.shape).ljust(20))
333 | v_size = np.prod(np.array(v.shape.as_list())).tolist()
334 | total_size += v_size
335 | tf.logging.info("Total trainable variables size: %d", total_size)
336 |
337 | learning_rate = get_learning_rate_decay(params.learning_rate,
338 | global_step, params)
339 | learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
340 | tf.summary.scalar("learning_rate", learning_rate)
341 |
342 | # Create optimizer
343 | if params.optimizer == "Adam":
344 | opt = tf.train.AdamOptimizer(learning_rate,
345 | beta1=params.adam_beta1,
346 | beta2=params.adam_beta2,
347 | epsilon=params.adam_epsilon)
348 | elif params.optimizer == "LazyAdam":
349 | opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate,
350 | beta1=params.adam_beta1,
351 | beta2=params.adam_beta2,
352 | epsilon=params.adam_epsilon)
353 | else:
354 | raise RuntimeError("Optimizer %s not supported" % params.optimizer)
355 |
356 | loss, ops = optimize.create_train_op(loss, opt, global_step, params)
357 |
358 | # Validation
359 | if params.validation and params.references[0]:
360 | files = [params.validation] + list(params.references)
361 | eval_inputs = dataset.sort_and_zip_files(files)
362 | eval_input_fn = dataset.get_evaluation_input
363 | else:
364 | eval_input_fn = None
365 |
366 | # Add hooks
367 | save_vars = tf.trainable_variables() + [global_step]
368 | saver = tf.train.Saver(
369 | var_list=save_vars if params.only_save_trainable else None,
370 | max_to_keep=params.keep_checkpoint_max,
371 | sharded=False
372 | )
373 | tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
374 |
375 | train_hooks = [
376 | tf.train.StopAtStepHook(last_step=params.train_steps),
377 | tf.train.NanTensorHook(loss),
378 | tf.train.LoggingTensorHook(
379 | {
380 | "step": global_step,
381 | "loss": loss,
382 | },
383 | every_n_iter=1
384 | ),
385 | tf.train.CheckpointSaverHook(
386 | checkpoint_dir=params.output,
387 | save_secs=params.save_checkpoint_secs or None,
388 | save_steps=params.save_checkpoint_steps or None,
389 | saver=saver
390 | )
391 | ]
392 |
393 | config = session_config(params)
394 |
395 | if eval_input_fn is not None:
396 | train_hooks.append(
397 | hooks.EvaluationHook(
398 | lambda f: inference.create_inference_graph(
399 | [model.get_inference_func()], f, params
400 | ),
401 | lambda: eval_input_fn(eval_inputs, params),
402 | lambda x: decode_target_ids(x, params),
403 | params.output,
404 | config,
405 | params.keep_top_checkpoint_max,
406 | eval_secs=params.eval_secs,
407 | eval_steps=params.eval_steps
408 | )
409 | )
410 |
411 | # Create session, do not use default CheckpointSaverHook
412 | with tf.train.MonitoredTrainingSession(
413 | checkpoint_dir=params.output, hooks=train_hooks,
414 | save_checkpoint_secs=None, config=config) as sess:
415 | while not sess.should_stop():
416 | # Bypass hook calls
417 | utils.session_run(sess, [init_op, ops["zero_op"]])
418 | for i in range(params.update_cycle):
419 | utils.session_run(sess, ops["collect_op"])
420 | utils.session_run(sess, ops["scale_op"])
421 | sess.run(ops["train_op"])
422 |
423 |
424 | if __name__ == "__main__":
425 | main(parse_args())
426 |
--------------------------------------------------------------------------------
/thumt/bin/translator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2018 The THUMT Authors
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import argparse
10 | import itertools
11 | import os
12 |
13 | import numpy as np
14 | import tensorflow as tf
15 | import thumt.data.dataset as dataset
16 | import thumt.data.vocab as vocabulary
17 | import thumt.models as models
18 | import thumt.utils.inference as inference
19 | import thumt.utils.parallel as parallel
20 |
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser(
24 | description="Translate using existing NMT models",
25 | usage="translator.py [] [-h | --help]"
26 | )
27 |
28 | # input files
29 | parser.add_argument("--input", type=str, required=True,
30 | help="Path of input file")
31 | parser.add_argument("--output", type=str, required=True,
32 | help="Path of output file")
33 | parser.add_argument("--checkpoints", type=str, nargs="+", required=True,
34 | help="Path of trained models")
35 | parser.add_argument("--vocabulary", type=str, nargs=2, required=True,
36 | help="Path of source and target vocabulary")
37 |
38 | # model and configuration
39 | parser.add_argument("--models", type=str, required=True, nargs="+",
40 | help="Name of the model")
41 | parser.add_argument("--parameters", type=str,
42 | help="Additional hyper parameters")
43 | parser.add_argument("--verbose", action="store_true",
44 | help="Enable verbose output")
45 |
46 | return parser.parse_args()
47 |
48 |
49 | def default_parameters():
50 | params = tf.contrib.training.HParams(
51 | input=None,
52 | output=None,
53 | vocabulary=None,
54 | # vocabulary specific
55 | pad="",
56 | bos="",
57 | eos="",
58 | unk="",
59 | mapping=None,
60 | append_eos=False,
61 | # decoding
62 | top_beams=1,
63 | beam_size=4,
64 | decode_alpha=0.6,
65 | decode_length=50,
66 | decode_batch_size=32,
67 | device_list=[0],
68 | num_threads=1
69 | )
70 |
71 | return params
72 |
73 |
74 | def merge_parameters(params1, params2):
75 | params = tf.contrib.training.HParams()
76 |
77 | for (k, v) in params1.values().iteritems():
78 | params.add_hparam(k, v)
79 |
80 | params_dict = params.values()
81 |
82 | for (k, v) in params2.values().iteritems():
83 | if k in params_dict:
84 | # Override
85 | setattr(params, k, v)
86 | else:
87 | params.add_hparam(k, v)
88 |
89 | return params
90 |
91 |
92 | def import_params(model_dir, model_name, params):
93 | if model_name.startswith("experimental_"):
94 | model_name = model_name[13:]
95 |
96 | model_dir = os.path.abspath(model_dir)
97 | m_name = os.path.join(model_dir, model_name + ".json")
98 |
99 | if not tf.gfile.Exists(m_name):
100 | return params
101 |
102 | with tf.gfile.Open(m_name) as fd:
103 | tf.logging.info("Restoring model parameters from %s" % m_name)
104 | json_str = fd.readline()
105 | params.parse_json(json_str)
106 |
107 | return params
108 |
109 |
110 | def override_parameters(params, args):
111 | if args.parameters:
112 | params.parse(args.parameters)
113 |
114 | params.vocabulary = {
115 | "source": vocabulary.load_vocabulary(args.vocabulary[0]),
116 | "target": vocabulary.load_vocabulary(args.vocabulary[1])
117 | }
118 | params.vocabulary["source"] = vocabulary.process_vocabulary(
119 | params.vocabulary["source"], params
120 | )
121 | params.vocabulary["target"] = vocabulary.process_vocabulary(
122 | params.vocabulary["target"], params
123 | )
124 |
125 | control_symbols = [params.pad, params.bos, params.eos, params.unk]
126 |
127 | params.mapping = {
128 | "source": vocabulary.get_control_mapping(
129 | params.vocabulary["source"],
130 | control_symbols
131 | ),
132 | "target": vocabulary.get_control_mapping(
133 | params.vocabulary["target"],
134 | control_symbols
135 | )
136 | }
137 |
138 | return params
139 |
140 |
141 | def session_config(params):
142 | optimizer_options = tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L1,
143 | do_function_inlining=False)
144 | graph_options = tf.GraphOptions(optimizer_options=optimizer_options)
145 | config = tf.ConfigProto(allow_soft_placement=True,
146 | graph_options=graph_options)
147 | if params.device_list:
148 | device_str = ",".join([str(i) for i in params.device_list])
149 | config.gpu_options.visible_device_list = device_str
150 |
151 | return config
152 |
153 |
154 | def set_variables(var_list, value_dict, prefix):
155 | ops = []
156 | for var in var_list:
157 | for name in value_dict:
158 | var_name = "/".join([prefix] + list(name.split("/")[1:]))
159 |
160 | if var.name[:-2] == var_name:
161 | tf.logging.debug("restoring %s -> %s" % (name, var.name))
162 | with tf.device("/cpu:0"):
163 | op = tf.assign(var, value_dict[name])
164 | ops.append(op)
165 | break
166 |
167 | return ops
168 |
169 |
170 | def shard_features(features, placeholders, predictions):
171 | num_shards = len(placeholders)
172 | feed_dict = {}
173 | n = 0
174 |
175 | for name in features:
176 | feat = features[name]
177 | batch = feat.shape[0]
178 |
179 | if batch < num_shards:
180 | feed_dict[placeholders[0][name]] = feat
181 | n = 1
182 | else:
183 | shard_size = (batch + num_shards - 1) // num_shards
184 |
185 | for i in range(num_shards):
186 | shard_feat = feat[i * shard_size:(i + 1) * shard_size]
187 | feed_dict[placeholders[i][name]] = shard_feat
188 | n = num_shards
189 |
190 | return predictions[:n], feed_dict
191 |
192 |
193 | def main(args):
194 | tf.logging.set_verbosity(tf.logging.INFO)
195 | # Load configs
196 | model_cls_list = [models.get_model(model) for model in args.models]
197 | params_list = [default_parameters() for _ in range(len(model_cls_list))]
198 | params_list = [
199 | merge_parameters(params, model_cls.get_parameters())
200 | for params, model_cls in zip(params_list, model_cls_list)
201 | ]
202 | params_list = [
203 | import_params(args.checkpoints[i], args.models[i], params_list[i])
204 | for i in range(len(args.checkpoints))
205 | ]
206 | params_list = [
207 | override_parameters(params_list[i], args)
208 | for i in range(len(model_cls_list))
209 | ]
210 |
211 | # Build Graph
212 | with tf.Graph().as_default():
213 | model_var_lists = []
214 |
215 | # Load checkpoints
216 | for i, checkpoint in enumerate(args.checkpoints):
217 | tf.logging.info("Loading %s" % checkpoint)
218 | var_list = tf.train.list_variables(checkpoint)
219 | values = {}
220 | reader = tf.train.load_checkpoint(checkpoint)
221 |
222 | for (name, shape) in var_list:
223 | if not name.startswith(model_cls_list[i].get_name()):
224 | continue
225 |
226 | if name.find("losses_avg") >= 0:
227 | continue
228 |
229 | tensor = reader.get_tensor(name)
230 | values[name] = tensor
231 |
232 | model_var_lists.append(values)
233 |
234 | # Build models
235 | model_fns = []
236 |
237 | for i in range(len(args.checkpoints)):
238 | name = model_cls_list[i].get_name()
239 | model = model_cls_list[i](params_list[i], name + "_%d" % i)
240 | model_fn = model.get_inference_func()
241 | model_fns.append(model_fn)
242 |
243 | params = params_list[0]
244 | # Read input file
245 | sorted_keys, sorted_inputs = dataset.sort_input_file(args.input)
246 | # Build input queue
247 | features = dataset.get_inference_input(sorted_inputs, params)
248 | # Create placeholders
249 | placeholders = []
250 |
251 | for i in range(len(params.device_list)):
252 | placeholders.append({
253 | "source": tf.placeholder(tf.int32, [None, None],
254 | "source_%d" % i),
255 | "source_length": tf.placeholder(tf.int32, [None],
256 | "source_length_%d" % i)
257 | })
258 |
259 | # A list of outputs
260 | predictions = parallel.data_parallelism(
261 | params.device_list,
262 | lambda f: inference.create_inference_graph(model_fns, f, params),
263 | placeholders)
264 |
265 | # Create assign ops
266 | assign_ops = []
267 |
268 | all_var_list = tf.trainable_variables()
269 |
270 | for i in range(len(args.checkpoints)):
271 | un_init_var_list = []
272 | name = model_cls_list[i].get_name()
273 |
274 | for v in all_var_list:
275 | if v.name.startswith(name + "_%d" % i):
276 | un_init_var_list.append(v)
277 |
278 | ops = set_variables(un_init_var_list, model_var_lists[i],
279 | name + "_%d" % i)
280 | assign_ops.extend(ops)
281 |
282 | assign_op = tf.group(*assign_ops)
283 | results = []
284 |
285 | # Create session
286 | with tf.Session(config=session_config(params)) as sess:
287 | # Restore variables
288 | sess.run(assign_op)
289 | sess.run(tf.tables_initializer())
290 |
291 | while True:
292 | try:
293 | feats = sess.run(features)
294 | op, feed_dict = shard_features(feats, placeholders,
295 | predictions)
296 | results.append(sess.run(predictions, feed_dict=feed_dict))
297 | message = "Finished batch %d" % len(results)
298 | tf.logging.log(tf.logging.INFO, message)
299 | except tf.errors.OutOfRangeError:
300 | break
301 |
302 | # Convert to plain text
303 | vocab = params.vocabulary["target"]
304 | outputs = []
305 | scores = []
306 |
307 | for result in results:
308 | for item in result[0]:
309 | outputs.append(item.tolist())
310 | for item in result[1]:
311 | scores.append(item.tolist())
312 |
313 | outputs = list(itertools.chain(*outputs))
314 | scores = list(itertools.chain(*scores))
315 |
316 | restored_inputs = []
317 | restored_outputs = []
318 | restored_scores = []
319 |
320 | for index in range(len(sorted_inputs)):
321 | restored_inputs.append(sorted_inputs[sorted_keys[index]])
322 | restored_outputs.append(outputs[sorted_keys[index]])
323 | restored_scores.append(scores[sorted_keys[index]])
324 |
325 | # Write to file
326 | with open(args.output, "w") as outfile:
327 | count = 0
328 | for outputs, scores in zip(restored_outputs, restored_scores):
329 | for output, score in zip(outputs, scores):
330 | decoded = []
331 | for idx in output:
332 | if idx == params.mapping["target"][params.eos]:
333 | break
334 | decoded.append(vocab[idx])
335 |
336 | decoded = " ".join(decoded)
337 |
338 | if not args.verbose:
339 | outfile.write("%s\n" % decoded)
340 | break
341 | else:
342 | pattern = "%d ||| %s ||| %s ||| %f\n"
343 | source = restored_inputs[count]
344 | values = (count, source, decoded, score)
345 | outfile.write(pattern % values)
346 |
347 | count += 1
348 |
349 |
350 | if __name__ == "__main__":
351 | main(parse_args())
352 |
--------------------------------------------------------------------------------
/thumt/bin/translator_ctx.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2018 The THUMT Authors
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import argparse
10 | import itertools
11 | import os
12 |
13 | import numpy as np
14 | import tensorflow as tf
15 | import thumt.data.dataset as dataset
16 | import thumt.data.vocab as vocabulary
17 | import thumt.models as models
18 | import thumt.utils.inference_ctx as inference
19 | import thumt.utils.parallel as parallel
20 |
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser(
24 | description="Translate using existing NMT models",
25 | usage="translator.py [] [-h | --help]"
26 | )
27 |
28 | # input files
29 | parser.add_argument("--input", type=str, required=True,
30 | help="Path of input file")
31 | parser.add_argument("--context", type=str, required=True,
32 | help="Path of context file")
33 | parser.add_argument("--output", type=str, required=True,
34 | help="Path of output file")
35 | parser.add_argument("--checkpoints", type=str, nargs="+", required=True,
36 | help="Path of trained models")
37 | parser.add_argument("--vocabulary", type=str, nargs=2, required=True,
38 | help="Path of source and target vocabulary")
39 |
40 | # model and configuration
41 | parser.add_argument("--models", type=str, required=True, nargs="+",
42 | help="Name of the model")
43 | parser.add_argument("--parameters", type=str,
44 | help="Additional hyper parameters")
45 | parser.add_argument("--verbose", action="store_true",
46 | help="Enable verbose output")
47 |
48 | return parser.parse_args()
49 |
50 |
51 | def default_parameters():
52 | params = tf.contrib.training.HParams(
53 | input=None,
54 | output=None,
55 | vocabulary=None,
56 | # vocabulary specific
57 | pad="",
58 | bos="",
59 | eos="",
60 | unk="",
61 | mapping=None,
62 | append_eos=False,
63 | # decoding
64 | top_beams=1,
65 | beam_size=4,
66 | decode_alpha=0.6,
67 | decode_length=50,
68 | decode_batch_size=32,
69 | device_list=[0],
70 | num_threads=1
71 | )
72 |
73 | return params
74 |
75 |
76 | def merge_parameters(params1, params2):
77 | params = tf.contrib.training.HParams()
78 |
79 | for (k, v) in params1.values().iteritems():
80 | params.add_hparam(k, v)
81 |
82 | params_dict = params.values()
83 |
84 | for (k, v) in params2.values().iteritems():
85 | if k in params_dict:
86 | # Override
87 | setattr(params, k, v)
88 | else:
89 | params.add_hparam(k, v)
90 |
91 | return params
92 |
93 |
94 | def import_params(model_dir, model_name, params):
95 | if model_name.startswith("experimental_"):
96 | model_name = model_name[13:]
97 |
98 | model_dir = os.path.abspath(model_dir)
99 | m_name = os.path.join(model_dir, model_name + ".json")
100 |
101 | if not tf.gfile.Exists(m_name):
102 | return params
103 |
104 | with tf.gfile.Open(m_name) as fd:
105 | tf.logging.info("Restoring model parameters from %s" % m_name)
106 | json_str = fd.readline()
107 | params.parse_json(json_str)
108 |
109 | return params
110 |
111 |
112 | def override_parameters(params, args):
113 | if args.parameters:
114 | params.parse(args.parameters)
115 |
116 | params.vocabulary = {
117 | "source": vocabulary.load_vocabulary(args.vocabulary[0]),
118 | "target": vocabulary.load_vocabulary(args.vocabulary[1])
119 | }
120 | params.vocabulary["source"] = vocabulary.process_vocabulary(
121 | params.vocabulary["source"], params
122 | )
123 | params.vocabulary["target"] = vocabulary.process_vocabulary(
124 | params.vocabulary["target"], params
125 | )
126 |
127 | control_symbols = [params.pad, params.bos, params.eos, params.unk]
128 |
129 | params.mapping = {
130 | "source": vocabulary.get_control_mapping(
131 | params.vocabulary["source"],
132 | control_symbols
133 | ),
134 | "target": vocabulary.get_control_mapping(
135 | params.vocabulary["target"],
136 | control_symbols
137 | )
138 | }
139 |
140 | return params
141 |
142 |
143 | def session_config(params):
144 | optimizer_options = tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L1,
145 | do_function_inlining=False)
146 | graph_options = tf.GraphOptions(optimizer_options=optimizer_options)
147 | config = tf.ConfigProto(allow_soft_placement=True,
148 | graph_options=graph_options)
149 | if params.device_list:
150 | device_str = ",".join([str(i) for i in params.device_list])
151 | config.gpu_options.visible_device_list = device_str
152 |
153 | return config
154 |
155 |
156 | def set_variables(var_list, value_dict, prefix):
157 | ops = []
158 | for var in var_list:
159 | for name in value_dict:
160 | var_name = "/".join([prefix] + list(name.split("/")[1:]))
161 |
162 | if var.name[:-2] == var_name:
163 | tf.logging.debug("restoring %s -> %s" % (name, var.name))
164 | with tf.device("/cpu:0"):
165 | op = tf.assign(var, value_dict[name])
166 | ops.append(op)
167 | break
168 |
169 | return ops
170 |
171 |
172 | def shard_features(features, placeholders, predictions):
173 | num_shards = len(placeholders)
174 | feed_dict = {}
175 | n = 0
176 |
177 | for name in features:
178 | feat = features[name]
179 | batch = feat.shape[0]
180 |
181 | if batch < num_shards:
182 | feed_dict[placeholders[0][name]] = feat
183 | n = 1
184 | else:
185 | shard_size = (batch + num_shards - 1) // num_shards
186 |
187 | for i in range(num_shards):
188 | shard_feat = feat[i * shard_size:(i + 1) * shard_size]
189 | feed_dict[placeholders[i][name]] = shard_feat
190 | n = num_shards
191 |
192 | return predictions[:n], feed_dict
193 |
194 |
195 | def main(args):
196 | tf.logging.set_verbosity(tf.logging.INFO)
197 | # Load configs
198 | model_cls_list = [models.get_model(model) for model in args.models]
199 | params_list = [default_parameters() for _ in range(len(model_cls_list))]
200 | params_list = [
201 | merge_parameters(params, model_cls.get_parameters())
202 | for params, model_cls in zip(params_list, model_cls_list)
203 | ]
204 | params_list = [
205 | import_params(args.checkpoints[i], args.models[i], params_list[i])
206 | for i in range(len(args.checkpoints))
207 | ]
208 | params_list = [
209 | override_parameters(params_list[i], args)
210 | for i in range(len(model_cls_list))
211 | ]
212 |
213 | # Build Graph
214 | with tf.Graph().as_default():
215 | model_var_lists = []
216 |
217 | # Load checkpoints
218 | for i, checkpoint in enumerate(args.checkpoints):
219 | tf.logging.info("Loading %s" % checkpoint)
220 | var_list = tf.train.list_variables(checkpoint)
221 | values = {}
222 | reader = tf.train.load_checkpoint(checkpoint)
223 |
224 | for (name, shape) in var_list:
225 | if not name.startswith(model_cls_list[i].get_name()):
226 | continue
227 |
228 | if name.find("losses_avg") >= 0:
229 | continue
230 |
231 | tensor = reader.get_tensor(name)
232 | values[name] = tensor
233 |
234 | model_var_lists.append(values)
235 |
236 | # Build models
237 | model_fns = []
238 |
239 | for i in range(len(args.checkpoints)):
240 | name = model_cls_list[i].get_name()
241 | model = model_cls_list[i](params_list[i], name + "_%d" % i)
242 | model_fn = model.get_inference_func()
243 | model_fns.append(model_fn)
244 |
245 | params = params_list[0]
246 | # Read input file
247 | sorted_keys, sorted_inputs, sorted_ctxs = dataset.sort_input_file_ctx(args.input, args.context)
248 | # Build input queue
249 | features = dataset.get_inference_input(sorted_inputs, params)
250 | features_ctx = dataset.get_inference_input(sorted_ctxs, params)
251 | features["context"] = features_ctx["source"]
252 | features["context_length"] = features_ctx["source_length"]
253 | # Create placeholders
254 | placeholders = []
255 |
256 | for i in range(len(params.device_list)):
257 | placeholders.append({
258 | "source": tf.placeholder(tf.int32, [None, None],
259 | "source_%d" % i),
260 | "source_length": tf.placeholder(tf.int32, [None],
261 | "source_length_%d" % i),
262 | "context": tf.placeholder(tf.int32, [None, None],
263 | "context_%d" % i),
264 | "context_length": tf.placeholder(tf.int32, [None],
265 | "context_length_%d" % i)
266 | })
267 |
268 | # A list of outputs
269 | predictions = parallel.data_parallelism(
270 | params.device_list,
271 | lambda f: inference.create_inference_graph(model_fns, f, params),
272 | placeholders)
273 |
274 | # Create assign ops
275 | assign_ops = []
276 |
277 | all_var_list = tf.all_variables()
278 |
279 | for i in range(len(args.checkpoints)):
280 | un_init_var_list = []
281 | name = model_cls_list[i].get_name()
282 |
283 | for v in all_var_list:
284 | if v.name.startswith(name + "_%d" % i):
285 | un_init_var_list.append(v)
286 |
287 | ops = set_variables(un_init_var_list, model_var_lists[i],
288 | name + "_%d" % i)
289 | assign_ops.extend(ops)
290 |
291 | assign_op = tf.group(*assign_ops)
292 | results = []
293 |
294 | # Create session
295 | with tf.Session(config=session_config(params)) as sess:
296 | # Restore variables
297 | sess.run(assign_op)
298 | sess.run(tf.tables_initializer())
299 |
300 | while True:
301 | try:
302 | feats = sess.run(features)
303 | op, feed_dict = shard_features(feats, placeholders,
304 | predictions)
305 | results.append(sess.run(predictions, feed_dict=feed_dict))
306 | message = "Finished batch %d" % len(results)
307 | tf.logging.log(tf.logging.INFO, message)
308 | except tf.errors.OutOfRangeError:
309 | break
310 |
311 | # Convert to plain text
312 | vocab = params.vocabulary["target"]
313 | outputs = []
314 | scores = []
315 |
316 | for result in results:
317 | for item in result[0]:
318 | outputs.append(item.tolist())
319 | for item in result[1]:
320 | scores.append(item.tolist())
321 |
322 | outputs = list(itertools.chain(*outputs))
323 | scores = list(itertools.chain(*scores))
324 |
325 | restored_inputs = []
326 | restored_outputs = []
327 | restored_scores = []
328 |
329 | for index in range(len(sorted_inputs)):
330 | restored_inputs.append(sorted_inputs[sorted_keys[index]])
331 | restored_outputs.append(outputs[sorted_keys[index]])
332 | restored_scores.append(scores[sorted_keys[index]])
333 |
334 | # Write to file
335 | with open(args.output, "w") as outfile:
336 | count = 0
337 | for outputs, scores in zip(restored_outputs, restored_scores):
338 | for output, score in zip(outputs, scores):
339 | decoded = []
340 | for idx in output:
341 | if idx == params.mapping["target"][params.eos]:
342 | break
343 | decoded.append(vocab[idx])
344 |
345 | decoded = " ".join(decoded)
346 |
347 | if not args.verbose:
348 | outfile.write("%s\n" % decoded)
349 | break
350 | else:
351 | pattern = "%d ||| %s ||| %s ||| %f\n"
352 | source = restored_inputs[count]
353 | values = (count, source, decoded, score)
354 | outfile.write(pattern % values)
355 |
356 | count += 1
357 |
358 |
359 | if __name__ == "__main__":
360 | main(parse_args())
361 |
--------------------------------------------------------------------------------
/thumt/data/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
--------------------------------------------------------------------------------
/thumt/data/cache.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import tensorflow as tf
9 |
10 |
11 | def cache_features(features, num_shards):
12 | if num_shards == 1:
13 | return features, tf.no_op(name="init_queue")
14 |
15 | flat_features = list(features.itervalues())
16 | queue = tf.FIFOQueue(num_shards, dtypes=[v.dtype for v in flat_features])
17 | flat_features = [tf.split(v, num_shards, axis=0) for v in flat_features]
18 | flat_features = list(zip(*flat_features))
19 | init_ops = [queue.enqueue(v, name="enqueue_%d" % i)
20 | for i, v in enumerate(flat_features)]
21 | flat_feature = queue.dequeue()
22 | new_features = {}
23 |
24 | for k, v in zip(features.iterkeys(), flat_feature):
25 | v.set_shape(features[k].shape)
26 | new_features[k] = v
27 |
28 | return new_features, tf.group(*init_ops)
29 |
--------------------------------------------------------------------------------
/thumt/data/record.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Code modified from Tensor2Tensor library
3 | # Copyright 2018 The THUMT Authors
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import math
10 |
11 | import numpy as np
12 | import six
13 | import tensorflow as tf
14 | from tensorflow.contrib.slim import parallel_reader, tfexample_decoder
15 |
16 |
17 | def input_pipeline(file_pattern, mode, capacity=64):
18 | keys_to_features = {
19 | "source": tf.VarLenFeature(tf.int64),
20 | "target": tf.VarLenFeature(tf.int64),
21 | "source_length": tf.FixedLenFeature([1], tf.int64),
22 | "target_length": tf.FixedLenFeature([1], tf.int64)
23 | }
24 |
25 | items_to_handlers = {
26 | "source": tfexample_decoder.Tensor("source"),
27 | "target": tfexample_decoder.Tensor("target"),
28 | "source_length": tfexample_decoder.Tensor("source_length"),
29 | "target_length": tfexample_decoder.Tensor("target_length")
30 | }
31 |
32 | # Now the non-trivial case construction.
33 | with tf.name_scope("examples_queue"):
34 | training = (mode == "train")
35 | # Read serialized examples using slim parallel_reader.
36 | num_epochs = None if training else 1
37 | data_files = parallel_reader.get_data_files(file_pattern)
38 | num_readers = min(4 if training else 1, len(data_files))
39 | _, examples = parallel_reader.parallel_read([file_pattern],
40 | tf.TFRecordReader,
41 | num_epochs=num_epochs,
42 | shuffle=training,
43 | capacity=2 * capacity,
44 | min_after_dequeue=capacity,
45 | num_readers=num_readers)
46 |
47 | decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
48 | items_to_handlers)
49 |
50 | decoded = decoder.decode(examples, items=list(items_to_handlers))
51 | examples = {}
52 |
53 | for (field, tensor) in zip(keys_to_features, decoded):
54 | examples[field] = tensor
55 |
56 | # We do not want int64s as they do are not supported on GPUs.
57 | return {k: tf.to_int32(v) for (k, v) in six.iteritems(examples)}
58 |
59 |
60 | def batch_examples(examples, batch_size, max_length, mantissa_bits,
61 | shard_multiplier=1, length_multiplier=1, scheme="token",
62 | drop_long_sequences=True):
63 | with tf.name_scope("batch_examples"):
64 | max_length = max_length or batch_size
65 | min_length = 8
66 | mantissa_bits = mantissa_bits
67 |
68 | # compute boundaries
69 | x = min_length
70 | boundaries = []
71 |
72 | while x < max_length:
73 | boundaries.append(x)
74 | x += 2 ** max(0, int(math.log(x, 2)) - mantissa_bits)
75 |
76 | if scheme is "token":
77 | batch_sizes = [max(1, batch_size // length)
78 | for length in boundaries + [max_length]]
79 | batch_sizes = [b * shard_multiplier for b in batch_sizes]
80 | bucket_capacities = [2 * b for b in batch_sizes]
81 | else:
82 | batch_sizes = batch_size * shard_multiplier
83 | bucket_capacities = [2 * n for n in boundaries + [max_length]]
84 |
85 | max_length *= length_multiplier
86 | boundaries = [boundary * length_multiplier for boundary in boundaries]
87 | max_length = max_length if drop_long_sequences else 10 ** 9
88 |
89 | # The queue to bucket on will be chosen based on maximum length.
90 | max_example_length = 0
91 | for v in examples.values():
92 | seq_length = tf.shape(v)[0]
93 | max_example_length = tf.maximum(max_example_length, seq_length)
94 |
95 | (_, outputs) = tf.contrib.training.bucket_by_sequence_length(
96 | max_example_length,
97 | examples,
98 | batch_sizes,
99 | [b + 1 for b in boundaries],
100 | capacity=2,
101 | bucket_capacities=bucket_capacities,
102 | dynamic_pad=True,
103 | keep_input=(max_example_length <= max_length)
104 | )
105 |
106 | return outputs
107 |
108 |
109 | def get_input_features(file_patterns, mode, params):
110 | with tf.name_scope("input_queues"):
111 | with tf.device("/cpu:0"):
112 | if mode != "train":
113 | num_datashards = 1
114 | batch_size = params.eval_batch_size
115 | else:
116 | num_datashards = len(params.device_list)
117 | batch_size = params.batch_size
118 |
119 | batch_size_multiplier = 1
120 | capacity = 64 * num_datashards
121 | examples = input_pipeline(file_patterns, mode, capacity)
122 | drop_long_sequences = (mode == "train")
123 |
124 | feature_map = batch_examples(
125 | examples,
126 | batch_size,
127 | params.max_length,
128 | params.mantissa_bits,
129 | num_datashards,
130 | batch_size_multiplier,
131 | "token" if not params.constant_batch_size else "constant",
132 | drop_long_sequences
133 | )
134 |
135 | features = {
136 | "source": feature_map["source"],
137 | "target": feature_map["target"],
138 | "source_length": tf.squeeze(feature_map["source_length"], axis=1),
139 | "target_length": tf.squeeze(feature_map["target_length"], axis=1)
140 | }
141 |
142 | return features
143 |
--------------------------------------------------------------------------------
/thumt/data/vocab.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import tensorflow as tf
9 |
10 |
11 | def load_vocabulary(filename):
12 | vocab = []
13 | with tf.gfile.GFile(filename) as fd:
14 | for line in fd:
15 | word = line.strip()
16 | vocab.append(word)
17 |
18 | return vocab
19 |
20 |
21 | def process_vocabulary(vocab, params):
22 | if params.append_eos:
23 | vocab.append(params.eos)
24 |
25 | return vocab
26 |
27 |
28 | def get_control_mapping(vocab, symbols):
29 | mapping = {}
30 |
31 | for i, token in enumerate(vocab):
32 | for symbol in symbols:
33 | if symbol.decode("utf-8") == token.decode("utf-8"):
34 | mapping[symbol] = i
35 |
36 | return mapping
37 |
--------------------------------------------------------------------------------
/thumt/interface/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | from thumt.interface.model import NMTModel
9 |
--------------------------------------------------------------------------------
/thumt/interface/model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 |
9 | class NMTModel(object):
10 | """ Abstract object representing an NMT model """
11 |
12 | def __init__(self, params, scope):
13 | self._scope = scope
14 | self._params = params
15 |
16 | def get_training_func(self, initializer):
17 | """
18 | :param initializer: the initializer used to initialize the model
19 | :return: a function with the following signature:
20 | (features, params, reuse) -> loss
21 | """
22 | raise NotImplementedError("Not implemented")
23 |
24 | def get_evaluation_func(self):
25 | """
26 | :return: a function with the following signature:
27 | (features, params) -> score
28 | """
29 | raise NotImplementedError("Not implemented")
30 |
31 | def get_inference_func(self):
32 | """
33 | :returns:
34 | If a model implements incremental decoding, this function should
35 | returns a tuple of (encoding_fn, decoding_fn), with the following
36 | requirements:
37 | encoding_fn: (features, params) -> initial_state
38 | decoding_fn: (feature, state, params) -> log_prob, next_state
39 |
40 | If a model does not implement the incremental decoding (slower
41 | decoding speed but easier to write the code), then this
42 | function should returns a single function with the following
43 | signature:
44 | (features, params) -> log_prob
45 |
46 | See models/transformer.py and models/rnnsearch.py
47 | for comparison.
48 | """
49 | raise NotImplementedError("Not implemented")
50 |
51 | @staticmethod
52 | def get_name():
53 | raise NotImplementedError("Not implemented")
54 |
55 | @staticmethod
56 | def get_parameters():
57 | raise NotImplementedError("Not implemented")
58 |
59 | @property
60 | def parameters(self):
61 | return self._params
62 |
--------------------------------------------------------------------------------
/thumt/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | import thumt.layers.attention
5 | import thumt.layers.nn
6 | import thumt.layers.rnn_cell
7 |
--------------------------------------------------------------------------------
/thumt/layers/attention.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import math
9 | import tensorflow as tf
10 |
11 | from thumt.layers.nn import linear
12 |
13 |
14 | def add_timing_signal(x, min_timescale=1.0, max_timescale=1.0e4, name=None):
15 | """
16 | This function adds a bunch of sinusoids of different frequencies to a
17 | Tensor. See paper: `Attention is all you need'
18 |
19 | :param x: A tensor with shape [batch, length, channels]
20 | :param min_timescale: A floating point number
21 | :param max_timescale: A floating point number
22 | :param name: An optional string
23 |
24 | :returns: a Tensor the same shape as x.
25 | """
26 |
27 | with tf.name_scope(name, default_name="add_timing_signal", values=[x]):
28 | length = tf.shape(x)[1]
29 | channels = tf.shape(x)[2]
30 | position = tf.to_float(tf.range(length))
31 | num_timescales = channels // 2
32 |
33 | log_timescale_increment = (
34 | math.log(float(max_timescale) / float(min_timescale)) /
35 | (tf.to_float(num_timescales) - 1)
36 | )
37 | inv_timescales = min_timescale * tf.exp(
38 | tf.to_float(tf.range(num_timescales)) * -log_timescale_increment
39 | )
40 |
41 | scaled_time = (tf.expand_dims(position, 1) *
42 | tf.expand_dims(inv_timescales, 0))
43 | signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
44 | signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]])
45 | signal = tf.reshape(signal, [1, length, channels])
46 |
47 | return x + signal
48 |
49 |
50 | def split_heads(inputs, num_heads, name=None):
51 | """ Split heads
52 | :param inputs: A tensor with shape [batch, ..., channels]
53 | :param num_heads: An integer
54 | :param name: An optional string
55 | :returns: A tensor with shape [batch, heads, ..., channels / heads]
56 | """
57 |
58 | with tf.name_scope(name, default_name="split_heads", values=[inputs]):
59 | x = inputs
60 | n = num_heads
61 | old_shape = x.get_shape().dims
62 | ndims = x.shape.ndims
63 |
64 | last = old_shape[-1]
65 | new_shape = old_shape[:-1] + [n] + [last // n if last else None]
66 | ret = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [n, -1]], 0))
67 | ret.set_shape(new_shape)
68 | perm = [0, ndims - 1] + [i for i in range(1, ndims - 1)] + [ndims]
69 | return tf.transpose(ret, perm)
70 |
71 |
72 | def combine_heads(inputs, name=None):
73 | """ Combine heads
74 | :param inputs: A tensor with shape [batch, heads, length, channels]
75 | :param name: An optional string
76 | :returns: A tensor with shape [batch, length, heads * channels]
77 | """
78 |
79 | with tf.name_scope(name, default_name="combine_heads", values=[inputs]):
80 | x = inputs
81 | x = tf.transpose(x, [0, 2, 1, 3])
82 | old_shape = x.get_shape().dims
83 | a, b = old_shape[-2:]
84 | new_shape = old_shape[:-2] + [a * b if a and b else None]
85 | x = tf.reshape(x, tf.concat([tf.shape(x)[:-2], [-1]], 0))
86 | x.set_shape(new_shape)
87 |
88 | return x
89 |
90 |
91 | def attention_bias(inputs, mode, inf=-1e9, name=None):
92 | """ A bias tensor used in attention mechanism
93 | :param inputs: A tensor
94 | :param mode: one of "causal", "masking", "proximal" or "distance"
95 | :param inf: A floating value
96 | :param name: optional string
97 | :returns: A 4D tensor with shape [batch, heads, queries, memories]
98 | """
99 |
100 | with tf.name_scope(name, default_name="attention_bias", values=[inputs]):
101 | if mode == "causal":
102 | length = inputs
103 | lower_triangle = tf.matrix_band_part(
104 | tf.ones([length, length]), -1, 0
105 | )
106 | ret = inf * (1.0 - lower_triangle)
107 | return tf.reshape(ret, [1, 1, length, length])
108 | elif mode == "masking":
109 | mask = inputs
110 | ret = (1.0 - mask) * inf
111 | return tf.expand_dims(tf.expand_dims(ret, 1), 1)
112 | elif mode == "proximal":
113 | length = inputs
114 | r = tf.to_float(tf.range(length))
115 | diff = tf.expand_dims(r, 0) - tf.expand_dims(r, 1)
116 | m = tf.expand_dims(tf.expand_dims(-tf.log(1 + tf.abs(diff)), 0), 0)
117 | return m
118 | elif mode == "distance":
119 | length, distance = inputs
120 | distance = tf.where(distance > length, 0, distance)
121 | distance = tf.cast(distance, tf.int64)
122 | lower_triangle = tf.matrix_band_part(
123 | tf.ones([length, length]), -1, 0
124 | )
125 | mask_triangle = 1.0 - tf.matrix_band_part(
126 | tf.ones([length, length]), distance - 1, 0
127 | )
128 | ret = inf * (1.0 - lower_triangle + mask_triangle)
129 | return tf.reshape(ret, [1, 1, length, length])
130 | else:
131 | raise ValueError("Unknown mode %s" % mode)
132 |
133 |
134 | def attention(query, memories, bias, hidden_size, cache=None, reuse=None,
135 | dtype=None, scope=None):
136 | """ Standard attention layer
137 |
138 | :param query: A tensor with shape [batch, key_size]
139 | :param memories: A tensor with shape [batch, memory_size, key_size]
140 | :param bias: A tensor with shape [batch, memory_size]
141 | :param hidden_size: An integer
142 | :param cache: A dictionary of precomputed value
143 | :param reuse: A boolean value, whether to reuse the scope
144 | :param dtype: An optional instance of tf.DType
145 | :param scope: An optional string, the scope of this layer
146 | :return: A tensor with shape [batch, value_size] and
147 | a Tensor with shape [batch, memory_size]
148 | """
149 |
150 | with tf.variable_scope(scope or "attention", reuse=reuse,
151 | values=[query, memories, bias], dtype=dtype):
152 | mem_shape = tf.shape(memories)
153 | key_size = memories.get_shape().as_list()[-1]
154 |
155 | if cache is None:
156 | k = tf.reshape(memories, [-1, key_size])
157 | k = linear(k, hidden_size, False, False, scope="k_transform")
158 |
159 | if query is None:
160 | return {"key": k}
161 | else:
162 | k = cache["key"]
163 |
164 | q = linear(query, hidden_size, False, False, scope="q_transform")
165 | k = tf.reshape(k, [mem_shape[0], mem_shape[1], hidden_size])
166 |
167 | hidden = tf.tanh(q[:, None, :] + k)
168 | hidden = tf.reshape(hidden, [-1, hidden_size])
169 |
170 | # Shape: [batch, mem_size, 1]
171 | logits = linear(hidden, 1, False, False, scope="logits")
172 | logits = tf.reshape(logits, [-1, mem_shape[1]])
173 |
174 | if bias is not None:
175 | logits = logits + bias
176 |
177 | alpha = tf.nn.softmax(logits)
178 |
179 | outputs = {
180 | "value": tf.reduce_sum(alpha[:, :, None] * memories, axis=1),
181 | "weight": alpha
182 | }
183 |
184 | return outputs
185 |
186 |
187 | def additive_attention(queries, keys, values, bias, hidden_size, concat=False,
188 | keep_prob=None, dtype=None, scope=None):
189 | """ Additive attention mechanism. This layer is implemented using a
190 | one layer feed forward neural network
191 |
192 | :param queries: A tensor with shape [batch, heads, length_q, depth_k]
193 | :param keys: A tensor with shape [batch, heads, length_kv, depth_k]
194 | :param values: A tensor with shape [batch, heads, length_kv, depth_v]
195 | :param bias: A tensor
196 | :param hidden_size: An integer
197 | :param concat: A boolean value. If ``concat'' is set to True, then
198 | the computation of attention mechanism is following $tanh(W[q, k])$.
199 | When ``concat'' is set to False, the computation is following
200 | $tanh(Wq + Vk)$
201 | :param keep_prob: a scalar in [0, 1]
202 | :param dtype: An optional instance of tf.DType
203 | :param scope: An optional string, the scope of this layer
204 |
205 | :returns: A dict with the following keys:
206 | weights: A tensor with shape [batch, length_q]
207 | outputs: A tensor with shape [batch, length_q, depth_v]
208 | """
209 |
210 | with tf.variable_scope(scope, default_name="additive_attention",
211 | values=[queries, keys, values, bias], dtype=dtype):
212 | length_q = tf.shape(queries)[2]
213 | length_kv = tf.shape(keys)[2]
214 | q = tf.tile(tf.expand_dims(queries, 3), [1, 1, 1, length_kv, 1])
215 | k = tf.tile(tf.expand_dims(keys, 2), [1, 1, length_q, 1, 1])
216 |
217 | if concat:
218 | combined = tf.tanh(linear(tf.concat([q, k], axis=-1), hidden_size,
219 | True, True, name="qk_transform"))
220 | else:
221 | q = linear(queries, hidden_size, True, True, name="q_transform")
222 | k = linear(keys, hidden_size, True, True, name="key_transform")
223 | combined = tf.tanh(q + k)
224 |
225 | # shape: [batch, heads, length_q, length_kv]
226 | logits = tf.squeeze(linear(combined, 1, True, True, name="logits"),
227 | axis=-1)
228 |
229 | if bias is not None:
230 | logits += bias
231 |
232 | weights = tf.nn.softmax(logits, name="attention_weights")
233 |
234 | if keep_prob or keep_prob < 1.0:
235 | weights = tf.nn.dropout(weights, keep_prob)
236 |
237 | outputs = tf.matmul(weights, values)
238 |
239 | return {"weights": weights, "outputs": outputs}
240 |
241 |
242 | def multiplicative_attention(queries, keys, values, bias, keep_prob=None,
243 | name=None):
244 | """ Multiplicative attention mechanism. This layer is implemented using
245 | dot-product operation.
246 |
247 | :param queries: A tensor with shape [batch, heads, length_q, depth_k]
248 | :param keys: A tensor with shape [batch, heads, length_kv, depth_k]
249 | :param values: A tensor with shape [batch, heads, length_kv, depth_v]
250 | :param bias: A tensor
251 | :param keep_prob: a scalar in (0, 1]
252 | :param name: the name of this operation
253 |
254 | :returns: A dict with the following keys:
255 | weights: A tensor with shape [batch, heads, length_q, length_kv]
256 | outputs: A tensor with shape [batch, heads, length_q, depth_v]
257 | """
258 |
259 | with tf.name_scope(name, default_name="multiplicative_attention",
260 | values=[queries, keys, values, bias]):
261 | # shape: [batch, heads, length_q, length_kv]
262 | logits = tf.matmul(queries, keys, transpose_b=True)
263 |
264 | if bias is not None:
265 | logits += bias
266 |
267 | weights = tf.nn.softmax(logits, name="attention_weights")
268 |
269 | if keep_prob is not None and keep_prob < 1.0:
270 | weights = tf.nn.dropout(weights, keep_prob)
271 |
272 | outputs = tf.matmul(weights, values)
273 |
274 | return {"weights": weights, "outputs": outputs}
275 |
276 |
277 | def multihead_attention(queries, memories, bias, num_heads, key_size,
278 | value_size, output_size, keep_prob=None, output=True,
279 | state=None, dtype=None, scope=None, trainable=True):
280 | """ Multi-head scaled-dot-product attention with input/output
281 | transformations.
282 |
283 | :param queries: A tensor with shape [batch, length_q, depth_q]
284 | :param memories: A tensor with shape [batch, length_m, depth_m]
285 | :param bias: A tensor (see attention_bias)
286 | :param num_heads: An integer dividing key_size and value_size
287 | :param key_size: An integer
288 | :param value_size: An integer
289 | :param output_size: An integer
290 | :param keep_prob: A floating point number in (0, 1]
291 | :param output: Whether to use output transformation
292 | :param state: An optional dictionary used for incremental decoding
293 | :param dtype: An optional instance of tf.DType
294 | :param scope: An optional string
295 |
296 | :returns: A dict with the following keys:
297 | weights: A tensor with shape [batch, heads, length_q, length_kv]
298 | outputs: A tensor with shape [batch, length_q, depth_v]
299 | """
300 |
301 | if key_size % num_heads != 0:
302 | raise ValueError("Key size (%d) must be divisible by the number of "
303 | "attention heads (%d)." % (key_size, num_heads))
304 |
305 | if value_size % num_heads != 0:
306 | raise ValueError("Value size (%d) must be divisible by the number of "
307 | "attention heads (%d)." % (value_size, num_heads))
308 |
309 | with tf.variable_scope(scope, default_name="multihead_attention",
310 | values=[queries, memories], dtype=dtype):
311 | next_state = {}
312 |
313 | if memories is None:
314 | # self attention
315 | size = key_size * 2 + value_size
316 | combined = linear(queries, size, True, True, scope="qkv_transform", trainable=trainable)
317 | q, k, v = tf.split(combined, [key_size, key_size, value_size],
318 | axis=-1)
319 |
320 | if state is not None:
321 | k = tf.concat([state["key"], k], axis=1)
322 | v = tf.concat([state["value"], v], axis=1)
323 | next_state["key"] = k
324 | next_state["value"] = v
325 | else:
326 | q = linear(queries, key_size, True, True, scope="q_transform", trainable=trainable)
327 | combined = linear(memories, key_size + value_size, True,
328 | scope="kv_transform", trainable=trainable)
329 | k, v = tf.split(combined, [key_size, value_size], axis=-1)
330 |
331 | # split heads
332 | q = split_heads(q, num_heads)
333 | k = split_heads(k, num_heads)
334 | v = split_heads(v, num_heads)
335 |
336 | # scale query
337 | key_depth_per_head = key_size // num_heads
338 | q *= key_depth_per_head ** -0.5
339 |
340 | # attention
341 | results = multiplicative_attention(q, k, v, bias, keep_prob)
342 |
343 | # combine heads
344 | weights = results["weights"]
345 | x = combine_heads(results["outputs"])
346 |
347 | if output:
348 | outputs = linear(x, output_size, True, True,
349 | scope="output_transform", trainable=trainable)
350 | else:
351 | outputs = x
352 |
353 | outputs = {"weights": weights, "outputs": outputs}
354 |
355 | if state is not None:
356 | outputs["state"] = next_state
357 |
358 | return outputs
359 |
--------------------------------------------------------------------------------
/thumt/layers/nn.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import tensorflow as tf
9 |
10 |
11 | def linear(inputs, output_size, bias, concat=True, dtype=None, scope=None, trainable=True):
12 | """
13 | Linear layer
14 | :param inputs: A Tensor or a list of Tensors with shape [batch, input_size]
15 | :param output_size: An integer specify the output size
16 | :param bias: a boolean value indicate whether to use bias term
17 | :param concat: a boolean value indicate whether to concatenate all inputs
18 | :param dtype: an instance of tf.DType, the default value is ``tf.float32''
19 | :param scope: the scope of this layer, the default value is ``linear''
20 | :returns: a Tensor with shape [batch, output_size]
21 | :raises RuntimeError: raises ``RuntimeError'' when input sizes do not
22 | compatible with each other
23 | """
24 |
25 | with tf.variable_scope(scope, default_name="linear", values=[inputs]):
26 | if not isinstance(inputs, (list, tuple)):
27 | inputs = [inputs]
28 |
29 | input_size = [item.get_shape()[-1].value for item in inputs]
30 |
31 | if len(inputs) != len(input_size):
32 | raise RuntimeError("inputs and input_size unmatched!")
33 |
34 | output_shape = tf.concat([tf.shape(inputs[0])[:-1], [output_size]],
35 | axis=0)
36 | # Flatten to 2D
37 | inputs = [tf.reshape(inp, [-1, inp.shape[-1].value]) for inp in inputs]
38 |
39 | results = []
40 |
41 | if concat:
42 | input_size = sum(input_size)
43 | inputs = tf.concat(inputs, 1)
44 |
45 | shape = [input_size, output_size]
46 | matrix = tf.get_variable("matrix", shape, dtype=dtype, trainable=trainable)
47 | results.append(tf.matmul(inputs, matrix))
48 | else:
49 | for i in range(len(input_size)):
50 | shape = [input_size[i], output_size]
51 | name = "matrix_%d" % i
52 | matrix = tf.get_variable(name, shape, dtype=dtype, trainable=trainable)
53 | results.append(tf.matmul(inputs[i], matrix))
54 |
55 | output = tf.add_n(results)
56 |
57 | if bias:
58 | shape = [output_size]
59 | bias = tf.get_variable("bias", shape, dtype=dtype, trainable=trainable)
60 | output = tf.nn.bias_add(output, bias)
61 |
62 | output = tf.reshape(output, output_shape)
63 |
64 | return output
65 |
66 |
67 | def maxout(inputs, output_size, maxpart=2, use_bias=True, concat=True,
68 | dtype=None, scope=None):
69 | """
70 | Maxout layer
71 | :param inputs: see the corresponding description of ``linear''
72 | :param output_size: see the corresponding description of ``linear''
73 | :param maxpart: an integer, the default value is 2
74 | :param use_bias: a boolean value indicate whether to use bias term
75 | :param concat: concat all tensors if inputs is a list of tensors
76 | :param dtype: an optional instance of tf.Dtype
77 | :param scope: the scope of this layer, the default value is ``maxout''
78 | :returns: a Tensor with shape [batch, output_size]
79 | :raises RuntimeError: see the corresponding description of ``linear''
80 | """
81 |
82 | candidate = linear(inputs, output_size * maxpart, use_bias, concat,
83 | dtype=dtype, scope=scope or "maxout")
84 | shape = tf.concat([tf.shape(candidate)[:-1], [output_size, maxpart]],
85 | axis=0)
86 | value = tf.reshape(candidate, shape)
87 | output = tf.reduce_max(value, -1)
88 |
89 | return output
90 |
91 |
92 | def layer_norm(inputs, epsilon=1e-6, dtype=None, scope=None, trainable=True):
93 | """
94 | Layer Normalization
95 | :param inputs: A Tensor of shape [..., channel_size]
96 | :param epsilon: A floating number
97 | :param dtype: An optional instance of tf.DType
98 | :param scope: An optional string
99 | :returns: A Tensor with the same shape as inputs
100 | """
101 | with tf.variable_scope(scope, default_name="layer_norm", values=[inputs],
102 | dtype=dtype):
103 | channel_size = inputs.get_shape().as_list()[-1]
104 |
105 | scale = tf.get_variable("scale", shape=[channel_size],
106 | initializer=tf.ones_initializer(), trainable=trainable)
107 |
108 | offset = tf.get_variable("offset", shape=[channel_size],
109 | initializer=tf.zeros_initializer(), trainable=trainable)
110 |
111 | mean = tf.reduce_mean(inputs, -1, True)
112 | variance = tf.reduce_mean(tf.square(inputs - mean), -1, True)
113 |
114 | norm_inputs = (inputs - mean) * tf.rsqrt(variance + epsilon)
115 |
116 | return norm_inputs * scale + offset
117 |
118 |
119 | def smoothed_softmax_cross_entropy_with_logits(**kwargs):
120 | logits = kwargs.get("logits")
121 | labels = kwargs.get("labels")
122 | smoothing = kwargs.get("smoothing") or 0.0
123 | normalize = kwargs.get("normalize")
124 | scope = kwargs.get("scope")
125 |
126 | if logits is None or labels is None:
127 | raise ValueError("Both logits and labels must be provided")
128 |
129 | with tf.name_scope(scope or "smoothed_softmax_cross_entropy_with_logits",
130 | values=[logits, labels]):
131 |
132 | labels = tf.reshape(labels, [-1])
133 |
134 | if not smoothing:
135 | ce = tf.nn.sparse_softmax_cross_entropy_with_logits(
136 | logits=logits,
137 | labels=labels
138 | )
139 | return ce
140 |
141 | # label smoothing
142 | vocab_size = tf.shape(logits)[1]
143 |
144 | n = tf.to_float(vocab_size - 1)
145 | p = 1.0 - smoothing
146 | q = smoothing / n
147 |
148 | soft_targets = tf.one_hot(tf.cast(labels, tf.int32), depth=vocab_size,
149 | on_value=p, off_value=q)
150 | xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
151 | labels=soft_targets)
152 |
153 | if normalize is False:
154 | return xentropy
155 |
156 | # Normalizing constant is the best cross-entropy value with soft
157 | # targets. We subtract it just for readability, makes no difference on
158 | # learning
159 | normalizing = -(p * tf.log(p) + n * q * tf.log(q + 1e-20))
160 |
161 | return xentropy - normalizing
162 |
--------------------------------------------------------------------------------
/thumt/layers/rnn_cell.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import tensorflow as tf
9 |
10 | from .nn import linear
11 |
12 |
13 | class LegacyGRUCell(tf.nn.rnn_cell.RNNCell):
14 | """ Groundhog's implementation of GRUCell
15 |
16 | :param num_units: int, The number of units in the RNN cell.
17 | :param reuse: (optional) Python boolean describing whether to reuse
18 | variables in an existing scope. If not `True`, and the existing
19 | scope already has the given variables, an error is raised.
20 | """
21 |
22 | def __init__(self, num_units, reuse=None):
23 | super(LegacyGRUCell, self).__init__(_reuse=reuse)
24 | self._num_units = num_units
25 |
26 | def __call__(self, inputs, state, scope=None):
27 | with tf.variable_scope(scope, default_name="gru_cell",
28 | values=[inputs, state]):
29 | if not isinstance(inputs, (list, tuple)):
30 | inputs = [inputs]
31 |
32 | all_inputs = list(inputs) + [state]
33 | r = tf.nn.sigmoid(linear(all_inputs, self._num_units, False, False,
34 | scope="reset_gate"))
35 | u = tf.nn.sigmoid(linear(all_inputs, self._num_units, False, False,
36 | scope="update_gate"))
37 | all_inputs = list(inputs) + [r * state]
38 | c = linear(all_inputs, self._num_units, True, False,
39 | scope="candidate")
40 |
41 | new_state = (1.0 - u) * state + u * tf.tanh(c)
42 |
43 | return new_state, new_state
44 |
45 | @property
46 | def state_size(self):
47 | return self._num_units
48 |
49 | @property
50 | def output_size(self):
51 | return self._num_units
52 |
53 |
54 | class StateToOutputWrapper(tf.nn.rnn_cell.RNNCell):
55 | """ Copy state to the output of RNNCell so that all states can be obtained
56 | when using tf.nn.dynamic_rnn
57 |
58 | :param cell: An instance of tf.nn.rnn_cell.RNNCell
59 | :param reuse: (optional) Python boolean describing whether to reuse
60 | variables in an existing scope. If not `True`, and the existing
61 | scope already has the given variables, an error is raised.
62 | """
63 |
64 | def __init__(self, cell, reuse=None):
65 | super(StateToOutputWrapper, self).__init__(_reuse=reuse)
66 | self._cell = cell
67 |
68 | def __call__(self, inputs, state, scope=None):
69 | output, new_state = self._cell(inputs, state, scope=scope)
70 |
71 | return (output, new_state), new_state
72 |
73 | @property
74 | def state_size(self):
75 | return self._cell.state_size
76 |
77 | @property
78 | def output_size(self):
79 | return tuple([self._cell.output_size, self.state_size])
80 |
81 |
82 | class AttentionWrapper(tf.nn.rnn_cell.RNNCell):
83 | """ Wrap an RNNCell with attention mechanism
84 |
85 | :param cell: An instance of tf.nn.rnn_cell.RNNCell
86 | :param memory: A tensor with shape [batch, mem_size, mem_dim]
87 | :param bias: A tensor with shape [batch, mem_size]
88 | :param attention_fn: A callable function with signature
89 | (inputs, state, memory, bias) -> (output, state, weight, value)
90 | :param output_weight: Whether to output attention weights
91 | :param output_value: Whether to output attention values
92 | :param reuse: (optional) Python boolean describing whether to reuse
93 | variables in an existing scope. If not `True`, and the existing
94 | scope already has the given variables, an error is raised.
95 | """
96 |
97 | def __init__(self, cell, memory, bias, attention_fn, output_weight=False,
98 | output_value=False, reuse=None):
99 | super(AttentionWrapper, self).__init__(_reuse=reuse)
100 | memory.shape.assert_has_rank(3)
101 | self._cell = cell
102 | self._memory = memory
103 | self._bias = bias
104 | self._attention_fn = attention_fn
105 | self._output_weight = output_weight
106 | self._output_value = output_value
107 |
108 | def __call__(self, inputs, state, scope=None):
109 | outputs = self._attention_fn(inputs, state, self._memory, self._bias)
110 | cell_inputs, cell_state, weight, value = outputs
111 | cell_output, new_state = self._cell(cell_inputs, cell_state,
112 | scope=scope)
113 |
114 | if not self._output_weight and not self._output_value:
115 | return cell_output, new_state
116 |
117 | new_output = [cell_output]
118 |
119 | if self._output_weight:
120 | new_output.append(weights)
121 |
122 | if self._output_value:
123 | new_output.append(value)
124 |
125 | return tuple(new_output), new_state
126 |
127 | @property
128 | def state_size(self):
129 | return self._cell.state_size
130 |
131 | @property
132 | def output_size(self):
133 | if not self._output_weight and not self._output_value:
134 | return self._cell.output_size
135 |
136 | new_output_size = [self._cell.output_size]
137 |
138 | if self._output_weight:
139 | new_output_size.append(self._memory.shape[1])
140 |
141 | if self._output_value:
142 | new_output_size.append(self._memory.shape[2].value)
143 |
144 | return tuple(new_output_size)
145 |
--------------------------------------------------------------------------------
/thumt/models/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import thumt.models.seq2seq
9 | import thumt.models.rnnsearch
10 | import thumt.models.transformer
11 | import thumt.models.contextual_transformer
12 |
13 |
14 | def get_model(name):
15 | name = name.lower()
16 |
17 | if name == "rnnsearch":
18 | return thumt.models.rnnsearch.RNNsearch
19 | elif name == "seq2seq":
20 | return thumt.models.seq2seq.Seq2Seq
21 | elif name == "transformer":
22 | return thumt.models.transformer.Transformer
23 | elif name == "contextual_transformer":
24 | return thumt.models.contextual_transformer.Contextual_Transformer
25 | else:
26 | raise LookupError("Unknown model %s" % name)
27 |
--------------------------------------------------------------------------------
/thumt/models/rnnsearch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import copy
9 |
10 | import tensorflow as tf
11 | import thumt.interface as interface
12 | import thumt.layers as layers
13 |
14 |
15 | def _copy_through(time, length, output, new_output):
16 | copy_cond = (time >= length)
17 | return tf.where(copy_cond, output, new_output)
18 |
19 |
20 | def _gru_encoder(cell, inputs, sequence_length, initial_state, dtype=None):
21 | # Assume that the underlying cell is GRUCell-like
22 | output_size = cell.output_size
23 | dtype = dtype or inputs.dtype
24 |
25 | batch = tf.shape(inputs)[0]
26 | time_steps = tf.shape(inputs)[1]
27 |
28 | zero_output = tf.zeros([batch, output_size], dtype)
29 |
30 | if initial_state is None:
31 | initial_state = cell.zero_state(batch, dtype)
32 |
33 | input_ta = tf.TensorArray(dtype, time_steps,
34 | tensor_array_name="input_array")
35 | output_ta = tf.TensorArray(dtype, time_steps,
36 | tensor_array_name="output_array")
37 | input_ta = input_ta.unstack(tf.transpose(inputs, [1, 0, 2]))
38 |
39 | def loop_func(t, out_ta, state):
40 | inp_t = input_ta.read(t)
41 | cell_output, new_state = cell(inp_t, state)
42 | cell_output = _copy_through(t, sequence_length, zero_output,
43 | cell_output)
44 | new_state = _copy_through(t, sequence_length, state, new_state)
45 | out_ta = out_ta.write(t, cell_output)
46 | return t + 1, out_ta, new_state
47 |
48 | time = tf.constant(0, dtype=tf.int32, name="time")
49 | loop_vars = (time, output_ta, initial_state)
50 |
51 | outputs = tf.while_loop(lambda t, *_: t < time_steps, loop_func,
52 | loop_vars, parallel_iterations=32,
53 | swap_memory=True)
54 |
55 | output_final_ta = outputs[1]
56 | final_state = outputs[2]
57 |
58 | all_output = output_final_ta.stack()
59 | all_output.set_shape([None, None, output_size])
60 | all_output = tf.transpose(all_output, [1, 0, 2])
61 |
62 | return all_output, final_state
63 |
64 |
65 | def _encoder(cell_fw, cell_bw, inputs, sequence_length, dtype=None,
66 | scope=None):
67 | with tf.variable_scope(scope or "encoder",
68 | values=[inputs, sequence_length]):
69 | inputs_fw = inputs
70 | inputs_bw = tf.reverse_sequence(inputs, sequence_length,
71 | batch_axis=0, seq_axis=1)
72 |
73 | with tf.variable_scope("forward"):
74 | output_fw, state_fw = _gru_encoder(cell_fw, inputs_fw,
75 | sequence_length, None,
76 | dtype=dtype)
77 |
78 | with tf.variable_scope("backward"):
79 | output_bw, state_bw = _gru_encoder(cell_bw, inputs_bw,
80 | sequence_length, None,
81 | dtype=dtype)
82 | output_bw = tf.reverse_sequence(output_bw, sequence_length,
83 | batch_axis=0, seq_axis=1)
84 |
85 | results = {
86 | "annotation": tf.concat([output_fw, output_bw], axis=2),
87 | "outputs": {
88 | "forward": output_fw,
89 | "backward": output_bw
90 | },
91 | "final_states": {
92 | "forward": state_fw,
93 | "backward": state_bw
94 | }
95 | }
96 |
97 | return results
98 |
99 |
100 | def _decoder(cell, inputs, memory, sequence_length, initial_state, dtype=None,
101 | scope=None):
102 | # Assume that the underlying cell is GRUCell-like
103 | batch = tf.shape(inputs)[0]
104 | time_steps = tf.shape(inputs)[1]
105 | dtype = dtype or inputs.dtype
106 | output_size = cell.output_size
107 | zero_output = tf.zeros([batch, output_size], dtype)
108 | zero_value = tf.zeros([batch, memory.shape[-1].value], dtype)
109 |
110 | with tf.variable_scope(scope or "decoder", dtype=dtype):
111 | inputs = tf.transpose(inputs, [1, 0, 2])
112 | mem_mask = tf.sequence_mask(sequence_length["source"],
113 | maxlen=tf.shape(memory)[1],
114 | dtype=tf.float32)
115 | bias = layers.attention.attention_bias(mem_mask, "masking")
116 | bias = tf.squeeze(bias, axis=[1, 2])
117 | cache = layers.attention.attention(None, memory, None, output_size)
118 |
119 | input_ta = tf.TensorArray(tf.float32, time_steps,
120 | tensor_array_name="input_array")
121 | output_ta = tf.TensorArray(tf.float32, time_steps,
122 | tensor_array_name="output_array")
123 | value_ta = tf.TensorArray(tf.float32, time_steps,
124 | tensor_array_name="value_array")
125 | alpha_ta = tf.TensorArray(tf.float32, time_steps,
126 | tensor_array_name="alpha_array")
127 | input_ta = input_ta.unstack(inputs)
128 | initial_state = layers.nn.linear(initial_state, output_size, True,
129 | False, scope="s_transform")
130 | initial_state = tf.tanh(initial_state)
131 |
132 | def loop_func(t, out_ta, att_ta, val_ta, state, cache_key):
133 | inp_t = input_ta.read(t)
134 | results = layers.attention.attention(state, memory, bias,
135 | output_size,
136 | cache={"key": cache_key})
137 | alpha = results["weight"]
138 | context = results["value"]
139 | cell_input = [inp_t, context]
140 | cell_output, new_state = cell(cell_input, state)
141 | cell_output = _copy_through(t, sequence_length["target"],
142 | zero_output, cell_output)
143 | new_state = _copy_through(t, sequence_length["target"], state,
144 | new_state)
145 | new_value = _copy_through(t, sequence_length["target"], zero_value,
146 | context)
147 |
148 | out_ta = out_ta.write(t, cell_output)
149 | att_ta = att_ta.write(t, alpha)
150 | val_ta = val_ta.write(t, new_value)
151 | cache_key = tf.identity(cache_key)
152 | return t + 1, out_ta, att_ta, val_ta, new_state, cache_key
153 |
154 | time = tf.constant(0, dtype=tf.int32, name="time")
155 | loop_vars = (time, output_ta, alpha_ta, value_ta, initial_state,
156 | cache["key"])
157 |
158 | outputs = tf.while_loop(lambda t, *_: t < time_steps,
159 | loop_func, loop_vars,
160 | parallel_iterations=32,
161 | swap_memory=True)
162 |
163 | output_final_ta = outputs[1]
164 | value_final_ta = outputs[3]
165 |
166 | final_output = output_final_ta.stack()
167 | final_output.set_shape([None, None, output_size])
168 | final_output = tf.transpose(final_output, [1, 0, 2])
169 |
170 | final_value = value_final_ta.stack()
171 | final_value.set_shape([None, None, memory.shape[-1].value])
172 | final_value = tf.transpose(final_value, [1, 0, 2])
173 |
174 | result = {
175 | "outputs": final_output,
176 | "values": final_value,
177 | "initial_state": initial_state
178 | }
179 |
180 | return result
181 |
182 |
183 | def model_graph(features, mode, params):
184 | src_vocab_size = len(params.vocabulary["source"])
185 | tgt_vocab_size = len(params.vocabulary["target"])
186 |
187 | with tf.variable_scope("source_embedding"):
188 | src_emb = tf.get_variable("embedding",
189 | [src_vocab_size, params.embedding_size])
190 | src_bias = tf.get_variable("bias", [params.embedding_size])
191 | src_inputs = tf.nn.embedding_lookup(src_emb, features["source"])
192 |
193 | with tf.variable_scope("target_embedding"):
194 | tgt_emb = tf.get_variable("embedding",
195 | [tgt_vocab_size, params.embedding_size])
196 | tgt_bias = tf.get_variable("bias", [params.embedding_size])
197 | tgt_inputs = tf.nn.embedding_lookup(tgt_emb, features["target"])
198 |
199 | src_inputs = tf.nn.bias_add(src_inputs, src_bias)
200 | tgt_inputs = tf.nn.bias_add(tgt_inputs, tgt_bias)
201 |
202 | if params.dropout and not params.use_variational_dropout:
203 | src_inputs = tf.nn.dropout(src_inputs, 1.0 - params.dropout)
204 | tgt_inputs = tf.nn.dropout(tgt_inputs, 1.0 - params.dropout)
205 |
206 | # encoder
207 | cell_fw = layers.rnn_cell.LegacyGRUCell(params.hidden_size)
208 | cell_bw = layers.rnn_cell.LegacyGRUCell(params.hidden_size)
209 |
210 | if params.use_variational_dropout:
211 | cell_fw = tf.nn.rnn_cell.DropoutWrapper(
212 | cell_fw,
213 | input_keep_prob=1.0 - params.dropout,
214 | output_keep_prob=1.0 - params.dropout,
215 | state_keep_prob=1.0 - params.dropout,
216 | variational_recurrent=True,
217 | input_size=params.embedding_size,
218 | dtype=tf.float32
219 | )
220 | cell_bw = tf.nn.rnn_cell.DropoutWrapper(
221 | cell_bw,
222 | input_keep_prob=1.0 - params.dropout,
223 | output_keep_prob=1.0 - params.dropout,
224 | state_keep_prob=1.0 - params.dropout,
225 | variational_recurrent=True,
226 | input_size=params.embedding_size,
227 | dtype=tf.float32
228 | )
229 |
230 | encoder_output = _encoder(cell_fw, cell_bw, src_inputs,
231 | features["source_length"])
232 |
233 | # decoder
234 | cell = layers.rnn_cell.LegacyGRUCell(params.hidden_size)
235 |
236 | if params.use_variational_dropout:
237 | cell = tf.nn.rnn_cell.DropoutWrapper(
238 | cell,
239 | input_keep_prob=1.0 - params.dropout,
240 | output_keep_prob=1.0 - params.dropout,
241 | state_keep_prob=1.0 - params.dropout,
242 | variational_recurrent=True,
243 | # input + context
244 | input_size=params.embedding_size + 2 * params.hidden_size,
245 | dtype=tf.float32
246 | )
247 |
248 | length = {
249 | "source": features["source_length"],
250 | "target": features["target_length"]
251 | }
252 | initial_state = encoder_output["final_states"]["backward"]
253 | decoder_output = _decoder(cell, tgt_inputs, encoder_output["annotation"],
254 | length, initial_state)
255 |
256 | # Shift left
257 | shifted_tgt_inputs = tf.pad(tgt_inputs, [[0, 0], [1, 0], [0, 0]])
258 | shifted_tgt_inputs = shifted_tgt_inputs[:, :-1, :]
259 |
260 | all_outputs = tf.concat(
261 | [
262 | tf.expand_dims(decoder_output["initial_state"], axis=1),
263 | decoder_output["outputs"],
264 | ],
265 | axis=1
266 | )
267 | shifted_outputs = all_outputs[:, :-1, :]
268 |
269 | maxout_features = [
270 | shifted_tgt_inputs,
271 | shifted_outputs,
272 | decoder_output["values"]
273 | ]
274 | maxout_size = params.hidden_size // params.maxnum
275 |
276 | if mode is "infer":
277 | # Special case for non-incremental decoding
278 | maxout_features = [
279 | shifted_tgt_inputs[:, -1, :],
280 | shifted_outputs[:, -1, :],
281 | decoder_output["values"][:, -1, :]
282 | ]
283 | maxhid = layers.nn.maxout(maxout_features, maxout_size, params.maxnum,
284 | concat=False)
285 | readout = layers.nn.linear(maxhid, params.embedding_size, False,
286 | False, scope="deepout")
287 |
288 | # Prediction
289 | logits = layers.nn.linear(readout, tgt_vocab_size, True, False,
290 | scope="softmax")
291 |
292 | return tf.nn.log_softmax(logits)
293 |
294 | maxhid = layers.nn.maxout(maxout_features, maxout_size, params.maxnum,
295 | concat=False)
296 | readout = layers.nn.linear(maxhid, params.embedding_size, False, False,
297 | scope="deepout")
298 |
299 | if params.dropout and not params.use_variational_dropout:
300 | readout = tf.nn.dropout(readout, 1.0 - params.dropout)
301 |
302 | # Prediction
303 | logits = layers.nn.linear(readout, tgt_vocab_size, True, False,
304 | scope="softmax")
305 | logits = tf.reshape(logits, [-1, tgt_vocab_size])
306 | labels = features["target"]
307 |
308 | ce = layers.nn.smoothed_softmax_cross_entropy_with_logits(
309 | logits=logits,
310 | labels=labels,
311 | smoothing=params.label_smoothing,
312 | normalize=True
313 | )
314 |
315 | ce = tf.reshape(ce, tf.shape(labels))
316 | tgt_mask = tf.to_float(
317 | tf.sequence_mask(
318 | features["target_length"],
319 | maxlen=tf.shape(features["target"])[1]
320 | )
321 | )
322 |
323 | if mode == "eval":
324 | return -tf.reduce_sum(ce * tgt_mask, axis=1)
325 |
326 | loss = tf.reduce_sum(ce * tgt_mask) / tf.reduce_sum(tgt_mask)
327 |
328 | return loss
329 |
330 |
331 | class RNNsearch(interface.NMTModel):
332 |
333 | def __init__(self, params, scope="rnnsearch"):
334 | super(RNNsearch, self).__init__(params=params, scope=scope)
335 |
336 | def get_training_func(self, initializer):
337 | def training_fn(features, params=None, reuse=None):
338 | if params is None:
339 | params = self.parameters
340 | with tf.variable_scope(self._scope, initializer=initializer,
341 | reuse=reuse):
342 | loss = model_graph(features, "train", params)
343 | return loss
344 |
345 | return training_fn
346 |
347 | def get_evaluation_func(self):
348 | def evaluation_fn(features, params=None):
349 | if params is None:
350 | params = copy.copy(self.parameters)
351 | else:
352 | params = copy.copy(params)
353 |
354 | params.dropout = 0.0
355 | params.use_variational_dropout = False
356 | params.label_smoothing = 0.0
357 |
358 | with tf.variable_scope(self._scope):
359 | score = model_graph(features, "eval", params)
360 |
361 | return score
362 |
363 | return evaluation_fn
364 |
365 | def get_inference_func(self):
366 | def inference_fn(features, params=None):
367 | if params is None:
368 | params = copy.copy(self.parameters)
369 | else:
370 | params = copy.copy(params)
371 |
372 | params.dropout = 0.0
373 | params.use_variational_dropout = False
374 | params.label_smoothing = 0.0
375 |
376 | with tf.variable_scope(self._scope):
377 | log_prob = model_graph(features, "infer", params)
378 |
379 | return log_prob
380 |
381 | return inference_fn
382 |
383 | @staticmethod
384 | def get_name():
385 | return "rnnsearch"
386 |
387 | @staticmethod
388 | def get_parameters():
389 | params = tf.contrib.training.HParams(
390 | # vocabulary
391 | pad="",
392 | unk="",
393 | eos="",
394 | bos="",
395 | append_eos=False,
396 | # model
397 | rnn_cell="LegacyGRUCell",
398 | embedding_size=620,
399 | hidden_size=1000,
400 | maxnum=2,
401 | # regularization
402 | dropout=0.2,
403 | use_variational_dropout=False,
404 | label_smoothing=0.1,
405 | constant_batch_size=True,
406 | batch_size=128,
407 | max_length=60,
408 | clip_grad_norm=5.0
409 | )
410 |
411 | return params
412 |
--------------------------------------------------------------------------------
/thumt/models/seq2seq.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import copy
9 |
10 | import tensorflow as tf
11 | import thumt.interface as interface
12 | import thumt.layers as layers
13 |
14 |
15 | def model_graph(features, mode, params):
16 | src_vocab_size = len(params.vocabulary["source"])
17 | tgt_vocab_size = len(params.vocabulary["target"])
18 |
19 | src_seq = features["source"]
20 | tgt_seq = features["target"]
21 |
22 | if params.reverse_source:
23 | src_seq = tf.reverse_sequence(src_seq, seq_dim=1,
24 | seq_lengths=features["source_length"])
25 |
26 | with tf.device("/cpu:0"):
27 | with tf.variable_scope("source_embedding"):
28 | src_emb = tf.get_variable("embedding",
29 | [src_vocab_size, params.embedding_size])
30 | src_bias = tf.get_variable("bias", [params.embedding_size])
31 | src_inputs = tf.nn.embedding_lookup(src_emb, src_seq)
32 |
33 | with tf.variable_scope("target_embedding"):
34 | tgt_emb = tf.get_variable("embedding",
35 | [tgt_vocab_size, params.embedding_size])
36 | tgt_bias = tf.get_variable("bias", [params.embedding_size])
37 | tgt_inputs = tf.nn.embedding_lookup(tgt_emb, tgt_seq)
38 |
39 | src_inputs = tf.nn.bias_add(src_inputs, src_bias)
40 | tgt_inputs = tf.nn.bias_add(tgt_inputs, tgt_bias)
41 |
42 | if params.dropout and not params.use_variational_dropout:
43 | src_inputs = tf.nn.dropout(src_inputs, 1.0 - params.dropout)
44 | tgt_inputs = tf.nn.dropout(tgt_inputs, 1.0 - params.dropout)
45 |
46 | cell_enc = []
47 | cell_dec = []
48 |
49 | for _ in range(params.num_hidden_layers):
50 | if params.rnn_cell == "LSTMCell":
51 | cell_e = tf.nn.rnn_cell.BasicLSTMCell(params.hidden_size)
52 | cell_d = tf.nn.rnn_cell.BasicLSTMCell(params.hidden_size)
53 | elif params.rnn_cell == "GRUCell":
54 | cell_e = tf.nn.rnn_cell.GRUCell(params.hidden_size)
55 | cell_d = tf.nn.rnn_cell.GRUCell(params.hidden_size)
56 | else:
57 | raise ValueError("%s not supported" % params.rnn_cell)
58 |
59 | cell_e = tf.nn.rnn_cell.DropoutWrapper(
60 | cell_e,
61 | output_keep_prob=1.0 - params.dropout,
62 | variational_recurrent=params.use_variational_dropout,
63 | input_size=params.embedding_size,
64 | dtype=tf.float32
65 | )
66 | cell_d = tf.nn.rnn_cell.DropoutWrapper(
67 | cell_d,
68 | output_keep_prob=1.0 - params.dropout,
69 | variational_recurrent=params.use_variational_dropout,
70 | input_size=params.embedding_size,
71 | dtype=tf.float32
72 | )
73 |
74 | if params.use_residual:
75 | cell_e = tf.nn.rnn_cell.ResidualWrapper(cell_e)
76 | cell_d = tf.nn.rnn_cell.ResidualWrapper(cell_d)
77 |
78 | cell_enc.append(cell_e)
79 | cell_dec.append(cell_d)
80 |
81 | cell_enc = tf.nn.rnn_cell.MultiRNNCell(cell_enc)
82 | cell_dec = tf.nn.rnn_cell.MultiRNNCell(cell_dec)
83 |
84 | with tf.variable_scope("encoder"):
85 | _, final_state = tf.nn.dynamic_rnn(cell_enc, src_inputs,
86 | features["source_length"],
87 | dtype=tf.float32)
88 | # Shift left
89 | shifted_tgt_inputs = tf.pad(tgt_inputs, [[0, 0], [1, 0], [0, 0]])
90 | shifted_tgt_inputs = shifted_tgt_inputs[:, :-1, :]
91 |
92 | with tf.variable_scope("decoder"):
93 | outputs, _ = tf.nn.dynamic_rnn(cell_dec, shifted_tgt_inputs,
94 | features["target_length"],
95 | initial_state=final_state)
96 |
97 | if params.dropout:
98 | outputs = tf.nn.dropout(outputs, 1.0 - params.dropout)
99 |
100 | if mode == "infer":
101 | # Prediction
102 | logits = layers.nn.linear(outputs[:, -1, :], tgt_vocab_size, True,
103 | scope="softmax")
104 |
105 | return tf.nn.log_softmax(logits)
106 |
107 | # Prediction
108 | logits = layers.nn.linear(outputs, tgt_vocab_size, True, scope="softmax")
109 | logits = tf.reshape(logits, [-1, tgt_vocab_size])
110 | labels = features["target"]
111 |
112 | ce = layers.nn.smoothed_softmax_cross_entropy_with_logits(
113 | logits=logits,
114 | labels=labels,
115 | smoothing=params.label_smoothing,
116 | normalize=True
117 | )
118 |
119 | ce = tf.reshape(ce, tf.shape(labels))
120 | tgt_mask = tf.to_float(
121 | tf.sequence_mask(
122 | features["target_length"],
123 | maxlen=tf.shape(features["target"])[1]
124 | )
125 | )
126 |
127 | if mode == "eval":
128 | return -tf.reduce_sum(ce * tgt_mask, axis=1)
129 |
130 | loss = tf.reduce_sum(ce * tgt_mask) / tf.reduce_sum(tgt_mask)
131 |
132 | return loss
133 |
134 |
135 | class Seq2Seq(interface.NMTModel):
136 |
137 | def __init__(self, params, scope="seq2seq"):
138 | super(Seq2Seq, self).__init__(params=params, scope=scope)
139 |
140 | def get_training_func(self, initializer):
141 | def training_fn(features, params=None, reuse=None):
142 | if params is None:
143 | params = self.parameters
144 | with tf.variable_scope(self._scope, initializer=initializer,
145 | reuse=reuse):
146 | loss = model_graph(features, "train", params)
147 | return loss
148 |
149 | return training_fn
150 |
151 | def get_evaluation_func(self):
152 | def evaluation_fn(features, params=None):
153 | if params is None:
154 | params = copy.copy(self.parameters)
155 | else:
156 | params = copy.copy(params)
157 | params.dropout = 0.0
158 | params.use_variational_dropout = False
159 | params.label_smoothing = 0.0
160 |
161 | with tf.variable_scope(self._scope):
162 | score = model_graph(features, "eval", params)
163 |
164 | return score
165 |
166 | return evaluation_fn
167 |
168 | def get_inference_func(self):
169 | def inference_fn(features, params=None):
170 | if params is None:
171 | params = copy.copy(self.parameters)
172 | else:
173 | params = copy.copy(params)
174 | params.dropout = 0.0
175 | params.use_variational_dropout = False
176 | params.label_smoothing = 0.0
177 |
178 | with tf.variable_scope(self._scope):
179 | logits = model_graph(features, "infer", params)
180 |
181 | return logits
182 |
183 | return inference_fn
184 |
185 | @staticmethod
186 | def get_name():
187 | return "seq2seq"
188 |
189 | @staticmethod
190 | def get_parameters():
191 | params = tf.contrib.training.HParams(
192 | # vocabulary
193 | pad="",
194 | bos="",
195 | eos="",
196 | unk="",
197 | append_eos=False,
198 | # model
199 | rnn_cell="LSTMCell",
200 | embedding_size=1000,
201 | hidden_size=1000,
202 | num_hidden_layers=4,
203 | # regularization
204 | dropout=0.2,
205 | use_variational_dropout=False,
206 | label_smoothing=0.1,
207 | constant_batch_size=True,
208 | batch_size=128,
209 | max_length=80,
210 | reverse_source=True,
211 | use_residual=True,
212 | clip_grad_norm=5.0
213 | )
214 |
215 | return params
216 |
--------------------------------------------------------------------------------
/thumt/scripts/build_vocab.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2018 The THUMT Authors
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import argparse
10 | import collections
11 |
12 |
13 | def count_words(filename):
14 | counter = collections.Counter()
15 |
16 | with open(filename, "r") as fd:
17 | for line in fd:
18 | words = line.strip().split()
19 | counter.update(words)
20 |
21 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
22 | words, counts = list(zip(*count_pairs))
23 |
24 | return words, counts
25 |
26 |
27 | def control_symbols(string):
28 | if not string:
29 | return []
30 | else:
31 | return string.strip().split(",")
32 |
33 |
34 | def save_vocab(name, vocab):
35 | if name.split(".")[-1] != "txt":
36 | name = name + ".txt"
37 |
38 | pairs = sorted(vocab.items(), key=lambda x: (x[1], x[0]))
39 | words, ids = list(zip(*pairs))
40 |
41 | with open(name, "w") as f:
42 | for word in words:
43 | f.write(word + "\n")
44 |
45 |
46 | def parse_args():
47 | parser = argparse.ArgumentParser(description="Create vocabulary")
48 |
49 | parser.add_argument("corpus", help="input corpus")
50 | parser.add_argument("output", default="vocab.txt",
51 | help="Output vocabulary name")
52 | parser.add_argument("--limit", default=0, type=int, help="Vocabulary size")
53 | parser.add_argument("--control", type=str, default=",,",
54 | help="Add control symbols to vocabulary. "
55 | "Control symbols are separated by comma.")
56 |
57 | return parser.parse_args()
58 |
59 |
60 | def main(args):
61 | vocab = {}
62 | limit = args.limit
63 | count = 0
64 |
65 | words, counts = count_words(args.corpus)
66 | ctrl_symbols = control_symbols(args.control)
67 |
68 | for sym in ctrl_symbols:
69 | vocab[sym] = len(vocab)
70 |
71 | for word, freq in zip(words, counts):
72 | if limit and len(vocab) >= limit:
73 | break
74 |
75 | if word in vocab:
76 | print("Warning: found duplicate token %s, ignored" % word)
77 | continue
78 |
79 | vocab[word] = len(vocab)
80 | count += freq
81 |
82 | save_vocab(args.output, vocab)
83 |
84 | print("Total words: %d" % sum(counts))
85 | print("Unique words: %d" % len(words))
86 | print("Vocabulary coverage: %4.2f%%" % (100.0 * count / sum(counts)))
87 |
88 |
89 | if __name__ == "__main__":
90 | main(parse_args())
91 |
--------------------------------------------------------------------------------
/thumt/scripts/change.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2018 The THUMT Authors
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import argparse
10 | import operator
11 | import os
12 |
13 | import numpy as np
14 | import tensorflow as tf
15 |
16 |
17 | def parseargs():
18 | msg = "Average checkpoints"
19 | usage = "average.py [] [-h | --help]"
20 | parser = argparse.ArgumentParser(description=msg, usage=usage)
21 |
22 | parser.add_argument("--model", type=str, required=True,
23 | help="checkpoint dir")
24 | parser.add_argument("--output", type=str, help="output path")
25 |
26 | return parser.parse_args()
27 |
28 | def main(_):
29 | tf.logging.set_verbosity(tf.logging.INFO)
30 |
31 | var_list = tf.contrib.framework.list_variables(FLAGS.model)
32 | var_values, var_dtypes = {}, {}
33 | model_from = "transformer"
34 | model_to = "contextual_transformer"
35 |
36 | for (name, shape) in var_list:
37 | if True:#not name.startswith("global_step") and not 'Adam' in name:
38 | name = name.replace(model_from, model_to)
39 | var_values[name] = np.zeros(shape)
40 | print(name)
41 |
42 | reader = tf.contrib.framework.load_checkpoint(FLAGS.model)
43 | for name in var_values:
44 | name_ori = name.replace(model_to, model_from)
45 | tensor = reader.get_tensor(name_ori)
46 | var_dtypes[name] = tensor.dtype
47 | var_values[name] += tensor
48 | tf.logging.info("Read from %s", FLAGS.model)
49 |
50 | tf_vars = [
51 | tf.get_variable(name, shape=var_values[name].shape,
52 | dtype=var_dtypes[name]) for name in var_values
53 | ]
54 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
55 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
56 | global_step = tf.Variable(0, name="global_step", trainable=False,
57 | dtype=tf.int64)
58 | saver = tf.train.Saver(tf.global_variables())
59 |
60 | with tf.Session() as sess:
61 | sess.run(tf.global_variables_initializer())
62 | for p, assign_op, (name, value) in zip(placeholders, assign_ops,
63 | var_values.iteritems()):
64 | sess.run(assign_op, {p: value})
65 | saved_name = os.path.join(FLAGS.output, "new")
66 | saver.save(sess, saved_name, global_step=global_step)
67 |
68 | tf.logging.info("Averaged checkpoints saved in %s", saved_name)
69 |
70 | if __name__ == "__main__":
71 | FLAGS = parseargs()
72 | tf.app.run()
73 |
--------------------------------------------------------------------------------
/thumt/scripts/check_param.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2018 The THUMT Authors
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import argparse
10 | import operator
11 | import os
12 |
13 | import numpy as np
14 | import tensorflow as tf
15 |
16 |
17 | def parseargs():
18 | msg = "Average checkpoints"
19 | usage = "average.py [] [-h | --help]"
20 | parser = argparse.ArgumentParser(description=msg, usage=usage)
21 |
22 | parser.add_argument("--model", type=str, required=True,
23 | help="checkpoint dir")
24 |
25 | return parser.parse_args()
26 |
27 | def main(_):
28 | tf.logging.set_verbosity(tf.logging.INFO)
29 |
30 | var_list = tf.contrib.framework.list_variables(FLAGS.model)
31 | var_values, var_dtypes = {}, {}
32 | model_from = "transformer_cov"
33 | model_to = "transformer_lrp"
34 |
35 | count = 0
36 | for (name, shape) in var_list:
37 | if True:#not name.startswith("global_step") and not 'Adam' in name:
38 | count += 1
39 | print(name, shape)
40 | name = name.replace(model_from, model_to)
41 | var_values[name] = np.zeros(shape)
42 | print(len(var_list))
43 | print(count)
44 |
45 |
46 | if __name__ == "__main__":
47 | FLAGS = parseargs()
48 | tf.app.run()
49 |
--------------------------------------------------------------------------------
/thumt/scripts/checkpoint_averaging.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2018 The THUMT Authors
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import argparse
10 | import operator
11 | import os
12 |
13 | import numpy as np
14 | import tensorflow as tf
15 |
16 |
17 | def parseargs():
18 | msg = "Average checkpoints"
19 | usage = "average.py [] [-h | --help]"
20 | parser = argparse.ArgumentParser(description=msg, usage=usage)
21 |
22 | parser.add_argument("--path", type=str, required=True,
23 | help="checkpoint dir")
24 | parser.add_argument("--checkpoints", type=int, required=True,
25 | help="number of checkpoints to use")
26 | parser.add_argument("--output", type=str, help="output path")
27 |
28 | return parser.parse_args()
29 |
30 |
31 | def get_checkpoints(path):
32 | if not tf.gfile.Exists(os.path.join(path, "checkpoint")):
33 | raise ValueError("Cannot find checkpoints in %s" % path)
34 |
35 | checkpoint_names = []
36 |
37 | with tf.gfile.GFile(os.path.join(path, "checkpoint")) as fd:
38 | # Skip the first line
39 | fd.readline()
40 | for line in fd:
41 | name = line.strip().split(":")[-1].strip()[1:-1]
42 | key = int(name.split("-")[-1])
43 | checkpoint_names.append((key, os.path.join(path, name)))
44 |
45 | sorted_names = sorted(checkpoint_names, key=operator.itemgetter(0),
46 | reverse=True)
47 |
48 | return [item[-1] for item in sorted_names]
49 |
50 |
51 | def checkpoint_exists(path):
52 | return (tf.gfile.Exists(path) or tf.gfile.Exists(path + ".meta") or
53 | tf.gfile.Exists(path + ".index"))
54 |
55 |
56 | def main(_):
57 | tf.logging.set_verbosity(tf.logging.INFO)
58 | checkpoints = get_checkpoints(FLAGS.path)
59 | checkpoints = checkpoints[:FLAGS.checkpoints]
60 |
61 | if not checkpoints:
62 | raise ValueError("No checkpoints provided for averaging.")
63 |
64 | checkpoints = [c for c in checkpoints if checkpoint_exists(c)]
65 |
66 | if not checkpoints:
67 | raise ValueError(
68 | "None of the provided checkpoints exist. %s" % FLAGS.checkpoints
69 | )
70 |
71 | var_list = tf.contrib.framework.list_variables(checkpoints[0])
72 | var_values, var_dtypes = {}, {}
73 |
74 | for (name, shape) in var_list:
75 | if not name.startswith("global_step"):
76 | var_values[name] = np.zeros(shape)
77 |
78 | for checkpoint in checkpoints:
79 | reader = tf.contrib.framework.load_checkpoint(checkpoint)
80 | for name in var_values:
81 | tensor = reader.get_tensor(name)
82 | var_dtypes[name] = tensor.dtype
83 | var_values[name] += tensor
84 | tf.logging.info("Read from checkpoint %s", checkpoint)
85 |
86 | # Average checkpoints
87 | for name in var_values:
88 | var_values[name] /= len(checkpoints)
89 |
90 | tf_vars = [
91 | tf.get_variable(name, shape=var_values[name].shape,
92 | dtype=var_dtypes[name]) for name in var_values
93 | ]
94 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
95 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
96 | global_step = tf.Variable(0, name="global_step", trainable=False,
97 | dtype=tf.int64)
98 | saver = tf.train.Saver(tf.global_variables())
99 |
100 | with tf.Session() as sess:
101 | sess.run(tf.global_variables_initializer())
102 | for p, assign_op, (name, value) in zip(placeholders, assign_ops,
103 | var_values.iteritems()):
104 | sess.run(assign_op, {p: value})
105 | saved_name = os.path.join(FLAGS.output, "average")
106 | saver.save(sess, saved_name, global_step=global_step)
107 |
108 | tf.logging.info("Averaged checkpoints saved in %s", saved_name)
109 |
110 | params_pattern = os.path.join(FLAGS.path, "*.json")
111 | params_files = tf.gfile.Glob(params_pattern)
112 |
113 | for name in params_files:
114 | new_name = name.replace(FLAGS.path.rstrip("/"),
115 | FLAGS.output.rstrip("/"))
116 | tf.gfile.Copy(name, new_name, overwrite=True)
117 |
118 |
119 | if __name__ == "__main__":
120 | FLAGS = parseargs()
121 | tf.app.run()
122 |
--------------------------------------------------------------------------------
/thumt/scripts/combine.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2018 The THUMT Authors
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import argparse
10 | import operator
11 | import os
12 |
13 | import numpy as np
14 | import tensorflow as tf
15 |
16 |
17 | def parseargs():
18 | msg = "Average checkpoints"
19 | usage = "average.py [] [-h | --help]"
20 | parser = argparse.ArgumentParser(description=msg, usage=usage)
21 |
22 | parser.add_argument("--model", type=str, required=True,
23 | help="checkpoint dir")
24 | parser.add_argument("--part", type=str, required=True,
25 | help="partial model dir")
26 | parser.add_argument("--output", type=str, help="output path")
27 |
28 | return parser.parse_args()
29 |
30 | def main(_):
31 | tf.logging.set_verbosity(tf.logging.INFO)
32 |
33 | var_list = tf.contrib.framework.list_variables(FLAGS.model)
34 | var_part = tf.contrib.framework.list_variables(FLAGS.part)
35 | var_values, var_dtypes = {}, {}
36 | var_values_part = {}
37 |
38 | for (name, shape) in var_list:
39 | if True:#not name.startswith("global_step") and not 'Adam' in name:
40 | var_values[name] = np.zeros(shape)
41 | for (name, shape) in var_part:
42 | var_values_part[name] = np.zeros(shape)
43 |
44 | reader = tf.contrib.framework.load_checkpoint(FLAGS.model)
45 | reader_part = tf.contrib.framework.load_checkpoint(FLAGS.part)
46 | for name in var_values:
47 | if name in var_values_part:
48 | tensor = reader_part.get_tensor(name)
49 | var_dtypes[name] = tensor.dtype
50 | var_values[name] += tensor
51 | print(name+' in part')
52 | else:
53 | tensor = reader.get_tensor(name)
54 | var_dtypes[name] = tensor.dtype
55 | var_values[name] += tensor
56 | print(name+' is new')
57 | tf.logging.info("Read from %s and %s", FLAGS.model, FLAGS.part)
58 |
59 | tf_vars = [
60 | tf.get_variable(name, shape=var_values[name].shape,
61 | dtype=var_dtypes[name]) for name in var_values
62 | ]
63 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
64 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
65 | global_step = tf.Variable(0, name="global_step", trainable=False,
66 | dtype=tf.int64)
67 | saver = tf.train.Saver(tf.global_variables())
68 |
69 | with tf.Session() as sess:
70 | sess.run(tf.global_variables_initializer())
71 | for p, assign_op, (name, value) in zip(placeholders, assign_ops,
72 | var_values.iteritems()):
73 | sess.run(assign_op, {p: value})
74 | saved_name = os.path.join(FLAGS.output, "new")
75 | saver.save(sess, saved_name, global_step=global_step)
76 |
77 | tf.logging.info("Averaged checkpoints saved in %s", saved_name)
78 |
79 | if __name__ == "__main__":
80 | FLAGS = parseargs()
81 | tf.app.run()
82 |
--------------------------------------------------------------------------------
/thumt/scripts/combine_add.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2018 The THUMT Authors
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import argparse
10 | import operator
11 | import os
12 |
13 | import numpy as np
14 | import tensorflow as tf
15 |
16 |
17 | def parseargs():
18 | msg = "Average checkpoints"
19 | usage = "average.py [] [-h | --help]"
20 | parser = argparse.ArgumentParser(description=msg, usage=usage)
21 |
22 | parser.add_argument("--model", type=str, required=True,
23 | help="checkpoint dir")
24 | parser.add_argument("--part", type=str, required=True,
25 | help="partial model dir")
26 | parser.add_argument("--output", type=str, help="output path")
27 |
28 | return parser.parse_args()
29 |
30 | def main(_):
31 | tf.logging.set_verbosity(tf.logging.INFO)
32 |
33 | var_list = tf.contrib.framework.list_variables(FLAGS.model)
34 | var_part = tf.contrib.framework.list_variables(FLAGS.part)
35 | var_values, var_dtypes = {}, {}
36 | var_values_part = {}
37 |
38 | for (name, shape) in var_list:
39 | if True:#not name.startswith("global_step") and not 'Adam' in name:
40 | var_values[name] = np.zeros(shape)
41 | for (name, shape) in var_part:
42 | var_values[name] = np.zeros(shape)
43 | var_values_part[name] = np.zeros(shape)
44 |
45 | reader = tf.contrib.framework.load_checkpoint(FLAGS.model)
46 | reader_part = tf.contrib.framework.load_checkpoint(FLAGS.part)
47 | for name in var_values:
48 | if name in var_values_part:
49 | tensor = reader_part.get_tensor(name)
50 | var_dtypes[name] = tensor.dtype
51 | var_values[name] += tensor
52 | print(name+' in part')
53 | else:
54 | tensor = reader.get_tensor(name)
55 | var_dtypes[name] = tensor.dtype
56 | var_values[name] += tensor
57 | print(name+' is new')
58 | tf.logging.info("Read from %s and %s", FLAGS.model, FLAGS.part)
59 |
60 | tf_vars = [
61 | tf.get_variable(name, shape=var_values[name].shape,
62 | dtype=var_dtypes[name]) for name in var_values
63 | ]
64 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
65 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
66 | global_step = tf.Variable(0, name="global_step", trainable=False,
67 | dtype=tf.int64)
68 | saver = tf.train.Saver(tf.global_variables())
69 |
70 | with tf.Session() as sess:
71 | sess.run(tf.global_variables_initializer())
72 | for p, assign_op, (name, value) in zip(placeholders, assign_ops,
73 | var_values.iteritems()):
74 | sess.run(assign_op, {p: value})
75 | saved_name = os.path.join(FLAGS.output, "new")
76 | saver.save(sess, saved_name, global_step=global_step)
77 |
78 | tf.logging.info("Averaged checkpoints saved in %s", saved_name)
79 |
80 | if __name__ == "__main__":
81 | FLAGS = parseargs()
82 | tf.app.run()
83 |
--------------------------------------------------------------------------------
/thumt/scripts/compare.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2018 The THUMT Authors
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import argparse
10 | import operator
11 | import os
12 |
13 | import numpy as np
14 | import tensorflow as tf
15 |
16 |
17 | def parseargs():
18 | msg = "Average checkpoints"
19 | usage = "average.py [] [-h | --help]"
20 | parser = argparse.ArgumentParser(description=msg, usage=usage)
21 |
22 | parser.add_argument("--model", type=str, required=True,
23 | help="checkpoint dir")
24 | parser.add_argument("--part", type=str, required=True,
25 | help="partial model dir")
26 |
27 | return parser.parse_args()
28 |
29 | def main(_):
30 | tf.logging.set_verbosity(tf.logging.INFO)
31 |
32 | var_list = tf.contrib.framework.list_variables(FLAGS.model)
33 | var_part = tf.contrib.framework.list_variables(FLAGS.part)
34 | var_values, var_dtypes = {}, {}
35 | var_values_part = {}
36 |
37 | for (name, shape) in var_list:
38 | var_values[name] = np.zeros(shape)
39 | for (name, shape) in var_part:
40 | var_values_part[name] = np.zeros(shape)
41 |
42 | reader = tf.contrib.framework.load_checkpoint(FLAGS.model)
43 | reader_part = tf.contrib.framework.load_checkpoint(FLAGS.part)
44 | for name in var_values:
45 | if name in var_values_part:
46 | tensor_part = reader_part.get_tensor(name)
47 | tensor = reader.get_tensor(name)
48 | print(type(tensor))
49 | if tensor.equal(tensor_part):
50 | print('name '+name+' equals')
51 | else:
52 | print('name '+name+' is different')
53 | tf.logging.info("Read from %s and %s", FLAGS.model, FLAGS.part)
54 |
55 | if __name__ == "__main__":
56 | FLAGS = parseargs()
57 | tf.app.run()
58 |
--------------------------------------------------------------------------------
/thumt/scripts/convert_old_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2018 The THUMT Authors
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import argparse
10 |
11 | import numpy as np
12 | import tensorflow as tf
13 |
14 |
15 | def parseargs():
16 | parser = argparse.ArgumentParser(description="Convert old models")
17 |
18 | parser.add_argument("--input", type=str, required=True,
19 | help="Path of old model")
20 | parser.add_argument("--output", type=str, required=True,
21 | help="Path of output checkpoint")
22 |
23 | return parser.parse_args()
24 |
25 |
26 | def old_keys():
27 | keys = [
28 | "GRU_dec_attcontext",
29 | "GRU_dec_att",
30 | "GRU_dec_atthidden",
31 | "GRU_dec_inputoffset",
32 | "GRU_dec_inputemb",
33 | "GRU_dec_inputcontext",
34 | "GRU_dec_inputhidden",
35 | "GRU_dec_resetemb",
36 | "GRU_dec_resetcontext",
37 | "GRU_dec_resethidden",
38 | "GRU_dec_gateemb",
39 | "GRU_dec_gatecontext",
40 | "GRU_dec_gatehidden",
41 | "initer_b",
42 | "initer_W",
43 | "GRU_dec_probsemb",
44 | "GRU_enc_back_inputoffset",
45 | "GRU_enc_back_inputemb",
46 | "GRU_enc_back_inputhidden",
47 | "GRU_enc_back_resetemb",
48 | "GRU_enc_back_resethidden",
49 | "GRU_enc_back_gateemb",
50 | "GRU_enc_back_gatehidden",
51 | "GRU_enc_inputoffset",
52 | "GRU_enc_inputemb",
53 | "GRU_enc_inputhidden",
54 | "GRU_enc_resetemb",
55 | "GRU_enc_resethidden",
56 | "GRU_enc_gateemb",
57 | "GRU_enc_gatehidden",
58 | "GRU_dec_readoutoffset",
59 | "GRU_dec_readoutemb",
60 | "GRU_dec_readouthidden",
61 | "GRU_dec_readoutcontext",
62 | "GRU_dec_probsoffset",
63 | "GRU_dec_probs",
64 | "emb_src_b",
65 | "emb_src_emb",
66 | "emb_trg_b",
67 | "emb_trg_emb"
68 | ]
69 |
70 | return keys
71 |
72 |
73 | def new_keys():
74 | keys = [
75 | "rnnsearch/decoder/attention/k_transform/matrix_0",
76 | "rnnsearch/decoder/attention/logits/matrix_0",
77 | "rnnsearch/decoder/attention/q_transform/matrix_0",
78 | "rnnsearch/decoder/gru_cell/candidate/bias",
79 | "rnnsearch/decoder/gru_cell/candidate/matrix_0",
80 | "rnnsearch/decoder/gru_cell/candidate/matrix_1",
81 | "rnnsearch/decoder/gru_cell/candidate/matrix_2",
82 | "rnnsearch/decoder/gru_cell/reset_gate/matrix_0",
83 | "rnnsearch/decoder/gru_cell/reset_gate/matrix_1",
84 | "rnnsearch/decoder/gru_cell/reset_gate/matrix_2",
85 | "rnnsearch/decoder/gru_cell/update_gate/matrix_0",
86 | "rnnsearch/decoder/gru_cell/update_gate/matrix_1",
87 | "rnnsearch/decoder/gru_cell/update_gate/matrix_2",
88 | "rnnsearch/decoder/s_transform/bias",
89 | "rnnsearch/decoder/s_transform/matrix_0",
90 | "rnnsearch/deepout/matrix_0",
91 | "rnnsearch/encoder/backward/gru_cell/candidate/bias",
92 | "rnnsearch/encoder/backward/gru_cell/candidate/matrix_0",
93 | "rnnsearch/encoder/backward/gru_cell/candidate/matrix_1",
94 | "rnnsearch/encoder/backward/gru_cell/reset_gate/matrix_0",
95 | "rnnsearch/encoder/backward/gru_cell/reset_gate/matrix_1",
96 | "rnnsearch/encoder/backward/gru_cell/update_gate/matrix_0",
97 | "rnnsearch/encoder/backward/gru_cell/update_gate/matrix_1",
98 | "rnnsearch/encoder/forward/gru_cell/candidate/bias",
99 | "rnnsearch/encoder/forward/gru_cell/candidate/matrix_0",
100 | "rnnsearch/encoder/forward/gru_cell/candidate/matrix_1",
101 | "rnnsearch/encoder/forward/gru_cell/reset_gate/matrix_0",
102 | "rnnsearch/encoder/forward/gru_cell/reset_gate/matrix_1",
103 | "rnnsearch/encoder/forward/gru_cell/update_gate/matrix_0",
104 | "rnnsearch/encoder/forward/gru_cell/update_gate/matrix_1",
105 | "rnnsearch/maxout/bias",
106 | "rnnsearch/maxout/matrix_0",
107 | "rnnsearch/maxout/matrix_1",
108 | "rnnsearch/maxout/matrix_2",
109 | "rnnsearch/softmax/bias",
110 | "rnnsearch/softmax/matrix_0",
111 | "rnnsearch/source_embedding/bias",
112 | "rnnsearch/source_embedding/embedding",
113 | "rnnsearch/target_embedding/bias",
114 | "rnnsearch/target_embedding/embedding",
115 | ]
116 |
117 | return keys
118 |
119 |
120 | def main(args):
121 | values = dict(np.load(args.input))
122 | variables = {}
123 | o_keys = old_keys()
124 | n_keys = new_keys()
125 |
126 | for i, key in enumerate(o_keys):
127 | v = values[key]
128 | variables[n_keys[i]] = v
129 |
130 | with tf.Graph().as_default():
131 | with tf.device("/cpu:0"):
132 | tf_vars = [
133 | tf.get_variable(v, initializer=variables[v], dtype=tf.float32)
134 | for v in variables
135 | ]
136 | global_step = tf.Variable(0, name="global_step", trainable=False,
137 | dtype=tf.int64)
138 |
139 | saver = tf.train.Saver(tf_vars)
140 |
141 | with tf.Session() as sess:
142 | sess.run(tf.global_variables_initializer())
143 | saver.save(sess, args.output, global_step=global_step)
144 |
145 |
146 | if __name__ == "__main__":
147 | main(parseargs())
148 |
--------------------------------------------------------------------------------
/thumt/scripts/convert_vocab.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2018 The THUMT Authors
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import cPickle
10 | import sys
11 |
12 | if __name__ == "__main__":
13 | with open(sys.argv[1]) as fd:
14 | voc = cPickle.load(fd)
15 |
16 | ivoc = {}
17 |
18 | for key in voc:
19 | ivoc[voc[key]] = key
20 |
21 | with open(sys.argv[2], "w") as fd:
22 | for key in ivoc:
23 | val = ivoc[key]
24 | fd.write(val + "\n")
25 |
--------------------------------------------------------------------------------
/thumt/scripts/input_converter.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import argparse
9 | import os
10 | import random
11 | import six
12 |
13 | import tensorflow as tf
14 |
15 |
16 | def load_vocab(filename):
17 | with tf.gfile.Open(filename) as fd:
18 | count = 0
19 | vocab = {}
20 | for line in fd:
21 | word = line.strip()
22 | vocab[word] = count
23 | count += 1
24 |
25 | return vocab
26 |
27 |
28 | def to_example(dictionary):
29 | """ Convert python dictionary to tf.train.Example """
30 | features = {}
31 |
32 | for (k, v) in six.iteritems(dictionary):
33 | if not v:
34 | raise ValueError("Empty generated field: %s", str((k, v)))
35 |
36 | if isinstance(v[0], six.integer_types):
37 | int64_list = tf.train.Int64List(value=v)
38 | features[k] = tf.train.Feature(int64_list=int64_list)
39 | elif isinstance(v[0], float):
40 | float_list = tf.train.FloatList(value=v)
41 | features[k] = tf.train.Feature(float_list=float_list)
42 | elif isinstance(v[0], six.string_types):
43 | bytes_list = tf.train.BytesList(value=v)
44 | features[k] = tf.train.Feature(bytes_list=bytes_list)
45 | else:
46 | raise ValueError("Value is neither an int nor a float; "
47 | "v: %s type: %s" % (str(v[0]), str(type(v[0]))))
48 |
49 | return tf.train.Example(features=tf.train.Features(feature=features))
50 |
51 |
52 | def write_records(records, out_filename):
53 | """ Write to TensorFlow record """
54 | writer = tf.python_io.TFRecordWriter(out_filename)
55 |
56 | for count, record in enumerate(records):
57 | writer.write(record)
58 | if count % 10000 == 0:
59 | tf.logging.info("write: %d", count)
60 |
61 | writer.close()
62 |
63 |
64 | def convert_to_record(inputs, vocab, output_name, output_dir, num_shards,
65 | shuffle=False):
66 | """ Convert plain parallel text to TensorFlow record """
67 | source, target = inputs
68 | svocab, tvocab = vocab
69 | records = []
70 |
71 | with tf.gfile.Open(source) as src:
72 | with tf.gfile.Open(target) as tgt:
73 | for sline, tline in zip(src, tgt):
74 | sline = sline.strip().split()
75 | sline = [svocab[item] if item in svocab else svocab[FLAGS.unk]
76 | for item in sline] + [svocab[FLAGS.eos]]
77 | tline = tline.strip().split()
78 | tline = [tvocab[item] if item in tvocab else tvocab[FLAGS.unk]
79 | for item in tline] + [tvocab[FLAGS.eos]]
80 |
81 | feature = {
82 | "source": sline,
83 | "target": tline,
84 | "source_length": [len(sline)],
85 | "target_length": [len(tline)]
86 | }
87 | records.append(feature)
88 |
89 | output_files = []
90 | writers = []
91 |
92 | for shard in xrange(num_shards):
93 | output_filename = "%s-%.5d-of-%.5d" % (output_name, shard, num_shards)
94 | output_file = os.path.join(output_dir, output_filename)
95 | output_files.append(output_file)
96 | writers.append(tf.python_io.TFRecordWriter(output_file))
97 |
98 | counter, shard = 0, 0
99 |
100 | if shuffle:
101 | random.shuffle(records)
102 |
103 | for record in records:
104 | counter += 1
105 | example = to_example(record)
106 | writers[shard].write(example.SerializeToString())
107 | shard = (shard + 1) % num_shards
108 |
109 | for writer in writers:
110 | writer.close()
111 |
112 |
113 | def parse_args():
114 | msg = "convert inputs to tf.Record format"
115 | usage = "input_converter.py [] [-h | --help]"
116 | parser = argparse.ArgumentParser(description=msg, usage=usage)
117 |
118 | parser.add_argument("--input", required=True, type=str, nargs=2,
119 | help="Path of input file")
120 | parser.add_argument("--output_name", required=True, type=str,
121 | help="Output name")
122 | parser.add_argument("--output_dir", required=True, type=str,
123 | help="Output directory")
124 | parser.add_argument("--vocab", nargs=2, required=True, type=str,
125 | help="Path of vocabulary")
126 | parser.add_argument("--num_shards", default=100, type=int,
127 | help="Number of output shards")
128 | parser.add_argument("--shuffle", action="store_true",
129 | help="Shuffle inputs")
130 | parser.add_argument("--unk", default="", type=str,
131 | help="Unknown word symbol")
132 | parser.add_argument("--eos", default="", type=str,
133 | help="End of sentence symbol")
134 |
135 | return parser.parse_args()
136 |
137 |
138 | def main(_):
139 | svocab = load_vocab(FLAGS.vocab[0])
140 | tvocab = load_vocab(FLAGS.vocab[1])
141 |
142 | # convert data
143 | convert_to_record(FLAGS.input, [svocab, tvocab], FLAGS.output_name,
144 | FLAGS.output_dir, FLAGS.num_shards, FLAGS.shuffle)
145 |
146 |
147 | if __name__ == "__main__":
148 | FLAGS = parse_args()
149 | tf.app.run()
150 |
--------------------------------------------------------------------------------
/thumt/scripts/shuffle_corpus.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import argparse
9 | import numpy
10 |
11 |
12 | def parseargs():
13 | parser = argparse.ArgumentParser(description="Shuffle corpus")
14 |
15 | parser.add_argument("--corpus", nargs="+", required=True,
16 | help="input corpora")
17 | parser.add_argument("--suffix", type=str, default="shuf",
18 | help="Suffix of output files")
19 | parser.add_argument("--seed", type=int, help="Random seed")
20 |
21 | return parser.parse_args()
22 |
23 |
24 | def main(args):
25 | name = args.corpus
26 | suffix = "." + args.suffix
27 | stream = [open(item, "r") for item in name]
28 | data = [fd.readlines() for fd in stream]
29 | minlen = min([len(lines) for lines in data])
30 |
31 | if args.seed:
32 | numpy.random.seed(args.seed)
33 |
34 | indices = numpy.arange(minlen)
35 | numpy.random.shuffle(indices)
36 |
37 | newstream = [open(item + suffix, "w") for item in name]
38 |
39 | for idx in indices.tolist():
40 | lines = [item[idx] for item in data]
41 |
42 | for line, fd in zip(lines, newstream):
43 | fd.write(line)
44 |
45 | for fdr, fdw in zip(stream, newstream):
46 | fdr.close()
47 | fdw.close()
48 |
49 |
50 | if __name__ == "__main__":
51 | parsed_args = parseargs()
52 | main(parsed_args)
53 |
--------------------------------------------------------------------------------
/thumt/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
--------------------------------------------------------------------------------
/thumt/utils/bleu.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import math
9 |
10 | from collections import Counter
11 |
12 |
13 | def closest_length(candidate, references):
14 | clen = len(candidate)
15 | closest_diff = 9999
16 | closest_len = 9999
17 |
18 | for reference in references:
19 | rlen = len(reference)
20 | diff = abs(rlen - clen)
21 |
22 | if diff < closest_diff:
23 | closest_diff = diff
24 | closest_len = rlen
25 | elif diff == closest_diff:
26 | closest_len = rlen if rlen < closest_len else closest_len
27 |
28 | return closest_len
29 |
30 |
31 | def shortest_length(references):
32 | return min([len(ref) for ref in references])
33 |
34 |
35 | def modified_precision(candidate, references, n):
36 | tngrams = len(candidate) + 1 - n
37 | counts = Counter([tuple(candidate[i:i+n]) for i in range(tngrams)])
38 |
39 | if len(counts) == 0:
40 | return 0, 0
41 |
42 | max_counts = {}
43 | for reference in references:
44 | rngrams = len(reference) + 1 - n
45 | ngrams = [tuple(reference[i:i+n]) for i in range(rngrams)]
46 | ref_counts = Counter(ngrams)
47 | for ngram in counts:
48 | mcount = 0 if ngram not in max_counts else max_counts[ngram]
49 | rcount = 0 if ngram not in ref_counts else ref_counts[ngram]
50 | max_counts[ngram] = max(mcount, rcount)
51 |
52 | clipped_counts = {}
53 |
54 | for ngram, count in counts.items():
55 | clipped_counts[ngram] = min(count, max_counts[ngram])
56 |
57 | return float(sum(clipped_counts.values())), float(sum(counts.values()))
58 |
59 |
60 | def brevity_penalty(trans, refs, mode="closest"):
61 | bp_c = 0.0
62 | bp_r = 0.0
63 |
64 | for candidate, references in zip(trans, refs):
65 | bp_c += len(candidate)
66 |
67 | if mode == "shortest":
68 | bp_r += shortest_length(references)
69 | else:
70 | bp_r += closest_length(candidate, references)
71 |
72 | # Prevent zero divide
73 | bp_c = bp_c or 1.0
74 |
75 | return math.exp(min(0, 1.0 - bp_r / bp_c))
76 |
77 |
78 | def bleu(trans, refs, bp="closest", smooth=False, n=4, weights=None):
79 | p_norm = [0 for _ in range(n)]
80 | p_denorm = [0 for _ in range(n)]
81 |
82 | for candidate, references in zip(trans, refs):
83 | for i in range(n):
84 | ccount, tcount = modified_precision(candidate, references, i + 1)
85 | p_norm[i] += ccount
86 | p_denorm[i] += tcount
87 |
88 | bleu_n = [0 for _ in range(n)]
89 |
90 | for i in range(n):
91 | # add one smoothing
92 | if smooth and i > 0:
93 | p_norm[i] += 1
94 | p_denorm[i] += 1
95 |
96 | if p_norm[i] == 0 or p_denorm[i] == 0:
97 | bleu_n[i] = -9999
98 | else:
99 | bleu_n[i] = math.log(float(p_norm[i]) / float(p_denorm[i]))
100 |
101 | if weights:
102 | if len(weights) != n:
103 | raise ValueError("len(weights) != n: invalid weight number")
104 | log_precision = sum([bleu_n[i] * weights[i] for i in range(n)])
105 | else:
106 | log_precision = sum(bleu_n) / float(n)
107 |
108 | bp = brevity_penalty(trans, refs, bp)
109 |
110 | score = bp * math.exp(log_precision)
111 |
112 | return score
113 |
--------------------------------------------------------------------------------
/thumt/utils/hooks.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import datetime
9 | import operator
10 | import os
11 |
12 | import tensorflow as tf
13 | import thumt.utils.bleu as bleu
14 |
15 |
16 | def _get_saver():
17 | # Get saver from the SAVERS collection if present.
18 | collection_key = tf.GraphKeys.SAVERS
19 | savers = tf.get_collection(collection_key)
20 |
21 | if not savers:
22 | raise RuntimeError("No items in collection {}. "
23 | "Please add a saver to the collection ")
24 | elif len(savers) > 1:
25 | raise RuntimeError("More than one item in collection")
26 |
27 | return savers[0]
28 |
29 |
30 | def _save_log(filename, result):
31 | metric, global_step, score = result
32 |
33 | with open(filename, "a") as fd:
34 | time = datetime.datetime.now()
35 | msg = "%s: %s at step %d: %f\n" % (time, metric, global_step, score)
36 | fd.write(msg)
37 |
38 |
39 | def _read_checkpoint_def(filename):
40 | records = []
41 |
42 | with tf.gfile.GFile(filename) as fd:
43 | fd.readline()
44 |
45 | for line in fd:
46 | records.append(line.strip().split(":")[-1].strip()[1:-1])
47 |
48 | return records
49 |
50 |
51 | def _save_checkpoint_def(filename, checkpoint_names):
52 | keys = []
53 |
54 | for checkpoint_name in checkpoint_names:
55 | step = int(checkpoint_name.strip().split("-")[-1])
56 | keys.append((step, checkpoint_name))
57 |
58 | sorted_names = sorted(keys, key=operator.itemgetter(0),
59 | reverse=True)
60 |
61 | with tf.gfile.GFile(filename, "w") as fd:
62 | fd.write("model_checkpoint_path: \"%s\"\n" % checkpoint_names[0])
63 |
64 | for checkpoint_name in sorted_names:
65 | checkpoint_name = checkpoint_name[1]
66 | fd.write("all_model_checkpoint_paths: \"%s\"\n" % checkpoint_name)
67 |
68 |
69 | def _read_score_record(filename):
70 | # "checkpoint_name": score
71 | records = []
72 |
73 | if not tf.gfile.Exists(filename):
74 | return records
75 |
76 | with tf.gfile.GFile(filename) as fd:
77 | for line in fd:
78 | name, score = line.strip().split(":")
79 | name = name.strip()[1:-1]
80 | score = float(score)
81 | records.append([name, score])
82 |
83 | return records
84 |
85 |
86 | def _save_score_record(filename, records):
87 | keys = []
88 |
89 | for record in records:
90 | checkpoint_name = record[0]
91 | step = int(checkpoint_name.strip().split("-")[-1])
92 | keys.append((step, record))
93 |
94 | sorted_keys = sorted(keys, key=operator.itemgetter(0),
95 | reverse=True)
96 | sorted_records = [item[1] for item in sorted_keys]
97 |
98 | with tf.gfile.GFile(filename, "w") as fd:
99 | for record in sorted_records:
100 | checkpoint_name, score = record
101 | fd.write("\"%s\": %f\n" % (checkpoint_name, score))
102 |
103 |
104 | def _add_to_record(records, record, max_to_keep):
105 | added = None
106 | removed = None
107 | models = {}
108 |
109 | for (name, score) in records:
110 | models[name] = score
111 |
112 | if len(records) < max_to_keep:
113 | if record[0] not in models:
114 | added = record[0]
115 | records.append(record)
116 | else:
117 | sorted_records = sorted(records, key=lambda x: -x[1])
118 | worst_score = sorted_records[-1][1]
119 | current_score = record[1]
120 |
121 | if current_score >= worst_score:
122 | if record[0] not in models:
123 | added = record[0]
124 | removed = sorted_records[-1][0]
125 | records = sorted_records[:-1] + [record]
126 |
127 | # Sort
128 | records = sorted(records, key=lambda x: -x[1])
129 |
130 | return added, removed, records
131 |
132 |
133 | def _evaluate(eval_fn, input_fn, decode_fn, path, config):
134 | graph = tf.Graph()
135 | with graph.as_default():
136 | features = input_fn()
137 | refs = features["references"]
138 | placeholders = {
139 | "source": tf.placeholder(tf.int32, [None, None], "source"),
140 | "source_length": tf.placeholder(tf.int32, [None], "source_length")
141 | }
142 | predictions = eval_fn(placeholders)
143 | predictions = predictions[0][:, 0, :]
144 |
145 | all_refs = [[] for _ in range(len(refs))]
146 | all_outputs = []
147 |
148 | sess_creator = tf.train.ChiefSessionCreator(
149 | checkpoint_dir=path,
150 | config=config
151 | )
152 |
153 | with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
154 | while not sess.should_stop():
155 | feats = sess.run(features)
156 | outputs = sess.run(predictions, feed_dict={
157 | placeholders["source"]: feats["source"],
158 | placeholders["source_length"]: feats["source_length"]
159 | })
160 | # shape: [batch, len]
161 | outputs = outputs.tolist()
162 | # shape: ([batch, len], ..., [batch, len])
163 | references = [item.tolist() for item in feats["references"]]
164 |
165 | all_outputs.extend(outputs)
166 |
167 | for i in range(len(refs)):
168 | all_refs[i].extend(references[i])
169 |
170 | decoded_symbols = decode_fn(all_outputs)
171 | decoded_refs = [decode_fn(refs) for refs in all_refs]
172 | decoded_refs = [list(x) for x in zip(*decoded_refs)]
173 |
174 | return bleu.bleu(decoded_symbols, decoded_refs)
175 |
176 |
177 | class EvaluationHook(tf.train.SessionRunHook):
178 | """ Validate and save checkpoints every N steps or seconds.
179 | This hook only saves checkpoint according to a specific metric.
180 | """
181 |
182 | def __init__(self, eval_fn, eval_input_fn, eval_decode_fn, base_dir,
183 | session_config, max_to_keep=5, eval_secs=None,
184 | eval_steps=None, metric="BLEU"):
185 | """ Initializes a `EvaluationHook`.
186 | :param eval_fn: A function with signature (feature)
187 | :param eval_input_fn: A function with signature ()
188 | :param eval_decode_fn: A function with signature (inputs)
189 | :param base_dir: A string. Base directory for the checkpoint files.
190 | :param session_config: An instance of tf.ConfigProto
191 | :param max_to_keep: An integer. The maximum of checkpoints to save
192 | :param eval_secs: An integer, eval every N secs.
193 | :param eval_steps: An integer, eval every N steps.
194 | :param checkpoint_basename: `str`, base name for the checkpoint files.
195 | :raises ValueError: One of `save_steps` or `save_secs` should be set.
196 | :raises ValueError: At most one of saver or scaffold should be set.
197 | """
198 | tf.logging.info("Create EvaluationHook.")
199 |
200 | if metric != "BLEU":
201 | raise ValueError("Currently, EvaluationHook only support BLEU")
202 |
203 | self._base_dir = base_dir.rstrip("/")
204 | self._session_config = session_config
205 | self._save_path = os.path.join(base_dir, "eval")
206 | self._record_name = os.path.join(self._save_path, "record")
207 | self._log_name = os.path.join(self._save_path, "log")
208 | self._eval_fn = eval_fn
209 | self._eval_input_fn = eval_input_fn
210 | self._eval_decode_fn = eval_decode_fn
211 | self._max_to_keep = max_to_keep
212 | self._metric = metric
213 | self._global_step = None
214 | self._timer = tf.train.SecondOrStepTimer(
215 | every_secs=eval_secs or None, every_steps=eval_steps or None
216 | )
217 |
218 | def begin(self):
219 | if self._timer.last_triggered_step() is None:
220 | self._timer.update_last_triggered_step(0)
221 |
222 | global_step = tf.train.get_global_step()
223 |
224 | if not tf.gfile.Exists(self._save_path):
225 | tf.logging.info("Making dir: %s" % self._save_path)
226 | tf.gfile.MakeDirs(self._save_path)
227 |
228 | params_pattern = os.path.join(self._base_dir, "*.json")
229 | params_files = tf.gfile.Glob(params_pattern)
230 |
231 | for name in params_files:
232 | new_name = name.replace(self._base_dir, self._save_path)
233 | tf.gfile.Copy(name, new_name, overwrite=True)
234 |
235 | if global_step is None:
236 | raise RuntimeError("Global step should be created first")
237 |
238 | self._global_step = global_step
239 |
240 | def before_run(self, run_context):
241 | args = tf.train.SessionRunArgs(self._global_step)
242 | return args
243 |
244 | def after_run(self, run_context, run_values):
245 | stale_global_step = run_values.results
246 |
247 | if self._timer.should_trigger_for_step(stale_global_step + 1):
248 | global_step = run_context.session.run(self._global_step)
249 |
250 | # Get the real value
251 | if self._timer.should_trigger_for_step(global_step):
252 | self._timer.update_last_triggered_step(global_step)
253 | # Save model
254 | save_path = os.path.join(self._base_dir, "model.ckpt")
255 | saver = _get_saver()
256 | tf.logging.info("Saving checkpoints for %d into %s." %
257 | (global_step, save_path))
258 | saver.save(run_context.session,
259 | save_path,
260 | global_step=global_step)
261 | # Do validation here
262 | tf.logging.info("Validating model at step %d" % global_step)
263 | score = _evaluate(self._eval_fn, self._eval_input_fn,
264 | self._eval_decode_fn,
265 | self._base_dir,
266 | self._session_config)
267 | tf.logging.info("%s at step %d: %f" %
268 | (self._metric, global_step, score))
269 |
270 | _save_log(self._log_name, (self._metric, global_step, score))
271 |
272 | checkpoint_filename = os.path.join(self._base_dir,
273 | "checkpoint")
274 | all_checkpoints = _read_checkpoint_def(checkpoint_filename)
275 | records = _read_score_record(self._record_name)
276 | latest_checkpoint = all_checkpoints[-1]
277 | record = [latest_checkpoint, score]
278 | added, removed, records = _add_to_record(records, record,
279 | self._max_to_keep)
280 |
281 | if added is not None:
282 | old_path = os.path.join(self._base_dir, added)
283 | new_path = os.path.join(self._save_path, added)
284 | old_files = tf.gfile.Glob(old_path + "*")
285 | tf.logging.info("Copying %s to %s" % (old_path, new_path))
286 |
287 | for o_file in old_files:
288 | n_file = o_file.replace(old_path, new_path)
289 | tf.gfile.Copy(o_file, n_file, overwrite=True)
290 |
291 | if removed is not None:
292 | filename = os.path.join(self._save_path, removed)
293 | tf.logging.info("Removing %s" % filename)
294 | files = tf.gfile.Glob(filename + "*")
295 |
296 | for name in files:
297 | tf.gfile.Remove(name)
298 |
299 | _save_score_record(self._record_name, records)
300 | checkpoint_filename = checkpoint_filename.replace(
301 | self._base_dir, self._save_path
302 | )
303 | _save_checkpoint_def(checkpoint_filename,
304 | [item[0] for item in records])
305 |
306 | best_score = records[0][1]
307 | tf.logging.info("Best score at step %d: %f" %
308 | (global_step, best_score))
309 |
310 | def end(self, session):
311 | last_step = session.run(self._global_step)
312 |
313 | if last_step != self._timer.last_triggered_step():
314 | global_step = last_step
315 | tf.logging.info("Validating model at step %d" % global_step)
316 | score = _evaluate(self._eval_fn, self._eval_input_fn,
317 | self._eval_decode_fn,
318 | self._base_dir,
319 | self._session_config)
320 | tf.logging.info("%s at step %d: %f" %
321 | (self._metric, global_step, score))
322 |
323 | checkpoint_filename = os.path.join(self._base_dir,
324 | "checkpoint")
325 | all_checkpoints = _read_checkpoint_def(checkpoint_filename)
326 | records = _read_score_record(self._record_name)
327 | latest_checkpoint = all_checkpoints[-1]
328 | record = [latest_checkpoint, score]
329 | added, removed, records = _add_to_record(records, record,
330 | self._max_to_keep)
331 |
332 | if added is not None:
333 | old_path = os.path.join(self._base_dir, added)
334 | new_path = os.path.join(self._save_path, added)
335 | old_files = tf.gfile.Glob(old_path + "*")
336 | tf.logging.info("Copying %s to %s" % (old_path, new_path))
337 |
338 | for o_file in old_files:
339 | n_file = o_file.replace(old_path, new_path)
340 | tf.gfile.Copy(o_file, n_file, overwrite=True)
341 |
342 | if removed is not None:
343 | filename = os.path.join(self._save_path, removed)
344 | tf.logging.info("Removing %s" % filename)
345 | files = tf.gfile.Glob(filename + "*")
346 |
347 | for name in files:
348 | tf.gfile.Remove(name)
349 |
350 | _save_score_record(self._record_name, records)
351 | checkpoint_filename = checkpoint_filename.replace(
352 | self._base_dir, self._save_path
353 | )
354 | _save_checkpoint_def(checkpoint_filename,
355 | [item[0] for item in records])
356 |
357 | best_score = records[0][1]
358 | tf.logging.info("Best score: %f" % best_score)
359 |
--------------------------------------------------------------------------------
/thumt/utils/inference.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import copy
9 | import tensorflow as tf
10 |
11 | from collections import namedtuple
12 | from tensorflow.python.util import nest
13 |
14 |
15 | class BeamSearchState(namedtuple("BeamSearchState",
16 | ("inputs", "state", "finish"))):
17 | pass
18 |
19 |
20 | def _get_inference_fn(model_fns, features):
21 | def inference_fn(inputs, state):
22 | local_features = {
23 | "source": features["source"],
24 | "source_length": features["source_length"],
25 | # [bos_id, ...] => [..., 0]
26 | "target": tf.pad(inputs[:, 1:], [[0, 0], [0, 1]]),
27 | "target_length": tf.fill([tf.shape(inputs)[0]],
28 | tf.shape(inputs)[1])
29 | }
30 |
31 | outputs = []
32 | next_state = []
33 |
34 | for (model_fn, model_state) in zip(model_fns, state):
35 | if model_state:
36 | output, new_state = model_fn(local_features, model_state)
37 | outputs.append(output)
38 | next_state.append(new_state)
39 | else:
40 | output = model_fn(local_features)
41 | outputs.append(output)
42 | next_state.append({})
43 |
44 | # Ensemble
45 | log_prob = tf.add_n(outputs) / float(len(outputs))
46 |
47 | return log_prob, next_state
48 |
49 | return inference_fn
50 |
51 |
52 | def _infer_shape(x):
53 | x = tf.convert_to_tensor(x)
54 |
55 | # If unknown rank, return dynamic shape
56 | if x.shape.dims is None:
57 | return tf.shape(x)
58 |
59 | static_shape = x.shape.as_list()
60 | dynamic_shape = tf.shape(x)
61 |
62 | ret = []
63 | for i in range(len(static_shape)):
64 | dim = static_shape[i]
65 | if dim is None:
66 | dim = dynamic_shape[i]
67 | ret.append(dim)
68 |
69 | return ret
70 |
71 |
72 | def _infer_shape_invariants(tensor):
73 | shape = tensor.shape.as_list()
74 | for i in range(1, len(shape) - 1):
75 | shape[i] = None
76 | return tf.TensorShape(shape)
77 |
78 |
79 | def _merge_first_two_dims(tensor):
80 | shape = _infer_shape(tensor)
81 | shape[0] *= shape[1]
82 | shape.pop(1)
83 | return tf.reshape(tensor, shape)
84 |
85 |
86 | def _split_first_two_dims(tensor, dim_0, dim_1):
87 | shape = _infer_shape(tensor)
88 | new_shape = [dim_0] + [dim_1] + shape[1:]
89 | return tf.reshape(tensor, new_shape)
90 |
91 |
92 | def _tile_to_beam_size(tensor, beam_size):
93 | """Tiles a given tensor by beam_size. """
94 | tensor = tf.expand_dims(tensor, axis=1)
95 | tile_dims = [1] * tensor.shape.ndims
96 | tile_dims[1] = beam_size
97 |
98 | return tf.tile(tensor, tile_dims)
99 |
100 |
101 | def _gather_2d(params, indices, name=None):
102 | """ Gather the 2nd dimension given indices
103 | :param params: A tensor with shape [batch_size, M, ...]
104 | :param indices: A tensor with shape [batch_size, N]
105 | :return: A tensor with shape [batch_size, N, ...]
106 | """
107 | batch_size = tf.shape(params)[0]
108 | range_size = tf.shape(indices)[1]
109 | batch_pos = tf.range(batch_size * range_size) // range_size
110 | batch_pos = tf.reshape(batch_pos, [batch_size, range_size])
111 | indices = tf.stack([batch_pos, indices], axis=-1)
112 | output = tf.gather_nd(params, indices, name=name)
113 |
114 | return output
115 |
116 |
117 | def _beam_search_step(time, func, state, batch_size, beam_size, alpha,
118 | pad_id, eos_id):
119 | # Compute log probabilities
120 | seqs, log_probs = state.inputs[:2]
121 | flat_seqs = _merge_first_two_dims(seqs)
122 | flat_state = nest.map_structure(lambda x: _merge_first_two_dims(x),
123 | state.state)
124 | step_log_probs, next_state = func(flat_seqs, flat_state)
125 | step_log_probs = _split_first_two_dims(step_log_probs, batch_size,
126 | beam_size)
127 | next_state = nest.map_structure(
128 | lambda x: _split_first_two_dims(x, batch_size, beam_size), next_state)
129 | curr_log_probs = tf.expand_dims(log_probs, 2) + step_log_probs
130 |
131 | # Apply length penalty
132 | length_penalty = tf.pow((5.0 + tf.to_float(time + 1)) / 6.0, alpha)
133 | curr_scores = curr_log_probs / length_penalty
134 | vocab_size = curr_scores.shape[-1].value or tf.shape(curr_scores)[-1]
135 |
136 | # Select top-k candidates
137 | # [batch_size, beam_size * vocab_size]
138 | curr_scores = tf.reshape(curr_scores, [-1, beam_size * vocab_size])
139 | # [batch_size, 2 * beam_size]
140 | top_scores, top_indices = tf.nn.top_k(curr_scores, k=2 * beam_size)
141 | # Shape: [batch_size, 2 * beam_size]
142 | beam_indices = top_indices // vocab_size
143 | symbol_indices = top_indices % vocab_size
144 | # Expand sequences
145 | # [batch_size, 2 * beam_size, time]
146 | candidate_seqs = _gather_2d(seqs, beam_indices)
147 | candidate_seqs = tf.concat([candidate_seqs,
148 | tf.expand_dims(symbol_indices, 2)], 2)
149 |
150 | # Expand sequences
151 | # Suppress finished sequences
152 | flags = tf.equal(symbol_indices, eos_id)
153 | # [batch, 2 * beam_size]
154 | alive_scores = top_scores + tf.to_float(flags) * tf.float32.min
155 | # [batch, beam_size]
156 | alive_scores, alive_indices = tf.nn.top_k(alive_scores, beam_size)
157 | alive_symbols = _gather_2d(symbol_indices, alive_indices)
158 | alive_indices = _gather_2d(beam_indices, alive_indices)
159 | alive_seqs = _gather_2d(seqs, alive_indices)
160 | # [batch_size, beam_size, time + 1]
161 | alive_seqs = tf.concat([alive_seqs, tf.expand_dims(alive_symbols, 2)], 2)
162 | alive_state = nest.map_structure(lambda x: _gather_2d(x, alive_indices),
163 | next_state)
164 | alive_log_probs = alive_scores * length_penalty
165 |
166 | # Select finished sequences
167 | prev_fin_flags, prev_fin_seqs, prev_fin_scores = state.finish
168 | # [batch, 2 * beam_size]
169 | step_fin_scores = top_scores + (1.0 - tf.to_float(flags)) * tf.float32.min
170 | # [batch, 3 * beam_size]
171 | fin_flags = tf.concat([prev_fin_flags, flags], axis=1)
172 | fin_scores = tf.concat([prev_fin_scores, step_fin_scores], axis=1)
173 | # [batch, beam_size]
174 | fin_scores, fin_indices = tf.nn.top_k(fin_scores, beam_size)
175 | fin_flags = _gather_2d(fin_flags, fin_indices)
176 | pad_seqs = tf.fill([batch_size, beam_size, 1],
177 | tf.constant(pad_id, tf.int32))
178 | prev_fin_seqs = tf.concat([prev_fin_seqs, pad_seqs], axis=2)
179 | fin_seqs = tf.concat([prev_fin_seqs, candidate_seqs], axis=1)
180 | fin_seqs = _gather_2d(fin_seqs, fin_indices)
181 |
182 | new_state = BeamSearchState(
183 | inputs=(alive_seqs, alive_log_probs, alive_scores),
184 | state=alive_state,
185 | finish=(fin_flags, fin_seqs, fin_scores),
186 | )
187 |
188 | return time + 1, new_state
189 |
190 |
191 | def beam_search(func, state, batch_size, beam_size, max_length, alpha,
192 | pad_id, bos_id, eos_id):
193 | init_seqs = tf.fill([batch_size, beam_size, 1], bos_id)
194 | init_log_probs = tf.constant([[0.] + [tf.float32.min] * (beam_size - 1)])
195 | init_log_probs = tf.tile(init_log_probs, [batch_size, 1])
196 | init_scores = tf.zeros_like(init_log_probs)
197 | fin_seqs = tf.zeros([batch_size, beam_size, 1], tf.int32)
198 | fin_scores = tf.fill([batch_size, beam_size], tf.float32.min)
199 | fin_flags = tf.zeros([batch_size, beam_size], tf.bool)
200 |
201 | state = BeamSearchState(
202 | inputs=(init_seqs, init_log_probs, init_scores),
203 | state=state,
204 | finish=(fin_flags, fin_seqs, fin_scores),
205 | )
206 |
207 | max_step = tf.reduce_max(max_length)
208 |
209 | def _is_finished(t, s):
210 | log_probs = s.inputs[1]
211 | finished_flags = s.finish[0]
212 | finished_scores = s.finish[2]
213 | max_lp = tf.pow(((5.0 + tf.to_float(max_step)) / 6.0), alpha)
214 | best_alive_score = log_probs[:, 0] / max_lp
215 | worst_finished_score = tf.reduce_min(
216 | finished_scores * tf.to_float(finished_flags), axis=1)
217 | add_mask = 1.0 - tf.to_float(tf.reduce_any(finished_flags, 1))
218 | worst_finished_score += tf.float32.min * add_mask
219 | bound_is_met = tf.reduce_all(tf.greater(worst_finished_score,
220 | best_alive_score))
221 |
222 | cond = tf.logical_and(tf.less(t, max_step),
223 | tf.logical_not(bound_is_met))
224 |
225 | return cond
226 |
227 | def _loop_fn(t, s):
228 | outs = _beam_search_step(t, func, s, batch_size, beam_size, alpha,
229 | pad_id, eos_id)
230 | return outs
231 |
232 | time = tf.constant(0, name="time")
233 | shape_invariants = BeamSearchState(
234 | inputs=(tf.TensorShape([None, None, None]),
235 | tf.TensorShape([None, None]),
236 | tf.TensorShape([None, None])),
237 | state=nest.map_structure(_infer_shape_invariants, state.state),
238 | finish=(tf.TensorShape([None, None]),
239 | tf.TensorShape([None, None, None]),
240 | tf.TensorShape([None, None]))
241 | )
242 | outputs = tf.while_loop(_is_finished, _loop_fn, [time, state],
243 | shape_invariants=[tf.TensorShape([]),
244 | shape_invariants],
245 | parallel_iterations=1,
246 | back_prop=False)
247 |
248 | final_state = outputs[1]
249 | alive_seqs = final_state.inputs[0]
250 | alive_scores = final_state.inputs[2]
251 | final_flags = final_state.finish[0]
252 | final_seqs = final_state.finish[1]
253 | final_scores = final_state.finish[2]
254 |
255 | alive_seqs.set_shape([None, beam_size, None])
256 | final_seqs.set_shape((None, beam_size, None))
257 |
258 | final_seqs = tf.where(tf.reduce_any(final_flags, 1), final_seqs,
259 | alive_seqs)
260 | final_scores = tf.where(tf.reduce_any(final_flags, 1), final_scores,
261 | alive_scores)
262 |
263 | return final_seqs, final_scores
264 |
265 |
266 | def create_inference_graph(model_fns, features, params):
267 | if not isinstance(model_fns, (list, tuple)):
268 | raise ValueError("mode_fns must be a list or tuple")
269 |
270 | features = copy.copy(features)
271 |
272 | decode_length = params.decode_length
273 | beam_size = params.beam_size
274 | top_beams = params.top_beams
275 | alpha = params.decode_alpha
276 |
277 | # Compute initial state if necessary
278 | states = []
279 | funcs = []
280 |
281 | for model_fn in model_fns:
282 | if callable(model_fn):
283 | # For non-incremental decoding
284 | states.append({})
285 | funcs.append(model_fn)
286 | else:
287 | # For incremental decoding where model_fn is a tuple:
288 | # (encoding_fn, decoding_fn)
289 | states.append(model_fn[0](features))
290 | funcs.append(model_fn[1])
291 |
292 | batch_size = tf.shape(features["source"])[0]
293 | pad_id = params.mapping["target"][params.pad]
294 | bos_id = params.mapping["target"][params.bos]
295 | eos_id = params.mapping["target"][params.eos]
296 |
297 | # Expand the inputs in to the beam size
298 | # [batch, length] => [batch, beam_size, length]
299 | features["source"] = tf.expand_dims(features["source"], 1)
300 | features["source"] = tf.tile(features["source"], [1, beam_size, 1])
301 | shape = tf.shape(features["source"])
302 |
303 | # [batch, beam_size, length] => [batch * beam_size, length]
304 | features["source"] = tf.reshape(features["source"],
305 | [shape[0] * shape[1], shape[2]])
306 |
307 | # For source sequence length
308 | features["source_length"] = tf.expand_dims(features["source_length"], 1)
309 | features["source_length"] = tf.tile(features["source_length"],
310 | [1, beam_size])
311 | shape = tf.shape(features["source_length"])
312 |
313 | max_length = features["source_length"] + decode_length
314 |
315 | # [batch, beam_size, length] => [batch * beam_size, length]
316 | features["source_length"] = tf.reshape(features["source_length"],
317 | [shape[0] * shape[1]])
318 | decoding_fn = _get_inference_fn(funcs, features)
319 | states = nest.map_structure(lambda x: _tile_to_beam_size(x, beam_size),
320 | states)
321 |
322 | seqs, scores = beam_search(decoding_fn, states, batch_size, beam_size,
323 | max_length, alpha, pad_id, bos_id, eos_id)
324 |
325 | return seqs[:, :top_beams, 1:], scores[:, :top_beams]
326 |
--------------------------------------------------------------------------------
/thumt/utils/inference_ctx.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import copy
9 | import tensorflow as tf
10 |
11 | from collections import namedtuple
12 | from tensorflow.python.util import nest
13 |
14 |
15 | class BeamSearchState(namedtuple("BeamSearchState",
16 | ("inputs", "state", "finish"))):
17 | pass
18 |
19 |
20 | def _get_inference_fn(model_fns, features):
21 | def inference_fn(inputs, state):
22 | local_features = {
23 | "source": features["source"],
24 | "source_length": features["source_length"],
25 | "context": features["context"],
26 | "context_length": features["context_length"],
27 | # [bos_id, ...] => [..., 0]
28 | "target": tf.pad(inputs[:, 1:], [[0, 0], [0, 1]]),
29 | "target_length": tf.fill([tf.shape(inputs)[0]],
30 | tf.shape(inputs)[1])
31 | }
32 |
33 | outputs = []
34 | next_state = []
35 |
36 | for (model_fn, model_state) in zip(model_fns, state):
37 | if model_state:
38 | output, new_state = model_fn(local_features, model_state)
39 | outputs.append(output)
40 | next_state.append(new_state)
41 | else:
42 | output = model_fn(local_features)
43 | outputs.append(output)
44 | next_state.append({})
45 |
46 | # Ensemble
47 | log_prob = tf.add_n(outputs) / float(len(outputs))
48 |
49 | return log_prob, next_state
50 |
51 | return inference_fn
52 |
53 |
54 | def _infer_shape(x):
55 | x = tf.convert_to_tensor(x)
56 |
57 | # If unknown rank, return dynamic shape
58 | if x.shape.dims is None:
59 | return tf.shape(x)
60 |
61 | static_shape = x.shape.as_list()
62 | dynamic_shape = tf.shape(x)
63 |
64 | ret = []
65 | for i in range(len(static_shape)):
66 | dim = static_shape[i]
67 | if dim is None:
68 | dim = dynamic_shape[i]
69 | ret.append(dim)
70 |
71 | return ret
72 |
73 |
74 | def _infer_shape_invariants(tensor):
75 | shape = tensor.shape.as_list()
76 | for i in range(1, len(shape) - 1):
77 | shape[i] = None
78 | return tf.TensorShape(shape)
79 |
80 |
81 | def _merge_first_two_dims(tensor):
82 | shape = _infer_shape(tensor)
83 | shape[0] *= shape[1]
84 | shape.pop(1)
85 | return tf.reshape(tensor, shape)
86 |
87 |
88 | def _split_first_two_dims(tensor, dim_0, dim_1):
89 | shape = _infer_shape(tensor)
90 | new_shape = [dim_0] + [dim_1] + shape[1:]
91 | return tf.reshape(tensor, new_shape)
92 |
93 |
94 | def _tile_to_beam_size(tensor, beam_size):
95 | """Tiles a given tensor by beam_size. """
96 | tensor = tf.expand_dims(tensor, axis=1)
97 | tile_dims = [1] * tensor.shape.ndims
98 | tile_dims[1] = beam_size
99 |
100 | return tf.tile(tensor, tile_dims)
101 |
102 |
103 | def _gather_2d(params, indices, name=None):
104 | """ Gather the 2nd dimension given indices
105 | :param params: A tensor with shape [batch_size, M, ...]
106 | :param indices: A tensor with shape [batch_size, N]
107 | :return: A tensor with shape [batch_size, N, ...]
108 | """
109 | batch_size = tf.shape(params)[0]
110 | range_size = tf.shape(indices)[1]
111 | batch_pos = tf.range(batch_size * range_size) // range_size
112 | batch_pos = tf.reshape(batch_pos, [batch_size, range_size])
113 | indices = tf.stack([batch_pos, indices], axis=-1)
114 | output = tf.gather_nd(params, indices, name=name)
115 |
116 | return output
117 |
118 |
119 | def _beam_search_step(time, func, state, batch_size, beam_size, alpha,
120 | pad_id, eos_id):
121 | # Compute log probabilities
122 | print('st2', state)
123 | seqs, log_probs = state.inputs[:2]
124 | flat_seqs = _merge_first_two_dims(seqs)
125 | flat_state = nest.map_structure(lambda x: _merge_first_two_dims(x),
126 | state.state)
127 | print('st3', flat_state)
128 | step_log_probs, next_state = func(flat_seqs, flat_state)
129 | step_log_probs = _split_first_two_dims(step_log_probs, batch_size,
130 | beam_size)
131 | print('st4', next_state)
132 | next_state = nest.map_structure(
133 | lambda x: _split_first_two_dims(x, batch_size, beam_size), next_state)
134 | curr_log_probs = tf.expand_dims(log_probs, 2) + step_log_probs
135 |
136 | # Apply length penalty
137 | length_penalty = tf.pow((5.0 + tf.to_float(time + 1)) / 6.0, alpha)
138 | curr_scores = curr_log_probs / length_penalty
139 | vocab_size = curr_scores.shape[-1].value or tf.shape(curr_scores)[-1]
140 |
141 | # Select top-k candidates
142 | # [batch_size, beam_size * vocab_size]
143 | curr_scores = tf.reshape(curr_scores, [-1, beam_size * vocab_size])
144 | # [batch_size, 2 * beam_size]
145 | top_scores, top_indices = tf.nn.top_k(curr_scores, k=2 * beam_size)
146 | # Shape: [batch_size, 2 * beam_size]
147 | beam_indices = top_indices // vocab_size
148 | symbol_indices = top_indices % vocab_size
149 | # Expand sequences
150 | # [batch_size, 2 * beam_size, time]
151 | candidate_seqs = _gather_2d(seqs, beam_indices)
152 | candidate_seqs = tf.concat([candidate_seqs,
153 | tf.expand_dims(symbol_indices, 2)], 2)
154 |
155 | # Expand sequences
156 | # Suppress finished sequences
157 | flags = tf.equal(symbol_indices, eos_id)
158 | # [batch, 2 * beam_size]
159 | alive_scores = top_scores + tf.to_float(flags) * tf.float32.min
160 | # [batch, beam_size]
161 | alive_scores, alive_indices = tf.nn.top_k(alive_scores, beam_size)
162 | alive_symbols = _gather_2d(symbol_indices, alive_indices)
163 | alive_indices = _gather_2d(beam_indices, alive_indices)
164 | alive_seqs = _gather_2d(seqs, alive_indices)
165 | # [batch_size, beam_size, time + 1]
166 | alive_seqs = tf.concat([alive_seqs, tf.expand_dims(alive_symbols, 2)], 2)
167 | alive_state = nest.map_structure(lambda x: _gather_2d(x, alive_indices),
168 | next_state)
169 | print('st5', alive_state)
170 | alive_log_probs = alive_scores * length_penalty
171 |
172 | # Select finished sequences
173 | prev_fin_flags, prev_fin_seqs, prev_fin_scores = state.finish
174 | # [batch, 2 * beam_size]
175 | step_fin_scores = top_scores + (1.0 - tf.to_float(flags)) * tf.float32.min
176 | # [batch, 3 * beam_size]
177 | fin_flags = tf.concat([prev_fin_flags, flags], axis=1)
178 | fin_scores = tf.concat([prev_fin_scores, step_fin_scores], axis=1)
179 | # [batch, beam_size]
180 | fin_scores, fin_indices = tf.nn.top_k(fin_scores, beam_size)
181 | fin_flags = _gather_2d(fin_flags, fin_indices)
182 | pad_seqs = tf.fill([batch_size, beam_size, 1],
183 | tf.constant(pad_id, tf.int32))
184 | prev_fin_seqs = tf.concat([prev_fin_seqs, pad_seqs], axis=2)
185 | fin_seqs = tf.concat([prev_fin_seqs, candidate_seqs], axis=1)
186 | fin_seqs = _gather_2d(fin_seqs, fin_indices)
187 |
188 | new_state = BeamSearchState(
189 | inputs=(alive_seqs, alive_log_probs, alive_scores),
190 | state=alive_state,
191 | finish=(fin_flags, fin_seqs, fin_scores),
192 | )
193 |
194 | return time + 1, new_state
195 |
196 |
197 | def beam_search(func, state, batch_size, beam_size, max_length, alpha,
198 | pad_id, bos_id, eos_id):
199 | init_seqs = tf.fill([batch_size, beam_size, 1], bos_id)
200 | init_log_probs = tf.constant([[0.] + [tf.float32.min] * (beam_size - 1)])
201 | init_log_probs = tf.tile(init_log_probs, [batch_size, 1])
202 | init_scores = tf.zeros_like(init_log_probs)
203 | fin_seqs = tf.zeros([batch_size, beam_size, 1], tf.int32)
204 | fin_scores = tf.fill([batch_size, beam_size], tf.float32.min)
205 | fin_flags = tf.zeros([batch_size, beam_size], tf.bool)
206 |
207 | state = BeamSearchState(
208 | inputs=(init_seqs, init_log_probs, init_scores),
209 | state=state,
210 | finish=(fin_flags, fin_seqs, fin_scores),
211 | )
212 | print('st1',state)
213 |
214 | max_step = tf.reduce_max(max_length)
215 |
216 | def _is_finished(t, s):
217 | log_probs = s.inputs[1]
218 | finished_flags = s.finish[0]
219 | finished_scores = s.finish[2]
220 | max_lp = tf.pow(((5.0 + tf.to_float(max_step)) / 6.0), alpha)
221 | best_alive_score = log_probs[:, 0] / max_lp
222 | worst_finished_score = tf.reduce_min(
223 | finished_scores * tf.to_float(finished_flags), axis=1)
224 | add_mask = 1.0 - tf.to_float(tf.reduce_any(finished_flags, 1))
225 | worst_finished_score += tf.float32.min * add_mask
226 | bound_is_met = tf.reduce_all(tf.greater(worst_finished_score,
227 | best_alive_score))
228 |
229 | cond = tf.logical_and(tf.less(t, max_step),
230 | tf.logical_not(bound_is_met))
231 |
232 | return cond
233 |
234 | def _loop_fn(t, s):
235 | outs = _beam_search_step(t, func, s, batch_size, beam_size, alpha,
236 | pad_id, eos_id)
237 | return outs
238 |
239 | time = tf.constant(0, name="time")
240 | shape_invariants = BeamSearchState(
241 | inputs=(tf.TensorShape([None, None, None]),
242 | tf.TensorShape([None, None]),
243 | tf.TensorShape([None, None])),
244 | state=nest.map_structure(_infer_shape_invariants, state.state),
245 | finish=(tf.TensorShape([None, None]),
246 | tf.TensorShape([None, None, None]),
247 | tf.TensorShape([None, None]))
248 | )
249 | outputs = tf.while_loop(_is_finished, _loop_fn, [time, state],
250 | shape_invariants=[tf.TensorShape([]),
251 | shape_invariants],
252 | parallel_iterations=1,
253 | back_prop=False)
254 |
255 | final_state = outputs[1]
256 | alive_seqs = final_state.inputs[0]
257 | alive_scores = final_state.inputs[2]
258 | final_flags = final_state.finish[0]
259 | final_seqs = final_state.finish[1]
260 | final_scores = final_state.finish[2]
261 |
262 | alive_seqs.set_shape([None, beam_size, None])
263 | final_seqs.set_shape((None, beam_size, None))
264 |
265 | final_seqs = tf.where(tf.reduce_any(final_flags, 1), final_seqs,
266 | alive_seqs)
267 | final_scores = tf.where(tf.reduce_any(final_flags, 1), final_scores,
268 | alive_scores)
269 |
270 | return final_seqs, final_scores
271 |
272 |
273 | def create_inference_graph(model_fns, features, params):
274 | if not isinstance(model_fns, (list, tuple)):
275 | raise ValueError("mode_fns must be a list or tuple")
276 |
277 | features = copy.copy(features)
278 |
279 | decode_length = params.decode_length
280 | beam_size = params.beam_size
281 | top_beams = params.top_beams
282 | alpha = params.decode_alpha
283 |
284 | # Compute initial state if necessary
285 | states = []
286 | funcs = []
287 |
288 | for model_fn in model_fns:
289 | if callable(model_fn):
290 | # For non-incremental decoding
291 | states.append({})
292 | funcs.append(model_fn)
293 | else:
294 | # For incremental decoding where model_fn is a tuple:
295 | # (encoding_fn, decoding_fn)
296 | states.append(model_fn[0](features))
297 | funcs.append(model_fn[1])
298 |
299 | batch_size = tf.shape(features["source"])[0]
300 | pad_id = params.mapping["target"][params.pad]
301 | bos_id = params.mapping["target"][params.bos]
302 | eos_id = params.mapping["target"][params.eos]
303 |
304 | # Expand the inputs in to the beam size
305 | # [batch, length] => [batch, beam_size, length]
306 | features["source"] = tf.expand_dims(features["source"], 1)
307 | features["source"] = tf.tile(features["source"], [1, beam_size, 1])
308 | shape = tf.shape(features["source"])
309 |
310 | # [batch, beam_size, length] => [batch * beam_size, length]
311 | features["source"] = tf.reshape(features["source"],
312 | [shape[0] * shape[1], shape[2]])
313 |
314 | # For source sequence length
315 | features["source_length"] = tf.expand_dims(features["source_length"], 1)
316 | features["source_length"] = tf.tile(features["source_length"],
317 | [1, beam_size])
318 | shape = tf.shape(features["source_length"])
319 |
320 | max_length = features["source_length"] + decode_length
321 |
322 | # [batch, beam_size, length] => [batch * beam_size, length]
323 | features["source_length"] = tf.reshape(features["source_length"],
324 | [shape[0] * shape[1]])
325 |
326 | ######
327 | # Expand the inputs in to the beam size
328 | # [batch, length] => [batch, beam_size, length]
329 | features["context"] = tf.expand_dims(features["context"], 1)
330 | features["context"] = tf.tile(features["context"], [1, beam_size, 1])
331 | shape = tf.shape(features["context"])
332 |
333 | # [batch, beam_size, length] => [batch * beam_size, length]
334 | features["context"] = tf.reshape(features["context"],
335 | [shape[0] * shape[1], shape[2]])
336 |
337 | # For context sequence length
338 | features["context_length"] = tf.expand_dims(features["context_length"], 1)
339 | features["context_length"] = tf.tile(features["context_length"],
340 | [1, beam_size])
341 | shape = tf.shape(features["context_length"])
342 |
343 | # [batch, beam_size, length] => [batch * beam_size, length]
344 | features["context_length"] = tf.reshape(features["context_length"],
345 | [shape[0] * shape[1]])
346 |
347 |
348 | decoding_fn = _get_inference_fn(funcs, features)
349 | states = nest.map_structure(lambda x: _tile_to_beam_size(x, beam_size),
350 | states)
351 |
352 | seqs, scores = beam_search(decoding_fn, states, batch_size, beam_size,
353 | max_length, alpha, pad_id, bos_id, eos_id)
354 |
355 | return seqs[:, :top_beams, 1:], scores[:, :top_beams]
356 |
--------------------------------------------------------------------------------
/thumt/utils/optimize.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import tensorflow as tf
9 |
10 |
11 | def _get_loss_variable(graph=None):
12 | graph = graph or tf.get_default_graph()
13 | loss_tensors = tf.get_collection("loss")
14 |
15 | if len(loss_tensors) == 1:
16 | loss_tensor = loss_tensors[0]
17 | elif not loss_tensors:
18 | try:
19 | loss_tensor = graph.get_tensor_by_name("loss_tensor:0")
20 | except KeyError:
21 | return None
22 | else:
23 | tf.logging.error("Multiple tensors in loss collection.")
24 | return None
25 |
26 | return loss_tensor
27 |
28 |
29 | def _create_loss_variable(graph=None):
30 | graph = graph or tf.get_default_graph()
31 | if _get_loss_variable(graph) is not None:
32 | raise ValueError("'loss' already exists.")
33 |
34 | # Create in proper graph and base name_scope.
35 | with graph.as_default() as g, g.name_scope(None):
36 | tensor = tf.get_variable("loss", shape=[], dtype=tf.float32,
37 | initializer=tf.zeros_initializer(),
38 | trainable=False,
39 | collections=[tf.GraphKeys.GLOBAL_VARIABLES,
40 | "loss"])
41 |
42 | return tensor
43 |
44 |
45 | def _get_or_create_loss_variable(graph=None):
46 | graph = graph or tf.get_default_graph()
47 | loss_tensor = _get_loss_variable(graph)
48 | if loss_tensor is None:
49 | loss_tensor = _create_loss_variable(graph)
50 | return loss_tensor
51 |
52 |
53 | def _zero_variables(variables, name=None):
54 | ops = []
55 |
56 | for var in variables:
57 | with tf.device(var.device):
58 | op = var.assign(tf.zeros(var.shape.as_list()))
59 | ops.append(op)
60 |
61 | return tf.group(*ops, name=name or "zero_variables")
62 |
63 |
64 | def _replicate_variables(variables, device=None):
65 | new_vars = []
66 |
67 | for var in variables:
68 | device = device or var.device
69 | with tf.device(device):
70 | name = var.name.split(":")[0].rstrip("/") + "/replica"
71 | new_vars.append(tf.Variable(tf.zeros(var.shape.as_list()),
72 | name=name, trainable=False))
73 |
74 | return new_vars
75 |
76 |
77 | def _collect_gradients(gradients, variables):
78 | ops = []
79 |
80 | for grad, var in zip(gradients, variables):
81 | if isinstance(grad, tf.Tensor):
82 | ops.append(tf.assign_add(var, grad))
83 | else:
84 | ops.append(tf.scatter_add(var, grad.indices, grad.values))
85 |
86 | return tf.group(*ops, name="collect_gradients")
87 |
88 |
89 | def _scale_variables(variables, scale):
90 | if not isinstance(variables, (list, tuple)):
91 | return tf.assign(variables, scale * variables)
92 |
93 | ops = []
94 |
95 | for var in variables:
96 | ops.append(tf.assign(var, scale * var))
97 |
98 | return tf.group(*ops, name="scale_variables")
99 |
100 |
101 | def create_train_op(loss, optimizer, global_step, params):
102 | with tf.name_scope("create_train_op"):
103 | grads_and_vars = optimizer.compute_gradients(
104 | loss, colocate_gradients_with_ops=True)
105 | gradients = [item[0] for item in grads_and_vars]
106 | variables = [item[1] for item in grads_and_vars]
107 |
108 | if params.update_cycle == 1:
109 | zero_variables_op = tf.no_op("zero_variables")
110 | collect_op = tf.no_op("collect_op")
111 | scale_op = tf.no_op("scale_op")
112 | else:
113 | # collect
114 | loss_tensor = _get_or_create_loss_variable()
115 | slot_variables = _replicate_variables(variables)
116 | zero_variables_op = _zero_variables(slot_variables + [loss_tensor])
117 | collect_grads_op = _collect_gradients(gradients, slot_variables)
118 | collect_loss_op = tf.assign_add(loss_tensor, loss)
119 | collect_op = tf.group(collect_loss_op, collect_grads_op,
120 | name="collect_op")
121 | # scale
122 | scale = 1.0 / params.update_cycle
123 | scale_grads_op = _scale_variables(slot_variables, scale)
124 | scale_loss_op = _scale_variables(loss_tensor, scale)
125 | scale_op = tf.group(scale_grads_op, scale_loss_op, name="scale_op")
126 | gradients = slot_variables
127 | loss = tf.convert_to_tensor(loss_tensor)
128 |
129 | # Add summaries
130 | tf.summary.scalar("loss", loss)
131 | tf.summary.scalar("global_norm/gradient_norm",
132 | tf.global_norm(gradients))
133 |
134 | for gradient, variable in zip(gradients, variables):
135 | if isinstance(gradient, tf.IndexedSlices):
136 | grad_values = gradient.values
137 | else:
138 | grad_values = gradient
139 |
140 | if grad_values is not None:
141 | var_name = variable.name.replace(":", "_")
142 | tf.summary.histogram("gradients/%s" % var_name, grad_values)
143 | tf.summary.scalar("gradient_norm/%s" % var_name,
144 | tf.global_norm([grad_values]))
145 |
146 | # Gradient clipping
147 | if isinstance(params.clip_grad_norm or None, float):
148 | gradients, _ = tf.clip_by_global_norm(gradients,
149 | params.clip_grad_norm)
150 |
151 | # Update variables
152 | grads_and_vars = list(zip(gradients, tf.trainable_variables()))
153 | train_op = optimizer.apply_gradients(grads_and_vars, global_step)
154 |
155 | ops = {
156 | "zero_op": zero_variables_op,
157 | "collect_op": collect_op,
158 | "scale_op": scale_op,
159 | "train_op": train_op
160 | }
161 |
162 | return loss, ops
163 |
--------------------------------------------------------------------------------
/thumt/utils/parallel.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import operator
9 |
10 | import tensorflow as tf
11 |
12 |
13 | class GPUParamServerDeviceSetter(object):
14 |
15 | def __init__(self, worker_device, ps_devices):
16 | self.ps_devices = ps_devices
17 | self.worker_device = worker_device
18 | self.ps_sizes = [0] * len(self.ps_devices)
19 |
20 | def __call__(self, op):
21 | if op.device:
22 | return op.device
23 | if op.type not in ["Variable", "VariableV2", "VarHandleOp"]:
24 | return self.worker_device
25 |
26 | # Gets the least loaded ps_device
27 | device_index, _ = min(enumerate(self.ps_sizes),
28 | key=operator.itemgetter(1))
29 | device_name = self.ps_devices[device_index]
30 | var_size = op.outputs[0].get_shape().num_elements()
31 | self.ps_sizes[device_index] += var_size
32 |
33 | return device_name
34 |
35 |
36 | def _maybe_repeat(x, n):
37 | if isinstance(x, list):
38 | assert len(x) == n
39 | return x
40 | else:
41 | return [x] * n
42 |
43 |
44 | def _create_device_setter(is_cpu_ps, worker, num_gpus):
45 | if is_cpu_ps:
46 | # tf.train.replica_device_setter supports placing variables on the CPU,
47 | # all on one GPU, or on ps_servers defined in a cluster_spec.
48 | return tf.train.replica_device_setter(
49 | worker_device=worker, ps_device="/cpu:0", ps_tasks=1)
50 | else:
51 | gpus = ["/gpu:%d" % i for i in range(num_gpus)]
52 | return GPUParamServerDeviceSetter(worker, gpus)
53 |
54 |
55 | # Data-level parallelism
56 | def data_parallelism(devices, fn, *args, **kwargs):
57 | num_worker = len(devices)
58 |
59 | # Replicate args and kwargs
60 | if args:
61 | new_args = [_maybe_repeat(arg, num_worker) for arg in args]
62 | # Transpose
63 | new_args = [list(x) for x in zip(*new_args)]
64 | else:
65 | new_args = [[] for _ in range(num_worker)]
66 |
67 | new_kwargs = [{} for _ in range(num_worker)]
68 |
69 | for k, v in kwargs.iteritems():
70 | vals = _maybe_repeat(v, num_worker)
71 |
72 | for i in range(num_worker):
73 | new_kwargs[i][k] = vals[i]
74 |
75 | fns = _maybe_repeat(fn, num_worker)
76 |
77 | # Now make the parallel call.
78 | outputs = []
79 |
80 | for i in range(num_worker):
81 | worker = "/gpu:%d" % i
82 | device_setter = _create_device_setter(False, worker, len(devices))
83 | with tf.variable_scope(tf.get_variable_scope(), reuse=(i != 0)):
84 | with tf.name_scope("parallel_%d" % i):
85 | with tf.device(device_setter):
86 | outputs.append(fns[i](*new_args[i], **new_kwargs[i]))
87 |
88 | if isinstance(outputs[0], tuple):
89 | outputs = list(zip(*outputs))
90 | outputs = tuple([list(o) for o in outputs])
91 |
92 | return outputs
93 |
94 |
95 | def shard_features(features, device_list):
96 | num_datashards = len(device_list)
97 |
98 | sharded_features = {}
99 |
100 | for k, v in features.iteritems():
101 | v = tf.convert_to_tensor(v)
102 | if not v.shape.as_list():
103 | v = tf.expand_dims(v, axis=-1)
104 | v = tf.tile(v, [num_datashards])
105 | with tf.device(v.device):
106 | sharded_features[k] = tf.split(v, num_datashards, 0)
107 |
108 | datashard_to_features = []
109 |
110 | for d in range(num_datashards):
111 | feat = {
112 | k: v[d] for k, v in sharded_features.iteritems()
113 | }
114 | datashard_to_features.append(feat)
115 |
116 | return datashard_to_features
117 |
118 |
119 | def parallel_model(model_fn, features, devices, use_cpu=False):
120 | devices = ["gpu:%d" % d for d in devices]
121 |
122 | if use_cpu:
123 | devices += ["cpu:0"]
124 |
125 | if len(devices) == 1:
126 | return [model_fn(features)]
127 |
128 | features = shard_features(features, devices)
129 |
130 | outputs = data_parallelism(devices, model_fn, features)
131 | return outputs
132 |
--------------------------------------------------------------------------------
/thumt/utils/sample.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
--------------------------------------------------------------------------------
/thumt/utils/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The THUMT Authors
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import tensorflow as tf
9 |
10 |
11 | def session_run(monitored_session, args):
12 | # Call raw TF session directly
13 | return monitored_session._tf_sess().run(args)
14 |
--------------------------------------------------------------------------------