├── .gitignore ├── LICENSE ├── README.md ├── configs ├── config_template.yaml ├── ende6.yaml └── news6.yaml ├── evaluate.py ├── models ├── __init__.py ├── deeprnn.py ├── indrnn.py ├── model.py ├── parallel.py ├── rnnsearch.py └── transformer.py ├── multi-bleu.perl ├── third_party ├── __init__.py └── tensor2tensor │ ├── __init__.py │ ├── avg_checkpoints.py │ ├── common_attention.py │ ├── common_layers.py │ └── expert_utils.py ├── train.py ├── train_wkd.py ├── utils.py └── vocab.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # SageMath parsed files 79 | *.sage.py 80 | 81 | # Environments 82 | .env 83 | .venv 84 | env/ 85 | venv/ 86 | ENV/ 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | .spyproject 91 | 92 | # Rope project settings 93 | .ropeproject 94 | 95 | # mkdocs documentation 96 | /site 97 | 98 | # mypy 99 | .mypy_cache/ 100 | 101 | corpora 102 | logdir 103 | preprocessed 104 | .idea/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Simpler Implementation of the Transformer: 2 | 3 | This project is originally forked from . 4 | Lots of change has been made to make it easy-to-use and flexible. 5 | Several new features have been implemented and tested. 6 | Some functions are taken from . 7 | 8 | ## Highlight Features 9 | - Batching data with bucket mechanism, which allows higher utilization of computational resources. 10 | - Beam search that support batching and length penalty. 11 | - Using yaml to config all hyper-parameters, as well as all other settings. 12 | - Supporting caching decoder outputs, which accelerates decoding on CPUs. 13 | 14 | ## Usage 15 | Create a new config file. 16 | 17 | `cp config_template.yaml your_config.yaml` 18 | 19 | Configure *train.src_path*, *train.dst_path*, *scr_vocab* and *dst_vocab* in *your_config.yaml*. 20 | After that, run the following command to build the vocabulary files. 21 | 22 | `python vocab.py -c your_config.yaml` 23 | 24 | Edit *src\_vocab_size* and *dst\_vocab_size* in *your_config.yaml* according to the vocabulary files generated in previous step. 25 | 26 | Run the following command to start training loops: 27 | 28 | `python train.py -c your_config.yaml` 29 | 30 | 31 | ## Contact 32 | Raise an issue on [github](https://github.com/chqiwang/transformer) or email to . 33 | -------------------------------------------------------------------------------- /configs/config_template.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: 'Transformer' 3 | src_vocab: 4 | dst_vocab: 5 | src_vocab_size: 6 | dst_vocab_size: 7 | hidden_units: 512 8 | scale_embedding: True 9 | tie_embedding_and_softmax: True 10 | tie_embeddings: False 11 | attention_dropout_rate: 0.0 12 | residual_dropout_rate: 0.1 13 | num_blocks: 6 14 | num_heads: 8 15 | ff_activation: 'relu' 16 | model_dir: 17 | train: 18 | num_gpus: 8 19 | src_path: 20 | dst_path: 21 | tokens_per_batch: 30000 22 | max_length: 125 23 | num_epochs: 100 24 | num_steps: 300000 25 | save_freq: 1000 26 | show_freq: 1 27 | summary_freq: 100 28 | grads_clip: 0 29 | optimizer: 'adam_decay' 30 | learning_rate: 1 31 | warmup_steps: 4000 32 | label_smoothing: 0.1 33 | toleration: 10 34 | eval_on_dev: False 35 | dev: 36 | batch_size: 256 37 | src_path: 38 | ref_path: 39 | output_path: 40 | 41 | test: 42 | batch_size: 256 43 | max_target_length: 200 44 | lp_alpha: 0.6 45 | beam_size: 4 46 | num_gpus: 8 47 | 48 | set1: 49 | src_path: 50 | ref_path: 51 | output_path: 52 | cmd: 53 | -------------------------------------------------------------------------------- /configs/ende6.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: Transformer 3 | src_vocab: '../../parallel/ende/vocab.bpe.32000' 4 | dst_vocab: '../../parallel/ende/vocab.bpe.32000' 5 | src_vocab_size: 37008 6 | dst_vocab_size: 37008 7 | hidden_units: 512 8 | scale_embedding: True 9 | tie_embeddings: True 10 | tie_embedding_and_softmax: True 11 | attention_dropout_rate: 0.0 12 | residual_dropout_rate: 0.1 13 | num_blocks: 6 14 | num_heads: 8 15 | ff_activation: 'relu' 16 | model_dir: 'model-ende6' 17 | train: 18 | num_gpus: 8 19 | src_path: '../../parallel/ende/train.tok.clean.bpe.32000.en' 20 | dst_path: '../../parallel/ende/train.tok.clean.bpe.32000.de' 21 | tokens_per_batch: 30000 22 | max_length: 150 23 | num_epochs: 100 24 | num_steps: 100000 25 | save_freq: 1000 26 | show_freq: 1 27 | summary_freq: 100 28 | grads_clip: 0 29 | optimizer: 'adam_decay' 30 | learning_rate: 1 31 | warmup_steps: 4000 32 | label_smoothing: 0.1 33 | toleration: # Empty value denotes that we save model anyway 34 | eval_on_dev: True 35 | dev: 36 | batch_size: 256 37 | src_path: '../../parallel/ende/newstest2013.tok.bpe.32000.en' 38 | ref_path: '../../parallel/ende/newstest2013.tok.de' 39 | output_path: 'model-ende6/newstest2013.output' 40 | cmd: > 41 | perl -ple 's{{(\S)-(\S)}}{{$1 ##AT##-##AT## $2}}g' < {ref} > /tmp/ende.ref && 42 | perl -ple 's{{(\S)-(\S)}}{{$1 ##AT##-##AT## $2}}g' < {output} > /tmp/ende.output && 43 | perl multi-bleu.perl /tmp/ende.ref < /tmp/ende.output 2>/dev/null | awk '{{print($3)}}' | awk -F, '{{print $1}}' 44 | test: 45 | batch_size: 256 46 | max_target_length: 200 47 | lp_alpha: 0.6 48 | beam_size: 4 49 | num_gpus: 8 50 | # batch_size: 1 51 | # max_target_length: 200 52 | # beam_size: 1 53 | # num_gpus: 1 54 | # set_wmt13: 55 | # src_path: '../../parallel/ende/newstest2013.tok.bpe.32000.en' 56 | # ref_path: '../../parallel/ende/newstest2013.tok.de' 57 | # output_path: 'model-ende6-pt4/newstest2013.output' 58 | # cmd: > 59 | # perl -ple 's{{(\S)-(\S)}}{{$1 ##AT##-##AT## $2}}g' < {ref} > /tmp/ende.ref && 60 | # perl -ple 's{{(\S)-(\S)}}{{$1 ##AT##-##AT## $2}}g' < {output} > /tmp/ende.output && 61 | # perl multi-bleu.perl /tmp/ende.ref < /tmp/ende.output 2>/dev/null | awk '{{print($3)}}' | awk -F, '{{print $1}}' 62 | set_wmt14: 63 | src_path: '../../parallel/ende/newstest2014.tok.bpe.32000.en' 64 | dst_path: '../../parallel/ende/newstest2014.tok.bpe.32000.de' 65 | ref_path: '../../parallel/ende/newstest2014.tok.de' 66 | output_path: 'model-ende6-pt4/newstest2014.output' 67 | cmd: > 68 | perl -ple 's{{(\S)-(\S)}}{{$1 ##AT##-##AT## $2}}g' < {ref} > /tmp/ende.ref && 69 | perl -ple 's{{(\S)-(\S)}}{{$1 ##AT##-##AT## $2}}g' < {output} > /tmp/ende.output && 70 | perl multi-bleu.perl /tmp/ende.ref < /tmp/ende.output 2>/dev/null | awk '{{print($3)}}' | awk -F, '{{print $1}}' 71 | # set_train: 72 | # src_path: '../../parallel/ende/train.tok.clean.bpe.32000.en' 73 | # output_path: 'model-ende6/train.de' 74 | -------------------------------------------------------------------------------- /configs/news6.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: Transformer 3 | src_vocab: '../../parallel/news180w/zh.vocab' 4 | dst_vocab: '../../parallel/news180w/en.vocab' 5 | src_vocab_size: 9031 6 | dst_vocab_size: 34218 7 | hidden_units: 512 8 | scale_embedding: True 9 | num_shards: 1 10 | tie_embedding_and_softmax: True 11 | attention_dropout_rate: 0.0 12 | residual_dropout_rate: 0.1 13 | num_blocks: 6 14 | num_heads: 8 15 | ff_activation: 'relu' 16 | model_dir: 'model-news6' 17 | train: 18 | num_gpus: 8 19 | src_path: '../../parallel/news180w/zh.char.txt' 20 | dst_path: '../../parallel/news180w/en.bpe.txt' 21 | tokens_per_batch: 30000 22 | max_length: 125 23 | num_epochs: 100 24 | num_steps: 100000 25 | save_freq: 1000 26 | show_freq: 1 27 | summary_freq: 100 28 | grads_clip: 0 29 | optimizer: 'adam_decay' 30 | learning_rate: 1 31 | warmup_steps: 4000 32 | label_smoothing: 0.1 33 | toleration: 34 | eval_on_dev: True 35 | dev: 36 | batch_size: 256 37 | src_path: '../../parallel/test/nist02.char.txt' 38 | ref_path: '../../parallel/test/nist02.ref' 39 | output_path: 'model-news6/nist02.output' 40 | 41 | test: 42 | batch_size: 256 43 | max_target_length: 200 44 | lp_alpha: 0.6 45 | beam_size: 4 46 | num_gpus: 8 47 | 48 | set_nist02: 49 | src_path: '../../parallel/test/nist02.char.txt' 50 | ref_path: '../../parallel/test/nist02.ref' 51 | output_path: 'model-news6/nist02.output' 52 | set_nist03: 53 | src_path: '../../parallel/test/nist03.char.txt' 54 | ref_path: '../../parallel/test/nist03.ref' 55 | output_path: 'model-news6/nist03.output' 56 | set_nist04: 57 | src_path: '../../parallel/test/nist04.char.txt' 58 | ref_path: '../../parallel/test/nist04.ref' 59 | output_path: 'model-news6/nist04.output' 60 | set_nist05: 61 | src_path: '../../parallel/test/nist05.char.txt' 62 | ref_path: '../../parallel/test/nist05.ref' 63 | output_path: 'model-news6/nist05.output' 64 | 65 | # set_train: 66 | # src_path: '../../parallel/news180w/zh.char.txt' 67 | # output_path: 'model-news6/train.en' 68 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import codecs 4 | import commands 5 | import os 6 | import time 7 | import logging 8 | import tensorflow as tf 9 | import numpy as np 10 | from argparse import ArgumentParser 11 | from tempfile import mkstemp 12 | 13 | import yaml 14 | 15 | from models import * 16 | from utils import DataReader, AttrDict, expand_feed_dict 17 | 18 | 19 | def roll_back_to_previous_version(config): 20 | with tf.Graph().as_default(): 21 | with tf.Session() as sess: 22 | var_list = tf.contrib.framework.list_variables(config.model_dir) 23 | var_names, var_shapes = zip(*var_list) 24 | reader = tf.contrib.framework.load_checkpoint(config.model_dir) 25 | var_values = [reader.get_tensor(name) for name in var_names] 26 | new_var_list = [] 27 | for name, value in zip(var_names, var_values): 28 | if name == 'encoder/src_embedding/kernel': 29 | name = 'src_embedding' 30 | elif name == 'decoder/dst_embedding/kernel': 31 | name = 'dst_embedding' 32 | elif name == 'decoder/softmax/kernel': 33 | name = 'dst_softmax' 34 | new_var_list.append(tf.get_variable(name, initializer=value)) 35 | sess.run(tf.global_variables_initializer()) 36 | saver = tf.train.Saver(new_var_list) 37 | saver.save(sess, os.path.join(config.model_dir, 'new_version')) 38 | config.num_shards = 1 39 | 40 | 41 | class Evaluator(object): 42 | """ 43 | Evaluate the model. 44 | """ 45 | def __init__(self): 46 | pass 47 | 48 | def init_from_config(self, config): 49 | self.model = eval(config.model)(config, config.test.num_gpus) 50 | self.model.build_test_model() 51 | 52 | sess_config = tf.ConfigProto() 53 | sess_config.gpu_options.allow_growth = True 54 | sess_config.allow_soft_placement = True 55 | self.sess = tf.Session(config=sess_config) 56 | 57 | # Restore model. 58 | try: 59 | tf.train.Saver().restore(self.sess, tf.train.latest_checkpoint(config.model_dir)) 60 | except tf.errors.NotFoundError: 61 | roll_back_to_previous_version(config) 62 | tf.train.Saver().restore(self.sess, tf.train.latest_checkpoint(config.model_dir)) 63 | 64 | self.data_reader = DataReader(config) 65 | 66 | def init_from_frozen_graphdef(self, config): 67 | frozen_graph_path = os.path.join(config.model_dir, 'frozen_graph.pb') 68 | # If the file doesn't existed, create it. 69 | if not os.path.exists(frozen_graph_path): 70 | logging.warning('The frozen graph does not existed, use \'init_from_config\' instead' 71 | 'and create a frozen graph for next use.') 72 | self.init_from_config(config) 73 | saver = tf.train.Saver() 74 | save_dir = '/tmp/graph-{}'.format(os.getpid()) 75 | os.mkdir(save_dir) 76 | save_path = '{}/ckpt'.format(save_dir) 77 | saver.save(sess=self.sess, save_path=save_path) 78 | 79 | with tf.Session(graph=tf.Graph()) as sess: 80 | clear_devices = True 81 | output_node_names = ['loss_sum', 'predictions'] 82 | # We import the meta graph in the current default Graph 83 | saver = tf.train.import_meta_graph(save_path + '.meta', clear_devices=clear_devices) 84 | 85 | # We restore the weights 86 | saver.restore(sess, save_path) 87 | 88 | # We use a built-in TF helper to export variables to constants 89 | output_graph_def = tf.graph_util.convert_variables_to_constants( 90 | sess, # The session is used to retrieve the weights 91 | tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 92 | output_node_names # The output node names are used to select the useful nodes 93 | ) 94 | 95 | # Finally we serialize and dump the output graph to the filesystem 96 | with tf.gfile.GFile(frozen_graph_path, "wb") as f: 97 | f.write(output_graph_def.SerializeToString()) 98 | logging.info("%d ops in the final graph." % len(output_graph_def.node)) 99 | 100 | # Remove temp files. 101 | os.system('rm -rf ' + save_dir) 102 | else: 103 | sess_config = tf.ConfigProto() 104 | sess_config.gpu_options.allow_growth = True 105 | sess_config.allow_soft_placement = True 106 | self.sess = tf.Session(config=sess_config) 107 | self.data_reader = DataReader(config) 108 | 109 | # We load the protobuf file from the disk and parse it to retrieve the 110 | # unserialized graph_def 111 | with tf.gfile.GFile(frozen_graph_path, "rb") as f: 112 | graph_def = tf.GraphDef() 113 | graph_def.ParseFromString(f.read()) 114 | 115 | # Import the graph_def into current the default graph. 116 | tf.import_graph_def(graph_def) 117 | graph = tf.get_default_graph() 118 | self.model = AttrDict() 119 | 120 | def collect_placeholders(prefix): 121 | ret = [] 122 | idx = 0 123 | while True: 124 | try: 125 | ret.append(graph.get_tensor_by_name('import/{}_{}:0'.format(prefix, idx))) 126 | idx += 1 127 | except KeyError: 128 | return tuple(ret) 129 | 130 | self.model['src_pls'] = collect_placeholders('src_pl') 131 | self.model['dst_pls'] = collect_placeholders('dst_pl') 132 | self.model['predictions'] = graph.get_tensor_by_name('import/predictions:0') 133 | 134 | def init_from_existed(self, model, sess, data_reader): 135 | self.sess = sess 136 | self.model = model 137 | self.data_reader = data_reader 138 | 139 | def beam_search(self, X): 140 | return self.sess.run(self.model.predictions, feed_dict=expand_feed_dict({self.model.src_pls: X})) 141 | 142 | def loss(self, X, Y): 143 | return self.sess.run(self.model.loss_sum, feed_dict=expand_feed_dict({self.model.src_pls: X, self.model.dst_pls: Y})) 144 | 145 | def translate(self, src_path, output_path, batch_size): 146 | logging.info('Translate %s.' % src_path) 147 | _, tmp = mkstemp() 148 | fd = codecs.open(tmp, 'w', 'utf8') 149 | count = 0 150 | token_count = 0 151 | epsilon = 1e-6 152 | start = time.time() 153 | for X in self.data_reader.get_test_batches(src_path, batch_size): 154 | Y = self.beam_search(X) 155 | Y = Y[:len(X)] 156 | sents = self.data_reader.indices_to_words(Y) 157 | assert len(X) == len(sents) 158 | for sent in sents: 159 | print(sent, file=fd) 160 | count += len(X) 161 | token_count += np.sum(np.not_equal(Y, 3)) # 3: 162 | time_span = time.time() - start 163 | logging.info('{0} sentences ({1} tokens) processed in {2:.2f} minutes (speed: {3:.4f} sec/token).'. 164 | format(count, token_count, time_span / 60, time_span / (token_count + epsilon))) 165 | fd.close() 166 | # Remove BPE flag, if have. 167 | os.system("sed -r 's/(@@ )|(@@ ?$)//g' %s > %s" % (tmp, output_path)) 168 | os.remove(tmp) 169 | logging.info('The result file was saved in %s.' % output_path) 170 | 171 | def ppl(self, src_path, dst_path, batch_size): 172 | logging.info('Calculate PPL for %s and %s.' % (src_path, dst_path)) 173 | token_count = 0 174 | loss_sum = 0 175 | for batch in self.data_reader.get_test_batches_with_target(src_path, dst_path, batch_size): 176 | X, Y = batch 177 | loss_sum += self.loss(X, Y) 178 | token_count += np.sum(np.greater(Y, 0)) 179 | # Compute PPL 180 | ppl = np.exp(loss_sum / token_count) 181 | logging.info('PPL: %.4f' % ppl) 182 | return ppl 183 | 184 | def evaluate(self, batch_size, **kargs): 185 | """Evaluate the model on dev set.""" 186 | src_path = kargs['src_path'] 187 | output_path = kargs['output_path'] 188 | cmd = kargs['cmd'] if 'cmd' in kargs else\ 189 | "perl multi-bleu.perl {ref} < {output} 2>/dev/null | awk '{{print($3)}}' | awk -F, '{{print $1}}'" 190 | cmd = cmd.strip() 191 | logging.info('Evaluation command: ' + cmd) 192 | self.translate(src_path, output_path, batch_size) 193 | bleu = None 194 | if 'ref_path' in kargs: 195 | ref_path = kargs['ref_path'] 196 | try: 197 | bleu = commands.getoutput(cmd.format(**{'ref': ref_path, 'output': output_path})) 198 | bleu = float(bleu) 199 | except ValueError, e: 200 | logging.warning('An error raised when calculate BLEU: {}'.format(e)) 201 | bleu = 0 202 | logging.info('BLEU: {}'.format(bleu)) 203 | if 'dst_path' in kargs: 204 | self.ppl(src_path, kargs['dst_path'], batch_size) 205 | return bleu 206 | 207 | 208 | if __name__ == '__main__': 209 | parser = ArgumentParser() 210 | parser.add_argument('-c', '--config', dest='config') 211 | args = parser.parse_args() 212 | # Read config 213 | config = AttrDict(yaml.load(open(args.config))) 214 | # Logger 215 | logging.basicConfig(level=logging.INFO) 216 | evaluator = Evaluator() 217 | if config.test.frozen: 218 | evaluator.init_from_frozen_graphdef(config) 219 | else: 220 | evaluator.init_from_config(config) 221 | for attr in config.test: 222 | if attr.startswith('set'): 223 | evaluator.evaluate(config.test.batch_size, **config.test[attr]) 224 | logging.info("Done") 225 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from transformer import Transformer 2 | from deeprnn import DeepRNN 3 | from rnnsearch import RNNSearch 4 | from parallel import PTransformer 5 | -------------------------------------------------------------------------------- /models/deeprnn.py: -------------------------------------------------------------------------------- 1 | from rnnsearch import * 2 | 3 | 4 | class DeepRNN(RNNSearch): 5 | def __init__(self, *args, **kargs): 6 | super(DeepRNN, self).__init__(*args, **kargs) 7 | 8 | def encoder_impl(self, encoder_input, is_training): 9 | residual_dropout_rate = self._config.residual_dropout_rate if is_training else 0.0 10 | 11 | # Mask 12 | encoder_mask = tf.to_int32(tf.not_equal(encoder_input, 0)) 13 | sequence_lengths = tf.reduce_sum(encoder_mask, axis=1) 14 | 15 | # Embedding 16 | encoder_output = embedding(encoder_input, 17 | vocab_size=self._config.src_vocab_size, 18 | dense_size=self._config.hidden_units, 19 | kernel=self._src_embedding, 20 | multiplier=self._config.hidden_units ** 0.5 if self._config.scale_embedding else 1.0, 21 | name="src_embedding") 22 | 23 | # Dropout 24 | # encoder_output = tf.layers.dropout(encoder_output, rate=residual_dropout_rate, training=is_training) 25 | encoder_output = common_layers.layer_norm(encoder_output, name='LN_0') 26 | 27 | # Bi-directional RNN 28 | cell_fw = GRUCell(num_units=self._config.hidden_units, 29 | name='fw_cell_0') 30 | cell_bw = GRUCell(num_units=self._config.hidden_units, 31 | name='bw_cell_1') 32 | 33 | encoder_outputs, _ = tf.nn.bidirectional_dynamic_rnn( 34 | cell_fw=cell_fw, cell_bw=cell_bw, 35 | inputs=encoder_output, 36 | sequence_length=sequence_lengths, 37 | dtype=tf.float32 38 | ) 39 | 40 | encoder_output = tf.concat(encoder_outputs, axis=2) 41 | encoder_output = dense(encoder_output, output_size=self._config.hidden_units) 42 | encoder_output = tf.layers.dropout(encoder_output, rate=residual_dropout_rate, training=is_training) 43 | 44 | for i in xrange(2, self._config.num_blocks): 45 | encoder_output = common_layers.layer_norm(encoder_output, name='LN_%d' % i) 46 | 47 | cell = GRUCell(num_units=self._config.hidden_units, 48 | reuse=tf.AUTO_REUSE, name='cell_%s' % i) 49 | 50 | encoder_output_, _ = tf.nn.dynamic_rnn(cell, encoder_output, 51 | sequence_length=sequence_lengths, 52 | dtype=tf.float32, scope='rnn_%d' % i) 53 | encoder_output_ = tf.layers.dropout(encoder_output_, rate=residual_dropout_rate, training=is_training) 54 | 55 | if i >= 2: 56 | encoder_output = encoder_output_ + encoder_output 57 | else: 58 | encoder_output = encoder_output_ 59 | 60 | encoder_output = common_layers.layer_norm(encoder_output, name='LN_%d' % self._config.num_blocks) 61 | # Mask 62 | encoder_output *= tf.expand_dims(tf.to_float(encoder_mask), axis=-1) 63 | 64 | return encoder_output 65 | 66 | def decoder_impl(self, decoder_input, encoder_output, is_training): 67 | 68 | residual_dropout_rate = self._config.residual_dropout_rate if is_training else 0.0 69 | 70 | attention_bias = tf.equal(tf.reduce_sum(tf.abs(encoder_output), axis=-1, keepdims=True), 0.0) 71 | attention_bias = tf.to_float(attention_bias) * (- 1e9) 72 | 73 | decoder_output = embedding(decoder_input, 74 | vocab_size=self._config.dst_vocab_size, 75 | dense_size=self._config.hidden_units, 76 | kernel=self._dst_embedding, 77 | multiplier=self._config.hidden_units ** 0.5 if self._config.scale_embedding else 1.0, 78 | name="dst_embedding") 79 | decoder_output = tf.layers.dropout(decoder_output, rate=residual_dropout_rate, training=is_training) 80 | 81 | for i in xrange(self._config.num_blocks): 82 | decoder_output = common_layers.layer_norm(decoder_output, name='LN_%d' % i) 83 | 84 | if i % 3 == 1: 85 | cell = AttentionGRUCell(num_units=self._config.hidden_units, 86 | attention_memories=encoder_output, 87 | attention_bias=attention_bias, 88 | reuse=tf.AUTO_REUSE, 89 | name='cell_%s' % i) 90 | else: 91 | cell = GRUCell(num_units=self._config.hidden_units, 92 | reuse=tf.AUTO_REUSE, 93 | name='cell_%s' % i) 94 | 95 | decoder_output_, _ = tf.nn.dynamic_rnn(cell=cell, inputs=decoder_output, dtype=tf.float32, 96 | scope='rnn_%d' % i) 97 | decoder_output_ = tf.layers.dropout(decoder_output_, rate=residual_dropout_rate, training=is_training) 98 | if i >= 2: 99 | decoder_output = decoder_output_ + decoder_output 100 | else: 101 | decoder_output = decoder_output_ 102 | 103 | decoder_output = common_layers.layer_norm(decoder_output, name='LN_%d' % self._config.num_blocks) 104 | return decoder_output 105 | 106 | def decoder_with_caching_impl(self, decoder_input, decoder_cache, encoder_output, is_training): 107 | residual_dropout_rate = self._config.residual_dropout_rate if is_training else 0.0 108 | 109 | attention_bias = tf.equal(tf.reduce_sum(tf.abs(encoder_output), axis=-1, keepdims=True), 0.0) 110 | attention_bias = tf.to_float(attention_bias) * (- 1e9) 111 | 112 | decoder_input = decoder_input[:, -1] 113 | 114 | decoder_output = embedding(decoder_input, 115 | vocab_size=self._config.dst_vocab_size, 116 | dense_size=self._config.hidden_units, 117 | kernel=self._dst_embedding, 118 | multiplier=self._config.hidden_units ** 0.5 if self._config.scale_embedding else 1.0, 119 | name="dst_embedding") 120 | decoder_output = tf.layers.dropout(decoder_output, rate=residual_dropout_rate, training=is_training) 121 | 122 | decoder_cache = \ 123 | tf.cond(tf.equal(tf.shape(decoder_cache)[1], 0), 124 | lambda: tf.zeros([tf.shape(decoder_input)[0], 125 | 1, 126 | self._config.num_blocks, 127 | self._config.hidden_units]), 128 | lambda: decoder_cache) 129 | # Unstack cache 130 | states = tf.unstack(decoder_cache[:, -1, :, :], num=self._config.num_blocks, axis=1) 131 | new_states = [] 132 | for i in xrange(self._config.num_blocks): 133 | decoder_output = common_layers.layer_norm(decoder_output, name='LN_%s' % i) 134 | if i % 3 == 1: 135 | cell = AttentionGRUCell(num_units=self._config.hidden_units, 136 | attention_memories=encoder_output, 137 | attention_bias=attention_bias, 138 | reuse=tf.AUTO_REUSE, 139 | name='cell_%s' % i) 140 | else: 141 | cell = GRUCell(num_units=self._config.hidden_units, 142 | reuse=tf.AUTO_REUSE, 143 | name='cell_%s' % i) 144 | 145 | with tf.variable_scope('rnn_%s' % i): 146 | decoder_output_, decoder_state = cell(decoder_output, states[i]) 147 | 148 | # if i % 3 == 1: 149 | # # We can log attention weights here. 150 | # cell.get_attention_weights() 151 | 152 | decoder_output_ = tf.layers.dropout(decoder_output_, rate=residual_dropout_rate, training=is_training) 153 | if i >= 2: 154 | decoder_output = decoder_output_ + decoder_output 155 | else: 156 | decoder_output = decoder_output_ 157 | new_states.append(decoder_state) 158 | 159 | decoder_state = tf.stack(new_states, axis=1)[:, None, :, :] 160 | decoder_output = common_layers.layer_norm(decoder_output, name='LN_%d' % self._config.num_blocks) 161 | 162 | return decoder_output[:, None, :], decoder_state 163 | -------------------------------------------------------------------------------- /models/indrnn.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from rnnsearch import RNNSearch 3 | 4 | 5 | class DeepRNN(RNNSearch): 6 | def __init__(self, *args, **kargs): 7 | super(DeepRNN, self).__init__(*args, **kargs) 8 | 9 | def encoder_impl(self, encoder_input, is_training): 10 | # residual_dropout_rate = self._config.residual_dropout_rate if is_training else 0.0 11 | 12 | # Mask 13 | encoder_mask = tf.to_int32(tf.not_equal(encoder_input, 0)) 14 | sequence_lengths = tf.reduce_sum(encoder_mask, axis=1) 15 | 16 | recurrent_initializer = tf.random_uniform_initializer(0.0, 1.0) 17 | 18 | # Embedding 19 | encoder_output = embedding(encoder_input, 20 | vocab_size=self._config.src_vocab_size, 21 | dense_size=self._config.hidden_units, 22 | kernel=self._src_embedding, 23 | multiplier=self._config.hidden_units ** 0.5 if self._config.scale_embedding else 1.0, 24 | name="src_embedding") 25 | 26 | # Dropout 27 | # encoder_output = tf.layers.dropout(encoder_output, rate=residual_dropout_rate, training=is_training) 28 | encoder_output = tf.layers.batch_normalization(encoder_output, training=is_training, name='BN_0') 29 | 30 | # Bi-directional RNN 31 | cell_fw = IndRNNCell(num_units=self._config.hidden_units, 32 | recurrent_initializer=recurrent_initializer, 33 | name='fw_cell_0') 34 | cell_bw = IndRNNCell(num_units=self._config.hidden_units, 35 | recurrent_initializer=recurrent_initializer, 36 | name='bw_cell_1') 37 | 38 | encoder_outputs, _ = tf.nn.bidirectional_dynamic_rnn( 39 | cell_fw=cell_fw, cell_bw=cell_bw, 40 | inputs=encoder_output, 41 | sequence_length=sequence_lengths, 42 | dtype=tf.float32 43 | ) 44 | 45 | encoder_output = tf.concat(encoder_outputs, axis=2) 46 | encoder_output = dense(encoder_output, output_size=self._config.hidden_units) 47 | # encoder_output = tf.layers.dropout(encoder_output, rate=residual_dropout_rate, training=is_training) 48 | 49 | for i in xrange(2, self._config.num_blocks): 50 | encoder_output = tf.layers.batch_normalization(encoder_output, training=is_training, name='BN_%d' % i) 51 | 52 | cell = IndRNNCell(num_units=self._config.hidden_units, 53 | recurrent_initializer=recurrent_initializer, 54 | reuse=tf.AUTO_REUSE, name='cell_%s' % i) 55 | 56 | encoder_output, _ = tf.nn.dynamic_rnn(cell, encoder_output, 57 | sequence_length=sequence_lengths, 58 | dtype=tf.float32, scope='rnn_%d' % i) 59 | # encoder_output = tf.layers.dropout(encoder_output, rate=residual_dropout_rate, training=is_training) 60 | 61 | encoder_output = tf.layers.batch_normalization(encoder_output, 62 | training=is_training, 63 | name='BN_%d' % self._config.num_blocks) 64 | # Mask 65 | encoder_output *= tf.expand_dims(tf.to_float(encoder_mask), axis=-1) 66 | 67 | return encoder_output 68 | 69 | def decoder_impl(self, decoder_input, encoder_output, is_training): 70 | 71 | # residual_dropout_rate = self._config.residual_dropout_rate if is_training else 0.0 72 | 73 | attention_bias = tf.equal(tf.reduce_sum(tf.abs(encoder_output), axis=-1, keepdims=True), 0.0) 74 | attention_bias = tf.to_float(attention_bias) * (- 1e9) 75 | 76 | recurrent_initializer = tf.random_uniform_initializer(0.0, 1.0) 77 | 78 | decoder_output = embedding(decoder_input, 79 | vocab_size=self._config.dst_vocab_size, 80 | dense_size=self._config.hidden_units, 81 | kernel=self._dst_embedding, 82 | multiplier=self._config.hidden_units ** 0.5 if self._config.scale_embedding else 1.0, 83 | name="dst_embedding") 84 | # decoder_output = tf.layers.dropout(decoder_output, rate=residual_dropout_rate, training=is_training) 85 | 86 | for i in xrange(self._config.num_blocks): 87 | decoder_output = tf.layers.batch_normalization(decoder_output, training=is_training, name='BN_%d' % i) 88 | 89 | if i % 3 == 0: 90 | cell = AttentionIndRNNCell(num_units=self._config.hidden_units, 91 | attention_memories=encoder_output, 92 | attention_bias=attention_bias, 93 | recurrent_initializer=recurrent_initializer, 94 | reuse=tf.AUTO_REUSE, 95 | name='cell_%s' % i) 96 | else: 97 | cell = IndRNNCell(num_units=self._config.hidden_units, 98 | recurrent_initializer=recurrent_initializer, 99 | reuse=tf.AUTO_REUSE, 100 | name='cell_%s' % i) 101 | 102 | decoder_output, _ = tf.nn.dynamic_rnn(cell=cell, inputs=decoder_output, dtype=tf.float32, scope='rnn_%d' % i) 103 | # decoder_output = tf.layers.dropout(decoder_output, rate=residual_dropout_rate, training=is_training) 104 | decoder_output = tf.layers.batch_normalization(decoder_output, 105 | training=is_training, 106 | name='BN_%d' % self._config.num_blocks) 107 | return decoder_output 108 | 109 | def decoder_with_caching_impl(self, decoder_input, decoder_cache, encoder_output, is_training): 110 | # residual_dropout_rate = self._config.residual_dropout_rate if is_training else 0.0 111 | 112 | attention_bias = tf.equal(tf.reduce_sum(tf.abs(encoder_output), axis=-1, keepdims=True), 0.0) 113 | attention_bias = tf.to_float(attention_bias) * (- 1e9) 114 | 115 | recurrent_initializer = tf.random_uniform_initializer(0.0, 1.0) 116 | 117 | decoder_input = decoder_input[:, -1] 118 | 119 | decoder_output = embedding(decoder_input, 120 | vocab_size=self._config.dst_vocab_size, 121 | dense_size=self._config.hidden_units, 122 | kernel=self._dst_embedding, 123 | multiplier=self._config.hidden_units ** 0.5 if self._config.scale_embedding else 1.0, 124 | name="dst_embedding") 125 | # decoder_output = tf.layers.dropout(decoder_output, rate=residual_dropout_rate, training=is_training) 126 | 127 | decoder_cache = \ 128 | tf.cond(tf.equal(tf.shape(decoder_cache)[1], 0), 129 | lambda: tf.zeros([tf.shape(decoder_input)[0], 130 | 1, 131 | self._config.num_blocks, 132 | self._config.hidden_units]), 133 | lambda: decoder_cache) 134 | # Unstack cache 135 | states = tf.unstack(decoder_cache[:, -1, :, :], num=self._config.num_blocks, axis=1) 136 | new_states = [] 137 | for i in xrange(self._config.num_blocks): 138 | decoder_output = tf.layers.batch_normalization(decoder_output, training=is_training, name='BN_%s' % i) 139 | if i % 3 == 0: 140 | cell = AttentionIndRNNCell(num_units=self._config.hidden_units, 141 | attention_memories=encoder_output, 142 | attention_bias=attention_bias, 143 | recurrent_initializer=recurrent_initializer, 144 | reuse=tf.AUTO_REUSE, 145 | name='cell_%s' % i) 146 | else: 147 | cell = IndRNNCell(num_units=self._config.hidden_units, 148 | recurrent_initializer=recurrent_initializer, 149 | reuse=tf.AUTO_REUSE, 150 | name='cell_%s' % i) 151 | 152 | with tf.variable_scope('rnn_%s' % i): 153 | decoder_output, decoder_state = cell(decoder_output, states[i]) 154 | 155 | # if i % 3 == 0: 156 | # # We can log attention weights here. 157 | # cell.get_attention_weights() 158 | 159 | # decoder_output = tf.layers.dropout(decoder_output, rate=residual_dropout_rate, training=is_training) 160 | new_states.append(decoder_state) 161 | 162 | decoder_state = tf.stack(new_states, axis=1)[:, None, :, :] 163 | decoder_output = tf.layers.batch_normalization(decoder_output, 164 | training=is_training, 165 | name='BN_%d' % self._config.num_blocks) 166 | 167 | return decoder_output[:, None, :], decoder_state 168 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | from tensorflow.python.ops import init_ops 4 | from utils import * 5 | 6 | 7 | class Model(object): 8 | def __init__(self, config, num_gpus): 9 | self._config = config 10 | 11 | self._devices = ['/gpu:%d' % i for i in range(num_gpus)] if num_gpus > 0 else ['/cpu:0'] 12 | 13 | # Placeholders and saver. 14 | src_pls = [] 15 | dst_pls = [] 16 | for i, device in enumerate(self._devices): 17 | with tf.device(device): 18 | src_pls.append(tf.placeholder(dtype=tf.int32, shape=[None, None], name='src_pl_{}'.format(i))) 19 | dst_pls.append(tf.placeholder(dtype=tf.int32, shape=[None, None], name='dst_pl_{}'.format(i))) 20 | self.src_pls = tuple(src_pls) 21 | self.dst_pls = tuple(dst_pls) 22 | 23 | self.encoder_scope = self._config.encoder_scope or 'encoder' 24 | self.decoder_scope = self._config.decoder_scope or 'decoder' 25 | 26 | self.losses = defaultdict(list) # self.losses[name][device] 27 | self.grads_and_vars = defaultdict(list) # self.grads_and_vars[name][device] 28 | 29 | # Uniform scaling initializer. 30 | self._initializer = init_ops.variance_scaling_initializer(scale=1.0, mode='fan_avg', distribution='uniform') 31 | 32 | self.prepare_shared_weights() 33 | 34 | self._use_cache = True 35 | self._use_daisy_chain_getter = True 36 | 37 | def prepare_shared_weights(self): 38 | 39 | partitions = self._config.num_shards or 16 40 | 41 | def get_weights(name, shape): 42 | vocab_size, hidden_size = shape 43 | if partitions > 1: 44 | inter_points = np.linspace(0, vocab_size, partitions + 1, dtype=np.int) 45 | parts = [] 46 | pre_point = 0 47 | for i, p in enumerate(inter_points[1:]): 48 | parts.append(tf.get_variable(name=name + '_' + str(i), 49 | shape=[p - pre_point, hidden_size])) 50 | pre_point = p 51 | return common_layers.eu.ConvertGradientToTensor(tf.concat(parts, 0, name)) 52 | else: 53 | return tf.get_variable(name=name, shape=shape) 54 | 55 | src_embedding = get_weights('src_embedding', 56 | shape=[self._config.src_vocab_size, self._config.hidden_units]) 57 | if self._config.tie_embeddings: 58 | assert self._config.src_vocab_size == self._config.dst_vocab_size and \ 59 | self._config.src_vocab == self._config.dst_vocab 60 | dst_embedding = src_embedding 61 | else: 62 | dst_embedding = get_weights('dst_embedding', 63 | shape=[self._config.dst_vocab_size, self._config.hidden_units]) 64 | 65 | if self._config.tie_embedding_and_softmax: 66 | dst_softmax = dst_embedding 67 | else: 68 | dst_softmax = get_weights('dst_softmax', 69 | shape=[self._config.dst_vocab_size, self._config.hidden_units]) 70 | 71 | self._src_embedding = src_embedding 72 | self._dst_embedding = dst_embedding 73 | self._dst_softmax = dst_softmax 74 | 75 | def prepare_training(self): 76 | # Optimizer 77 | self.global_step = tf.get_variable(name='global_step', dtype=tf.int64, shape=[], 78 | trainable=False, initializer=tf.zeros_initializer) 79 | 80 | self.learning_rate = tf.convert_to_tensor(self._config.train.learning_rate, dtype=tf.float32) 81 | 82 | if self._config.train.optimizer == 'adam': 83 | self._optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) 84 | elif self._config.train.optimizer == 'adam_decay': 85 | self.learning_rate *= learning_rate_decay(self._config, self.global_step) 86 | self._optimizer = tf.train.AdamOptimizer( 87 | learning_rate=self.learning_rate, beta1=0.9, beta2=0.98, epsilon=1e-9) 88 | elif self._config.train.optimizer == 'sgd': 89 | self._optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate) 90 | elif self._config.train.optimizer == 'mom': 91 | self._optimizer = tf.train.MomentumOptimizer(self.learning_rate, momentum=0.9) 92 | else: 93 | raise Exception('Unknown optimizer: {}.'.format(self._config.train.optimizer)) 94 | 95 | tf.summary.scalar('learning_rate', self.learning_rate) 96 | 97 | def build_train_model(self, test=True, teacher_model=None, reuse=None): 98 | """Build model for training. """ 99 | logging.info('Build train model.') 100 | self.prepare_training() 101 | 102 | cache = {} 103 | load = dict([(d, 0) for d in self._devices]) 104 | for i, (X, Y, device) in enumerate(zip(self.src_pls, self.dst_pls, self._devices)): 105 | 106 | def daisy_chain_getter(getter, name, *args, **kwargs): 107 | """Get a variable and cache in a daisy chain.""" 108 | device_var_key = (device, name) 109 | if device_var_key in cache: 110 | # if we have the variable on the correct device, return it. 111 | return cache[device_var_key] 112 | if name in cache: 113 | # if we have it on a different device, copy it from the last device 114 | v = tf.identity(cache[name]) 115 | else: 116 | var = getter(name, *args, **kwargs) 117 | v = tf.identity(var._ref()) # pylint: disable=protected-access 118 | # update the cache 119 | cache[name] = v 120 | cache[device_var_key] = v 121 | return v 122 | 123 | def balanced_device_setter(op): 124 | """Balance variables to all devices.""" 125 | if op.type in {'Variable', 'VariableV2', 'VarHandleOp'}: 126 | # return self._sync_device 127 | min_load = min(load.values()) 128 | min_load_devices = [d for d in load if load[d] == min_load] 129 | chosen_device = random.choice(min_load_devices) 130 | load[chosen_device] += op.outputs[0].get_shape().num_elements() 131 | return chosen_device 132 | return device 133 | 134 | device_setter = balanced_device_setter 135 | custom_getter = daisy_chain_getter if self._use_daisy_chain_getter else None 136 | 137 | with tf.variable_scope(tf.get_variable_scope(), 138 | initializer=self._initializer, 139 | custom_getter=custom_getter, 140 | reuse=reuse): 141 | with tf.device(device_setter): 142 | logging.info('Build model on %s.' % device) 143 | encoder_output = self.encoder(X, is_training=True, reuse=i > 0 or None) 144 | decoder_output = self.decoder(shift_right(Y), encoder_output, is_training=True, reuse=i > 0 or None) 145 | if teacher_model is not None: 146 | with tf.variable_scope('teacher'): 147 | teacher_encoder_output = teacher_model.encoder(X, is_training=False, reuse=i > 0 or None) 148 | teacher_decoder_output = teacher_model.decoder(shift_right(Y), 149 | teacher_encoder_output, 150 | is_training=False, 151 | reuse=i > 0 or None) 152 | _, teacher_probs = teacher_model.test_loss(teacher_decoder_output, 153 | Y, 154 | reuse=i > 0 or None) 155 | self.train_output(decoder_output, Y, teacher_probs=teacher_probs, reuse=i > 0 or None) 156 | 157 | else: 158 | self.train_output(decoder_output, Y, teacher_probs=None, reuse=i > 0 or None) 159 | 160 | self.summary_op = tf.summary.merge_all() 161 | 162 | # We may want to test the model during training. 163 | if test: 164 | self.build_test_model(reuse=True) 165 | 166 | def build_test_model(self, reuse=None): 167 | """Build model for inference.""" 168 | logging.info('Build test model.') 169 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): 170 | preds_list = [] 171 | loss_sum = 0 172 | for i, (X, Y, device) in enumerate(zip(self.src_pls, self.dst_pls, self._devices)): 173 | with tf.device(device): 174 | logging.info('Build model on %s.' % device) 175 | dec_input = shift_right(Y) 176 | 177 | # Avoid errors caused by empty input by a condition phrase. 178 | enc_output = self.encoder(X, is_training=False, reuse=i > 0 or None) 179 | preds = self.beam_search(enc_output, use_cache=self._use_cache, reuse=i > 0 or None) 180 | dec_output = self.decoder(dec_input, enc_output, is_training=False, reuse=True) 181 | loss, _ = self.test_loss(dec_output, Y, reuse=True) 182 | 183 | loss_sum += loss 184 | preds_list.append(preds) 185 | 186 | max_length = tf.reduce_max([tf.shape(pred)[1] for pred in preds_list]) 187 | 188 | def pad_to_max_length(input, length): 189 | """Pad the input (with rank 2) with 3() to the given length in the second axis.""" 190 | shape = tf.shape(input) 191 | padding = tf.ones([shape[0], length - shape[1]], dtype=tf.int32) * 3 192 | return tf.concat([input, padding], axis=1) 193 | 194 | preds_list = [pad_to_max_length(pred, max_length) for pred in preds_list] 195 | self.predictions = tf.concat(preds_list, axis=0, name='predictions') 196 | self.loss_sum = tf.identity(loss_sum, name='loss_sum') 197 | 198 | def register_loss(self, name, loss): 199 | self.losses[name].append(loss) 200 | # Filter out variables of the teacher model. 201 | vars = [v for v in tf.trainable_variables() if not v.name.startswith('teacher')] 202 | grads_and_vars = self._optimizer.compute_gradients(loss, vars) 203 | grads_and_vars_not_none = [] 204 | for g, v in grads_and_vars: 205 | # Avoid exception when g is None. 206 | if g is None: 207 | logging.warning('Gradient of {} to {} is None.'.format(name, v.name)) 208 | else: 209 | grads_and_vars_not_none.append((g, v)) 210 | self.grads_and_vars[name].append(grads_and_vars_not_none) 211 | 212 | if not tf.get_variable_scope().reuse: 213 | grads_norm = tf.global_norm([gv[0] for gv in grads_and_vars_not_none]) 214 | tf.summary.scalar(name.format(name), loss) 215 | tf.summary.scalar('{}_grads_norm'.format(name), grads_norm) 216 | 217 | def get_train_op(self, increase_global_step=True, name=None): 218 | global_step = self.global_step if increase_global_step else None 219 | if name: 220 | avg_loss = tf.reduce_mean(self.losses[name]) 221 | grads_and_vars_list = self.grads_and_vars[name] 222 | grads_and_vars = average_gradients(grads_and_vars_list) 223 | else: 224 | summed_grads_and_vars = {} 225 | avg_loss = 0 226 | for name in self.losses: 227 | avg_loss += tf.reduce_mean(self.losses[name]) 228 | grads_and_vars_list = self.grads_and_vars[name] 229 | grads_and_vars = average_gradients(grads_and_vars_list) 230 | for g, v in grads_and_vars: 231 | if v in summed_grads_and_vars: 232 | summed_grads_and_vars[v] += g 233 | else: 234 | summed_grads_and_vars[v] = g 235 | summed_grads_and_vars = [(summed_grads_and_vars[v], v) for v in summed_grads_and_vars] 236 | grads_and_vars = summed_grads_and_vars 237 | 238 | # Gradients clipping 239 | if self._config.train.grads_clip: 240 | grads, _ = tf.clip_by_global_norm([g for g, _ in grads_and_vars], 241 | self._config.train.grads_clip) 242 | vars = [v for _, v in grads_and_vars] 243 | grads_and_vars = list(zip(grads, vars)) 244 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 245 | with tf.control_dependencies(update_ops): 246 | train_op = self._optimizer.apply_gradients(grads_and_vars, global_step=global_step) 247 | 248 | return train_op, avg_loss 249 | 250 | def encoder(self, encoder_input, is_training, reuse): 251 | """Encoder.""" 252 | with tf.variable_scope(self.encoder_scope, reuse=reuse): 253 | return self.encoder_impl(encoder_input, is_training) 254 | 255 | def decoder(self, decoder_input, encoder_output, is_training, reuse): 256 | """Decoder""" 257 | with tf.variable_scope(self.decoder_scope, reuse=reuse): 258 | return self.decoder_impl(decoder_input, encoder_output, is_training) 259 | 260 | def decoder_with_caching(self, decoder_input, decoder_cache, encoder_output, is_training, reuse): 261 | """Incremental Decoder""" 262 | with tf.variable_scope(self.decoder_scope, reuse=reuse): 263 | return self.decoder_with_caching_impl(decoder_input, decoder_cache, encoder_output, is_training) 264 | 265 | def beam_search(self, encoder_output, use_cache, reuse): 266 | """Beam search in graph.""" 267 | beam_size, batch_size = self._config.test.beam_size, tf.shape(encoder_output)[0] 268 | inf = 1e10 269 | 270 | if beam_size == 1: 271 | return self.greedy_search(encoder_output, use_cache, reuse) 272 | 273 | def get_bias_scores(scores, bias): 274 | """ 275 | If a sequence is finished, we only allow one alive branch. This function aims to give one branch a zero score 276 | and the rest -inf score. 277 | Args: 278 | scores: A real value array with shape [batch_size * beam_size, beam_size]. 279 | bias: A bool array with shape [batch_size * beam_size]. 280 | 281 | Returns: 282 | A real value array with shape [batch_size * beam_size, beam_size]. 283 | """ 284 | bias = tf.to_float(bias) 285 | b = tf.constant([0.0] + [-inf] * (beam_size - 1)) 286 | b = tf.tile(b[None, :], multiples=[batch_size * beam_size, 1]) 287 | return scores * (1 - bias[:, None]) + b * bias[:, None] 288 | 289 | def get_bias_preds(preds, bias): 290 | """ 291 | If a sequence is finished, all of its branch should be (3). 292 | Args: 293 | preds: A int array with shape [batch_size * beam_size, beam_size]. 294 | bias: A bool array with shape [batch_size * beam_size]. 295 | 296 | Returns: 297 | A int array with shape [batch_size * beam_size]. 298 | """ 299 | bias = tf.to_int32(bias) 300 | return preds * (1 - bias[:, None]) + bias[:, None] * 3 301 | 302 | # Prepare beam search inputs. 303 | # [batch_size, 1, *, hidden_units] 304 | encoder_output = encoder_output[:, None, :, :] 305 | # [batch_size, beam_size, *, hidden_units] 306 | encoder_output = tf.tile(encoder_output, multiples=[1, beam_size, 1, 1]) 307 | encoder_output = tf.reshape(encoder_output, [batch_size * beam_size, -1, encoder_output.get_shape()[-1].value]) 308 | # [[, , ..., ]], shape: [batch_size * beam_size, 1] 309 | preds = tf.ones([batch_size * beam_size, 1], dtype=tf.int32) * 2 310 | scores = tf.constant([0.0] + [-inf] * (beam_size - 1), dtype=tf.float32) # [beam_size] 311 | scores = tf.tile(scores, multiples=[batch_size]) # [batch_size * beam_size] 312 | lengths = tf.zeros([batch_size * beam_size], dtype=tf.float32) 313 | bias = tf.zeros_like(scores, dtype=tf.bool) 314 | 315 | if use_cache: 316 | cache = tf.zeros([batch_size * beam_size, 0, self._config.num_blocks, self._config.hidden_units]) 317 | else: 318 | cache = tf.zeros([0, 0, 0, 0]) 319 | 320 | def step(i, bias, preds, scores, lengths, cache): 321 | # Where are we. 322 | i += 1 323 | 324 | # Call decoder and get predictions. 325 | if use_cache: 326 | decoder_output, cache = \ 327 | self.decoder_with_caching(preds, cache, encoder_output, is_training=False, reuse=reuse) 328 | else: 329 | decoder_output = self.decoder(preds, encoder_output, is_training=False, reuse=reuse) 330 | 331 | _, next_preds, next_scores = self.test_output(decoder_output, reuse=reuse) 332 | 333 | next_preds = get_bias_preds(next_preds, bias) 334 | next_scores = get_bias_scores(next_scores, bias) 335 | 336 | # Update scores. 337 | scores = scores[:, None] + next_scores # [batch_size * beam_size, beam_size] 338 | scores = tf.reshape(scores, shape=[batch_size, beam_size ** 2]) # [batch_size, beam_size * beam_size] 339 | 340 | # LP scores. 341 | lengths = lengths[:, None] + tf.to_float(tf.not_equal(next_preds, 3)) # [batch_size * beam_size, beam_size] 342 | lengths = tf.reshape(lengths, shape=[batch_size, beam_size ** 2]) # [batch_size, beam_size * beam_size] 343 | lp = tf.pow((5 + lengths) / (5 + 1), self._config.test.lp_alpha) # Length penalty 344 | lp_scores = scores / lp # following GNMT 345 | 346 | # Pruning 347 | _, k_indices = tf.nn.top_k(lp_scores, k=beam_size) 348 | base_indices = tf.reshape(tf.tile(tf.range(batch_size)[:, None], multiples=[1, beam_size]), shape=[-1]) 349 | base_indices *= beam_size ** 2 350 | k_indices = base_indices + tf.reshape(k_indices, shape=[-1]) # [batch_size * beam_size] 351 | 352 | # Update lengths. 353 | lengths = tf.reshape(lengths, [-1]) 354 | lengths = tf.gather(lengths, k_indices) 355 | 356 | # Update scores. 357 | scores = tf.reshape(scores, [-1]) 358 | scores = tf.gather(scores, k_indices) 359 | 360 | # Update predictions. 361 | next_preds = tf.gather(tf.reshape(next_preds, shape=[-1]), indices=k_indices) 362 | preds = tf.gather(preds, indices=k_indices / beam_size) 363 | if use_cache: 364 | cache = tf.gather(cache, indices=k_indices / beam_size) 365 | preds = tf.concat((preds, next_preds[:, None]), axis=1) # [batch_size * beam_size, i] 366 | 367 | # Whether sequences finished. 368 | bias = tf.equal(preds[:, -1], 3) # ? 369 | 370 | return i, bias, preds, scores, lengths, cache 371 | 372 | def not_finished(i, bias, preds, scores, lengths, cache): 373 | return tf.logical_and( 374 | tf.reduce_any(tf.logical_not(bias)), 375 | tf.less_equal( 376 | i, 377 | tf.reduce_min([tf.shape(encoder_output)[1] + 50, self._config.test.max_target_length]) 378 | ) 379 | ) 380 | 381 | i, bias, preds, scores, lengths, cache = \ 382 | tf.while_loop(cond=not_finished, 383 | body=step, 384 | loop_vars=[0, bias, preds, scores, lengths, cache], 385 | shape_invariants=[ 386 | tf.TensorShape([]), 387 | tf.TensorShape([None]), 388 | tf.TensorShape([None, None]), 389 | tf.TensorShape([None]), 390 | tf.TensorShape([None]), 391 | tf.TensorShape([None, None, None, None])], 392 | back_prop=False) 393 | 394 | scores = tf.reshape(scores, shape=[batch_size, beam_size]) 395 | preds = tf.reshape(preds, shape=[batch_size, beam_size, -1]) # [batch_size, beam_size, max_length] 396 | 397 | max_indices = tf.to_int32(tf.argmax(scores, axis=-1)) # [batch_size] 398 | max_indices += tf.range(batch_size) * beam_size 399 | preds = tf.reshape(preds, shape=[batch_size * beam_size, -1]) 400 | 401 | final_preds = tf.gather(preds, indices=max_indices) 402 | final_preds = final_preds[:, 1:] # remove flag 403 | return final_preds 404 | 405 | def greedy_search(self, encoder_output, use_cache, reuse): 406 | """Greedy search in graph.""" 407 | batch_size = tf.shape(encoder_output)[0] 408 | 409 | preds = tf.ones([batch_size, 1], dtype=tf.int32) * 2 410 | scores = tf.zeros([batch_size], dtype=tf.float32) 411 | finished = tf.zeros([batch_size], dtype=tf.bool) 412 | cache = tf.zeros([batch_size, 0, self._config.num_blocks, self._config.hidden_units]) 413 | 414 | def step(i, finished, preds, scores, cache): 415 | # Where are we. 416 | i += 1 417 | 418 | # Call decoder and get predictions. 419 | decoder_output, cache = self.decoder_with_caching(preds, cache, encoder_output, is_training=False, reuse=reuse) 420 | _, next_preds, next_scores = self.test_output(decoder_output, reuse=reuse) 421 | next_preds = next_preds[:, None, 0] 422 | next_scores = next_scores[:, 0] 423 | 424 | # Update. 425 | scores = scores + next_scores 426 | preds = tf.concat([preds, next_preds], axis=1) 427 | 428 | # Whether sequences finished. 429 | has_eos = tf.equal(next_preds[:, 0], 3) 430 | finished = tf.logical_or(finished, has_eos) 431 | 432 | return i, finished, preds, scores, cache 433 | 434 | def not_finished(i, finished, preds, scores, cache): 435 | return tf.logical_and( 436 | tf.reduce_any(tf.logical_not(finished)), 437 | tf.less_equal( 438 | i, 439 | tf.reduce_min([tf.shape(encoder_output)[1] + 50, self._config.test.max_target_length]) 440 | ) 441 | ) 442 | 443 | i, finished, preds, scores, cache = \ 444 | tf.while_loop(cond=not_finished, 445 | body=step, 446 | loop_vars=[0, finished, preds, scores, cache], 447 | shape_invariants=[ 448 | tf.TensorShape([]), 449 | tf.TensorShape([None]), 450 | tf.TensorShape([None, None]), 451 | tf.TensorShape([None]), 452 | tf.TensorShape([None, None, None, None])], 453 | back_prop=False) 454 | 455 | preds = preds[:, 1:] # remove flag 456 | return preds 457 | 458 | def test_output(self, decoder_output, reuse): 459 | """During test, we only need the last prediction at each time.""" 460 | with tf.variable_scope(self.decoder_scope, reuse=reuse): 461 | last_logits = dense(decoder_output[:, -1], self._config.dst_vocab_size, use_bias=False, 462 | kernel=self._dst_softmax, name='dst_softmax', reuse=None) 463 | next_pred = tf.to_int32(tf.argmax(last_logits, axis=-1)) 464 | z = tf.nn.log_softmax(last_logits) 465 | next_scores, next_preds = tf.nn.top_k(z, k=self._config.test.beam_size, sorted=False) 466 | next_preds = tf.to_int32(next_preds) 467 | return next_pred, next_preds, next_scores 468 | 469 | def test_loss(self, decoder_output, Y, reuse): 470 | """This function help users to compute PPL during test.""" 471 | with tf.variable_scope(self.decoder_scope, reuse=reuse): 472 | logits = dense(decoder_output, self._config.dst_vocab_size, use_bias=False, 473 | kernel=self._dst_softmax, name="decoder", reuse=None) 474 | mask = tf.to_float(tf.not_equal(Y, 0)) 475 | labels = tf.one_hot(Y, depth=self._config.dst_vocab_size) 476 | loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels) 477 | loss_sum = tf.reduce_sum(loss * mask) 478 | probs = tf.nn.softmax(logits) 479 | return loss_sum, probs 480 | 481 | def train_output(self, decoder_output, Y, teacher_probs, reuse): 482 | """Calculate loss and accuracy.""" 483 | with tf.variable_scope(self.decoder_scope, reuse=reuse): 484 | logits = dense(decoder_output, self._config.dst_vocab_size, use_bias=False, 485 | kernel=self._dst_softmax, name='decoder', reuse=None) 486 | preds = tf.to_int32(tf.argmax(logits, axis=-1)) 487 | mask = tf.to_float(tf.not_equal(Y, 0)) 488 | 489 | # Token-level accuracy 490 | acc = tf.reduce_sum(tf.to_float(tf.equal(preds, Y)) * mask) / tf.reduce_sum(mask) 491 | if not tf.get_variable_scope().reuse: 492 | tf.summary.scalar('accuracy', acc) 493 | 494 | if teacher_probs is not None: 495 | # Knowledge distillation 496 | loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=teacher_probs) 497 | else: 498 | # Smoothed loss 499 | loss = common_layers.smoothing_cross_entropy(logits=logits, labels=Y, 500 | vocab_size=self._config.dst_vocab_size, 501 | confidence=1 - self._config.train.label_smoothing) 502 | loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask) 503 | 504 | self.register_loss('ml_loss', loss) 505 | 506 | def encoder_impl(self, encoder_input, is_training): 507 | """ 508 | This is an interface leave to be implemented by sub classes. 509 | Args: 510 | encoder_input: A tensor with shape [batch_size, src_length] 511 | is_training: A boolean 512 | 513 | Returns: A Tensor with shape [batch_size, src_length, num_hidden] 514 | 515 | """ 516 | raise NotImplementedError() 517 | 518 | def decoder_impl(self, decoder_input, encoder_output, is_training): 519 | """ 520 | This is an interface leave to be implemented by sub classes. 521 | Args: 522 | decoder_input: A Tensor with shape [batch_size, dst_length] 523 | encoder_output: A Tensor with shape [batch_size, src_length, num_hidden] 524 | is_training: A boolean. 525 | 526 | Returns: A Tensor with shape [batch_size, dst_length, num_hidden] 527 | 528 | """ 529 | raise NotImplementedError() 530 | 531 | def decoder_with_caching_impl(self, decoder_input, decoder_cache, encoder_output, is_training): 532 | """ 533 | This is an interface leave to be implemented by sub classes. 534 | Args: 535 | decoder_input: A Tensor with shape [batch_size, dst_length] 536 | decoder_cache: A Tensor with shape [batch_size, *, *, num_hidden] 537 | encoder_output: A Tensor with shape [batch_size, src_length, num_hidden] 538 | is_training: A boolean. 539 | 540 | Returns: A Tensor with shape [batch_size, dst_length, num_hidden] 541 | 542 | """ 543 | raise NotImplementedError() 544 | -------------------------------------------------------------------------------- /models/parallel.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | from transformer import * 4 | 5 | 6 | def pad_begin(input, k): 7 | shp = tf.shape(input) 8 | return tf.concat([tf.ones([shp[0], k], dtype=tf.int32) * 2, input], 1) 9 | 10 | 11 | def decoder_self_attention_bias(length, k): 12 | max_length = 500 13 | max_length = (max_length // k) * k 14 | m = np.zeros([max_length, max_length], dtype=np.float32) 15 | for i in xrange(max_length // k): 16 | m[i * k: i * k + k, :i * k + k] = 1.0 17 | m = tf.convert_to_tensor(m) 18 | m = m[:length, :length] 19 | ret = -1e9 * (1.0 - m) 20 | return tf.reshape(ret, [1, 1, length, length]) 21 | 22 | 23 | class PTransformer(Transformer): 24 | def __init__(self, *args, **kargs): 25 | super(PTransformer, self).__init__(*args, **kargs) 26 | self._use_cache = True 27 | 28 | def decoder_impl(self, decoder_input, encoder_output, is_training): 29 | attention_dropout_rate = self._config.attention_dropout_rate if is_training else 0.0 30 | residual_dropout_rate = self._config.residual_dropout_rate if is_training else 0.0 31 | 32 | num_parallel = self._config.num_parallel 33 | padded_decoder_input = pad_begin(decoder_input, num_parallel - 1) 34 | length = tf.floor_div(tf.shape(decoder_input)[1] + self._config.num_parallel - 1, 35 | self._config.num_parallel) * self._config.num_parallel 36 | padded_decoder_input = padded_decoder_input[:, :length] 37 | 38 | encoder_padding = tf.equal(tf.reduce_sum(tf.abs(encoder_output), axis=-1), 0.0) 39 | encoder_attention_bias = common_attention.attention_bias_ignore_padding(encoder_padding) 40 | decoder_output = embedding(padded_decoder_input, 41 | vocab_size=self._config.dst_vocab_size, 42 | dense_size=self._config.hidden_units, 43 | kernel=self._dst_embedding, 44 | multiplier=self._config.hidden_units ** 0.5 if self._config.scale_embedding else 1.0, 45 | name="dst_embedding") 46 | # Positional Encoding 47 | decoder_output = common_attention.add_timing_signal_1d(decoder_output) 48 | 49 | # Dropout 50 | decoder_output = tf.layers.dropout(decoder_output, 51 | rate=residual_dropout_rate, 52 | training=is_training) 53 | # Bias for preventing peeping later information 54 | self_attention_bias = decoder_self_attention_bias(tf.shape(decoder_output)[1], self._config.num_parallel) 55 | # Blocks 56 | for i in range(self._config.num_blocks): 57 | with tf.variable_scope("block_{}".format(i)): 58 | # Multihead Attention (self-attention) 59 | decoder_output = residual(decoder_output, 60 | multihead_attention( 61 | query_antecedent=decoder_output, 62 | memory_antecedent=None, 63 | bias=self_attention_bias, 64 | total_key_depth=self._config.hidden_units, 65 | total_value_depth=self._config.hidden_units, 66 | num_heads=self._config.num_heads, 67 | dropout_rate=attention_dropout_rate, 68 | output_depth=self._config.hidden_units, 69 | name="decoder_self_attention", 70 | summaries=True), 71 | dropout_rate=residual_dropout_rate) 72 | 73 | # Multihead Attention (vanilla attention) 74 | decoder_output = residual(decoder_output, 75 | multihead_attention( 76 | query_antecedent=decoder_output, 77 | memory_antecedent=encoder_output, 78 | bias=encoder_attention_bias, 79 | total_key_depth=self._config.hidden_units, 80 | total_value_depth=self._config.hidden_units, 81 | output_depth=self._config.hidden_units, 82 | num_heads=self._config.num_heads, 83 | dropout_rate=attention_dropout_rate, 84 | name="decoder_vanilla_attention", 85 | summaries=True), 86 | dropout_rate=residual_dropout_rate) 87 | 88 | # Position-wise Feed Forward 89 | decoder_output = residual(decoder_output, 90 | ff_hidden( 91 | decoder_output, 92 | hidden_size=self._config.ff_hidden_units, 93 | output_size=self._config.hidden_units, 94 | activation=self._ff_activation), 95 | dropout_rate=residual_dropout_rate) 96 | 97 | decoder_output = decoder_output[:, :tf.shape(decoder_input)[1]] 98 | 99 | return decoder_output 100 | 101 | def decoder_with_caching_impl(self, decoder_input, decoder_cache, encoder_output, is_training): 102 | 103 | attention_dropout_rate = self._config.attention_dropout_rate if is_training else 0.0 104 | residual_dropout_rate = self._config.residual_dropout_rate if is_training else 0.0 105 | 106 | encoder_padding = tf.equal(tf.reduce_sum(tf.abs(encoder_output), axis=-1), 0.0) 107 | encoder_attention_bias = common_attention.attention_bias_ignore_padding(encoder_padding) 108 | 109 | padded_decoder_input = pad_begin(decoder_input, self._config.num_parallel - 1) 110 | decoder_output = embedding(padded_decoder_input, 111 | vocab_size=self._config.dst_vocab_size, 112 | dense_size=self._config.hidden_units, 113 | kernel=self._dst_embedding, 114 | multiplier=self._config.hidden_units ** 0.5 if self._config.scale_embedding else 1.0, 115 | name="dst_embedding") 116 | 117 | # Positional Encoding 118 | decoder_output = common_attention.add_timing_signal_1d(decoder_output) 119 | 120 | # Dropout 121 | decoder_output = tf.layers.dropout(decoder_output, 122 | rate=residual_dropout_rate, 123 | training=is_training) 124 | 125 | num_parallel = self._config.num_parallel 126 | new_cache = [] 127 | 128 | # Blocks 129 | for i in range(self._config.num_blocks): 130 | with tf.variable_scope("block_{}".format(i)): 131 | # Multihead Attention (self-attention) 132 | decoder_output = residual(decoder_output[:, -num_parallel:, :], 133 | multihead_attention( 134 | query_antecedent=decoder_output, 135 | memory_antecedent=None, 136 | bias=None, 137 | total_key_depth=self._config.hidden_units, 138 | total_value_depth=self._config.hidden_units, 139 | num_heads=self._config.num_heads, 140 | dropout_rate=attention_dropout_rate, 141 | num_queries=num_parallel, 142 | output_depth=self._config.hidden_units, 143 | name="decoder_self_attention", 144 | summaries=True), 145 | dropout_rate=residual_dropout_rate) 146 | 147 | # Multihead Attention (vanilla attention) 148 | decoder_output = residual(decoder_output, 149 | multihead_attention( 150 | query_antecedent=decoder_output, 151 | memory_antecedent=encoder_output, 152 | bias=encoder_attention_bias, 153 | total_key_depth=self._config.hidden_units, 154 | total_value_depth=self._config.hidden_units, 155 | output_depth=self._config.hidden_units, 156 | num_heads=self._config.num_heads, 157 | dropout_rate=attention_dropout_rate, 158 | num_queries=num_parallel, 159 | name="decoder_vanilla_attention", 160 | summaries=True), 161 | dropout_rate=residual_dropout_rate) 162 | 163 | # Position-wise Feed Forward 164 | decoder_output = residual(decoder_output, 165 | ff_hidden( 166 | decoder_output, 167 | hidden_size=self._config.ff_hidden_units, 168 | output_size=self._config.hidden_units, 169 | activation=self._ff_activation), 170 | dropout_rate=residual_dropout_rate) 171 | 172 | decoder_output = tf.concat([decoder_cache[:, :, i, :], decoder_output], axis=1) 173 | new_cache.append(decoder_output[:, :, None, :]) 174 | 175 | new_cache = tf.concat(new_cache, axis=2) # [batch_size, n_step, num_blocks, num_hidden] 176 | 177 | return decoder_output, new_cache 178 | 179 | def test_output_multiple(self, decoder_output, k, reuse): 180 | """Predict num_parallel tokens at once.""" 181 | 182 | num_parallel = self._config.num_parallel 183 | with tf.variable_scope(self.decoder_scope, reuse=reuse): 184 | last_logits = dense(decoder_output[:, -num_parallel:], self._config.dst_vocab_size, use_bias=False, 185 | kernel=self._dst_softmax, name='dst_softmax', reuse=None) 186 | next_pred = tf.to_int32(tf.argmax(last_logits, axis=-1)) # [B, P] 187 | z = tf.nn.log_softmax(last_logits) 188 | next_scores, next_preds = tf.nn.top_k(z, k=k, sorted=False) # [B, P, K] 189 | next_preds = tf.to_int32(next_preds) 190 | 191 | return next_pred, next_preds, next_scores 192 | 193 | def beam_search(self, encoder_output, use_cache, reuse): 194 | """Beam search in graph.""" 195 | beam_size, batch_size = self._config.test.beam_size, tf.shape(encoder_output)[0] 196 | 197 | if beam_size == 1: 198 | return self.greedy_search(encoder_output, use_cache, reuse) 199 | 200 | inf = 1e10 201 | 202 | def get_bias_scores(scores, bias): 203 | """ 204 | If a sequence is finished, we only allow one alive branch. This function aims to give one branch a zero score 205 | and the rest -inf score. 206 | Args: 207 | scores: A real value array with shape [batch_size * beam_size, beam_size]. 208 | bias: A bool array with shape [batch_size * beam_size]. 209 | 210 | Returns: 211 | A real value array with shape [batch_size * beam_size, beam_size]. 212 | """ 213 | bias = tf.to_float(bias) 214 | b = tf.constant([0.0] + [-inf] * (beam_size - 1)) 215 | b = tf.tile(b[None, :], multiples=[batch_size * beam_size, 1]) 216 | return scores * (1 - bias[:, None]) + b * bias[:, None] 217 | 218 | def get_bias_preds(preds, bias): 219 | """ 220 | If a sequence is finished, all of its branch should be (3). 221 | Args: 222 | preds: A int array with shape [batch_size * beam_size, beam_size]. 223 | bias: A bool array with shape [batch_size * beam_size]. 224 | 225 | Returns: 226 | A int array with shape [batch_size * beam_size]. 227 | """ 228 | bias = tf.to_int32(bias) 229 | return preds * (1 - bias[:, None]) + bias[:, None] * 3 230 | 231 | # Prepare beam search inputs. 232 | # [batch_size, 1, *, hidden_units] 233 | encoder_output = encoder_output[:, None, :, :] 234 | # [batch_size, beam_size, *, hidden_units] 235 | encoder_output = tf.tile(encoder_output, multiples=[1, beam_size, 1, 1]) 236 | encoder_output = tf.reshape(encoder_output, [batch_size * beam_size, -1, encoder_output.get_shape()[-1].value]) 237 | # [[, , ..., ]], shape: [batch_size * beam_size, 1] 238 | preds = tf.ones([batch_size * beam_size, 1], dtype=tf.int32) * 2 239 | scores = tf.constant([0.0] + [-inf] * (beam_size - 1), dtype=tf.float32) # [beam_size] 240 | scores = tf.tile(scores, multiples=[batch_size]) # [batch_size * beam_size] 241 | lengths = tf.zeros([batch_size * beam_size], dtype=tf.float32) 242 | bias = tf.zeros_like(scores, dtype=tf.bool) 243 | Cache = namedtuple('Cache', ['decoder_cache', 'next_preds', 'next_scores']) 244 | caches = Cache( 245 | decoder_cache=tf.zeros([batch_size * beam_size, 0, self._config.num_blocks, self._config.hidden_units]), 246 | next_preds=tf.zeros([batch_size * beam_size, 0, self._config.test.beam_size], dtype=tf.int32), 247 | next_scores=tf.zeros([batch_size * beam_size, 0, self._config.test.beam_size])) 248 | 249 | def step(i, bias, preds, scores, lengths, caches): 250 | # Where are we. 251 | i += 1 252 | 253 | # Call decoder and get predictions. 254 | if use_cache: 255 | 256 | def compute(): 257 | decoder_output, decoder_cache = \ 258 | self.decoder_with_caching(preds, caches.decoder_cache, encoder_output, 259 | is_training=False, reuse=reuse) 260 | _, next_preds, next_scores = self.test_output_multiple(decoder_output, 261 | k=self._config.test.beam_size, 262 | reuse=reuse) 263 | new_caches = Cache(decoder_cache=decoder_cache, next_preds=next_preds, next_scores=next_scores) 264 | return new_caches 265 | 266 | def hit(): 267 | return caches 268 | 269 | cond = tf.equal(tf.shape(caches.next_preds)[1], 0) 270 | caches = tf.cond(cond, compute, hit) 271 | next_preds = caches.next_preds[:, 0] 272 | next_scores = caches.next_scores[:, 0] 273 | caches = Cache(decoder_cache=caches.decoder_cache, 274 | next_preds=caches.next_preds[:, 1:], 275 | next_scores=caches.next_scores[:, 1:]) 276 | else: 277 | decoder_output = self.decoder(preds, encoder_output, is_training=False, reuse=reuse) 278 | _, next_preds, next_scores = self.test_output(decoder_output, reuse=reuse) 279 | 280 | next_preds = get_bias_preds(next_preds, bias) 281 | next_scores = get_bias_scores(next_scores, bias) 282 | 283 | # Update scores. 284 | scores = scores[:, None] + next_scores # [batch_size * beam_size, beam_size] 285 | scores = tf.reshape(scores, shape=[batch_size, beam_size ** 2]) # [batch_size, beam_size * beam_size] 286 | 287 | # LP scores. 288 | lengths = lengths[:, None] + tf.to_float(tf.not_equal(next_preds, 3)) # [batch_size * beam_size, beam_size] 289 | lengths = tf.reshape(lengths, shape=[batch_size, beam_size ** 2]) # [batch_size, beam_size * beam_size] 290 | lp = tf.pow((5 + lengths) / (5 + 1), self._config.test.lp_alpha) # Length penalty 291 | lp_scores = scores / lp # following GNMT 292 | 293 | # Pruning 294 | _, k_indices = tf.nn.top_k(lp_scores, k=beam_size) 295 | base_indices = tf.reshape(tf.tile(tf.range(batch_size)[:, None], multiples=[1, beam_size]), shape=[-1]) 296 | base_indices *= beam_size ** 2 297 | k_indices = base_indices + tf.reshape(k_indices, shape=[-1]) # [batch_size * beam_size] 298 | 299 | # Update lengths. 300 | lengths = tf.reshape(lengths, [-1]) 301 | lengths = tf.gather(lengths, k_indices) 302 | 303 | # Update scores. 304 | scores = tf.reshape(scores, [-1]) 305 | scores = tf.gather(scores, k_indices) 306 | 307 | # Update predictions. 308 | next_preds = tf.gather(tf.reshape(next_preds, shape=[-1]), indices=k_indices) 309 | preds = tf.gather(preds, indices=k_indices / beam_size) 310 | if use_cache: 311 | caches = Cache(decoder_cache=tf.gather(caches.decoder_cache, indices=k_indices / beam_size), 312 | next_preds=tf.gather(caches.next_preds, indices=k_indices / beam_size), 313 | next_scores=tf.gather(caches.next_scores, indices=k_indices / beam_size)) 314 | preds = tf.concat((preds, next_preds[:, None]), axis=1) # [batch_size * beam_size, i] 315 | 316 | # Whether sequences finished. 317 | bias = tf.equal(preds[:, -1], 3) # ? 318 | 319 | return i, bias, preds, scores, lengths, caches 320 | 321 | def not_finished(i, bias, preds, scores, lengths, caches): 322 | return tf.logical_and( 323 | tf.reduce_any(tf.logical_not(bias)), 324 | tf.less_equal( 325 | i, 326 | tf.reduce_min([tf.shape(encoder_output)[1] + 50, self._config.test.max_target_length]) 327 | ) 328 | ) 329 | 330 | i, bias, preds, scores, lengths, caches = \ 331 | tf.while_loop(cond=not_finished, 332 | body=step, 333 | loop_vars=[0, bias, preds, scores, lengths, caches], 334 | shape_invariants=[ 335 | tf.TensorShape([]), 336 | tf.TensorShape([None]), 337 | tf.TensorShape([None, None]), 338 | tf.TensorShape([None]), 339 | tf.TensorShape([None]), 340 | Cache(decoder_cache=tf.TensorShape([None, None, None, None]), 341 | next_preds=tf.TensorShape([None, None, None]), 342 | next_scores=tf.TensorShape([None, None, None])) 343 | ], 344 | back_prop=False) 345 | 346 | scores = tf.reshape(scores, shape=[batch_size, beam_size]) 347 | preds = tf.reshape(preds, shape=[batch_size, beam_size, -1]) # [batch_size, beam_size, max_length] 348 | 349 | max_indices = tf.to_int32(tf.argmax(scores, axis=-1)) # [batch_size] 350 | max_indices += tf.range(batch_size) * beam_size 351 | preds = tf.reshape(preds, shape=[batch_size * beam_size, -1]) 352 | 353 | final_preds = tf.gather(preds, indices=max_indices) 354 | final_preds = final_preds[:, 1:] # remove flag 355 | return final_preds 356 | 357 | def greedy_search(self, encoder_output, use_cache, reuse): 358 | """Beam search in graph.""" 359 | batch_size = tf.shape(encoder_output)[0] 360 | num_parallel = self._config.num_parallel 361 | 362 | preds = tf.ones([batch_size, 1], dtype=tf.int32) * 2 363 | scores = tf.zeros([batch_size], dtype=tf.float32) 364 | finished = tf.zeros([batch_size], dtype=tf.bool) 365 | cache = tf.zeros([batch_size, 0, self._config.num_blocks, self._config.hidden_units]) 366 | 367 | def step(i, finished, preds, scores, cache): 368 | # Where are we. 369 | i += num_parallel 370 | 371 | # Call decoder and get predictions. 372 | decoder_output, cache = self.decoder_with_caching(preds, cache, encoder_output, is_training=False, reuse=reuse) 373 | _, next_preds, next_scores = self.test_output_multiple(decoder_output, k=1, reuse=reuse) 374 | next_preds = next_preds[:, :, 0] 375 | next_scores = tf.reduce_sum(next_scores[:, :, 0], axis=1) 376 | 377 | # Update. 378 | scores = scores + next_scores 379 | preds = tf.concat([preds, next_preds], axis=1) 380 | 381 | # Whether sequences finished. 382 | has_eos = tf.reduce_any(tf.equal(next_preds, 3), axis=1) 383 | finished = tf.logical_or(finished, has_eos) 384 | 385 | return i, finished, preds, scores, cache 386 | 387 | def not_finished(i, finished, preds, scores, cache): 388 | return tf.logical_and( 389 | tf.reduce_any(tf.logical_not(finished)), 390 | tf.less_equal( 391 | i, 392 | tf.reduce_min([tf.shape(encoder_output)[1] + 50, self._config.test.max_target_length]) 393 | ) 394 | ) 395 | 396 | i, finished, preds, scores, cache = \ 397 | tf.while_loop(cond=not_finished, 398 | body=step, 399 | loop_vars=[0, finished, preds, scores, cache], 400 | shape_invariants=[ 401 | tf.TensorShape([]), 402 | tf.TensorShape([None]), 403 | tf.TensorShape([None, None]), 404 | tf.TensorShape([None]), 405 | tf.TensorShape([None, None, None, None])], 406 | back_prop=False) 407 | 408 | preds = preds[:, 1:] # remove flag 409 | return preds 410 | 411 | # def test_output_beam(self, decoder_output, reuse): 412 | # beam_size = self._config.test.beam_size 413 | # num_parallel = self._config.num_parallel 414 | # eos = 3 415 | # k = int(np.ceil(np.power(beam_size, 1.0 / num_parallel))) 416 | # _, next_preds, next_scores = self.test_output_multiple(decoder_output, k, reuse) # [batch_size, p, k] 417 | # 418 | # batch_size = tf.shape(decoder_output)[0] 419 | # scores = next_scores[:, 0, :] # [batch_size, k**1] 420 | # preds = next_preds[:, 0, :, None] # [batch_size, k**1, 1] 421 | # finished = np.zeros_like(scores) # [batch_size, k**1] 422 | # 423 | # def get_biased_scores(scores, finished): 424 | # pass 425 | # 426 | # def get_biased_preds(preds, finished): 427 | # pass 428 | # 429 | # for i in range(1, self._config.num_parallel): 430 | # cur_preds = next_preds[:, i, :] # [batch_size, k] 431 | # cur_scores = next_scores[:, i, :] # [batch_size, k] 432 | # 433 | # finished = finished[:, :, None] 434 | # finished = tf.tile(finished, [1, 1, k]) # [batch_size, k**i, k] 435 | # finished = tf.mul(finished, tf.to_float(tf.equal(cur_preds, eos)[:, None, :])) # [batch_size, k**i, k] 436 | # # finished = tf.reshape(finished, [batch_size, -1]) # [batch_size, k**(i+1)] 437 | # 438 | # scores = scores[:, :, None] 439 | # scores = tf.tile(scores, [1, 1, k]) 440 | # scores += finished * cur_scores[:, None, :] 441 | # scores = tf.reshape(scores, [batch_size, -1]) 442 | # 443 | # preds = preds[:, :, None, :] 444 | # preds = tf.tile(preds, [1, 1, k, 1]) # [batch_size, k**i, k, i] 445 | # preds = tf.reshape(preds, [batch_size, k ** (i+1), -1]) # [batch_size, k**(i+1), i] 446 | # cur_preds = tf.tile(cur_preds[:, :, None], [1, k ** i, 1]) # [batch_size, k**(i+1), 1] 447 | # preds = tf.concat([preds, cur_preds], axis=2) # [batch_size, k**(i+1), i+1] 448 | # 449 | # # Select the top beam_size predictions. 450 | # top_scores, top_indices = tf.nn.top_k(scores, beam_size) 451 | # flatten_top_indices = tf.reshape(top_indices, [-1]) 452 | # base_indices = tf.reshape(tf.tile(tf.range(batch_size)[:, None], multiples=[1, beam_size]), shape=[-1]) 453 | # base_indices *= k**num_parallel 454 | # flatten_top_indices += base_indices 455 | # flatten_preds = tf.reshape(preds, [-1, num_parallel]) 456 | # top_preds = tf.gather(flatten_preds, flatten_top_indices) 457 | # top_preds = tf.reshape(top_preds, [batch_size, beam_size, num_parallel]) 458 | # 459 | # return top_preds, top_scores # [batch_size, beam_size, num_parallel], [batch_size, beam_size] 460 | 461 | # def faster_beam_search(self, encoder_output, use_cache, reuse): 462 | # """Beam search in graph.""" 463 | # beam_size, batch_size = self._config.test.beam_size, tf.shape(encoder_output)[0] 464 | # inf = 1e10 465 | # 466 | # def get_bias_scores(scores, bias): 467 | # """ 468 | # If a sequence is finished, we only allow one alive branch. 469 | # This function aims to give one branch a zero score and the rest -inf score. 470 | # Args: 471 | # scores: A real value array with shape [batch_size * beam_size, beam_size]. 472 | # bias: A bool array with shape [batch_size * beam_size]. 473 | # 474 | # Returns: 475 | # A real value array with shape [batch_size * beam_size, beam_size]. 476 | # """ 477 | # bias = tf.to_float(bias) 478 | # b = tf.constant([0.0] + [-inf] * (beam_size - 1)) 479 | # b = tf.tile(b[None, :], multiples=[batch_size * beam_size, 1]) 480 | # return scores * (1 - bias[:, None]) + b * bias[:, None] 481 | # 482 | # def get_bias_preds(preds, bias): 483 | # """ 484 | # If a sequence is finished, all of its branch should be (3). 485 | # Args: 486 | # preds: A int array with shape [batch_size * beam_size, beam_size]. 487 | # bias: A bool array with shape [batch_size * beam_size]. 488 | # 489 | # Returns: 490 | # A int array with shape [batch_size * beam_size]. 491 | # """ 492 | # bias = tf.to_int32(bias) 493 | # return preds * (1 - bias[:, None]) + bias[:, None] * 3 494 | # 495 | # # Prepare beam search inputs. 496 | # # [batch_size, 1, *, hidden_units] 497 | # encoder_output = encoder_output[:, None, :, :] 498 | # # [batch_size, beam_size, *, hidden_units] 499 | # encoder_output = tf.tile(encoder_output, multiples=[1, beam_size, 1, 1]) 500 | # encoder_output = tf.reshape(encoder_output, [batch_size * beam_size, -1, encoder_output.get_shape()[-1].value]) 501 | # # [[, , ..., ]], shape: [batch_size * beam_size, 1] 502 | # preds = tf.ones([batch_size * beam_size, 1], dtype=tf.int32) * 2 503 | # scores = tf.constant([0.0] + [-inf] * (beam_size - 1), dtype=tf.float32) # [beam_size] 504 | # scores = tf.tile(scores, multiples=[batch_size]) # [batch_size * beam_size] 505 | # lengths = tf.zeros([batch_size * beam_size], dtype=tf.float32) 506 | # bias = tf.zeros_like(scores, dtype=tf.bool) 507 | # 508 | # if use_cache: 509 | # cache = tf.zeros([batch_size * beam_size, 0, self._config.num_blocks, self._config.hidden_units]) 510 | # else: 511 | # cache = tf.zeros([0, 0, 0, 0]) 512 | # 513 | # def step(i, bias, preds, scores, lengths, cache): 514 | # # Where are we. 515 | # i += 1 516 | # 517 | # # Call decoder and get predictions. 518 | # if use_cache: 519 | # decoder_output, cache = \ 520 | # self.decoder_with_caching(preds, cache, encoder_output, is_training=False, reuse=reuse) 521 | # else: 522 | # decoder_output = self.decoder(preds, encoder_output, is_training=False, reuse=reuse) 523 | # 524 | # # next_preds: [batch_size, beam_size, num_parallel] 525 | # # next_scores: [batch_size, beam_size] 526 | # _, next_k_preds, next_k_scores = self.test_output(decoder_output, reuse=reuse) 527 | # 528 | # next_preds = get_bias_preds(next_preds, bias) 529 | # next_scores = get_bias_scores(next_scores, bias) 530 | # 531 | # # Update scores. 532 | # scores = scores[:, None] + next_scores # [batch_size * beam_size, beam_size] 533 | # scores = tf.reshape(scores, shape=[batch_size, beam_size ** 2]) # [batch_size, beam_size * beam_size] 534 | # 535 | # # LP scores. 536 | # lengths = lengths[:, None] + tf.to_float(tf.not_equal(next_preds, 3)) # [batch_size * beam_size, beam_size] 537 | # lengths = tf.reshape(lengths, shape=[batch_size, beam_size ** 2]) # [batch_size, beam_size * beam_size] 538 | # lp = tf.pow((5 + lengths) / (5 + 1), self._config.test.lp_alpha) # Length penalty 539 | # lp_scores = scores / lp # following GNMT 540 | # 541 | # # Pruning 542 | # _, k_indices = tf.nn.top_k(lp_scores, k=beam_size) 543 | # base_indices = tf.reshape(tf.tile(tf.range(batch_size)[:, None], multiples=[1, beam_size]), shape=[-1]) 544 | # base_indices *= beam_size ** 2 545 | # k_indices = base_indices + tf.reshape(k_indices, shape=[-1]) # [batch_size * beam_size] 546 | # 547 | # # Update lengths. 548 | # lengths = tf.reshape(lengths, [-1]) 549 | # lengths = tf.gather(lengths, k_indices) 550 | # 551 | # # Update scores. 552 | # scores = tf.reshape(scores, [-1]) 553 | # scores = tf.gather(scores, k_indices) 554 | # 555 | # # Update predictions. 556 | # next_preds = tf.gather(tf.reshape(next_preds, shape=[-1]), indices=k_indices) 557 | # preds = tf.gather(preds, indices=k_indices / beam_size) 558 | # if use_cache: 559 | # cache = tf.gather(cache, indices=k_indices / beam_size) 560 | # preds = tf.concat((preds, next_preds[:, None]), axis=1) # [batch_size * beam_size, i] 561 | # 562 | # # Whether sequences finished. 563 | # bias = tf.equal(preds[:, -1], 3) # ? 564 | # 565 | # return i, bias, preds, scores, lengths, cache 566 | # 567 | # def not_finished(i, bias, preds, scores, lengths, cache): 568 | # return tf.logical_and( 569 | # tf.reduce_any(tf.logical_not(bias)), 570 | # tf.less_equal( 571 | # i, 572 | # tf.reduce_min([tf.shape(encoder_output)[1] + 50, self._config.test.max_target_length]) 573 | # ) 574 | # ) 575 | # 576 | # i, bias, preds, scores, lengths, cache = \ 577 | # tf.while_loop(cond=not_finished, 578 | # body=step, 579 | # loop_vars=[0, bias, preds, scores, lengths, cache], 580 | # shape_invariants=[ 581 | # tf.TensorShape([]), 582 | # tf.TensorShape([None]), 583 | # tf.TensorShape([None, None]), 584 | # tf.TensorShape([None]), 585 | # tf.TensorShape([None]), 586 | # tf.TensorShape([None, None, None, None])], 587 | # back_prop=False) 588 | # 589 | # scores = tf.reshape(scores, shape=[batch_size, beam_size]) 590 | # preds = tf.reshape(preds, shape=[batch_size, beam_size, -1]) # [batch_size, beam_size, max_length] 591 | # 592 | # max_indices = tf.to_int32(tf.argmax(scores, axis=-1)) # [batch_size] 593 | # max_indices += tf.range(batch_size) * beam_size 594 | # preds = tf.reshape(preds, shape=[batch_size * beam_size, -1]) 595 | # 596 | # final_preds = tf.gather(preds, indices=max_indices) 597 | # final_preds = final_preds[:, 1:] # remove flag 598 | # return final_preds 599 | 600 | # def test_loss(self, decoder_output, Y, reuse): 601 | # """This function help users to compute PPL during test.""" 602 | # with tf.variable_scope(self.decoder_scope, reuse=reuse): 603 | # logits = dense(decoder_output, self._config.dst_vocab_size, use_bias=False, 604 | # kernel=self._dst_softmax, name="decoder", reuse=None) 605 | # mask = tf.to_float(tf.not_equal(Y, 0)) 606 | # labels = tf.one_hot(Y, depth=self._config.dst_vocab_size) 607 | # loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels) 608 | # loss_sum = tf.reduce_sum(loss * mask) 609 | # # Position-wise PPL 610 | # lengths = tf.reduce_sum(mask, axis=1, keep_dims=True) 611 | # lengths_mask = tf.to_float(tf.greater(lengths, 12)) 612 | # loss_sum = tf.Print(loss_sum, 613 | # [tf.reduce_sum(loss * lengths_mask, axis=0)[:12] / 614 | # (tf.reduce_sum(lengths_mask) + 1e-6)], 615 | # summarize=15) 616 | # probs = tf.nn.softmax(logits) 617 | # return loss_sum, probs 618 | 619 | def train_output(self, decoder_output, Y, teacher_probs, reuse): 620 | """Calculate loss and accuracy.""" 621 | with tf.variable_scope(self.decoder_scope, reuse=reuse): 622 | logits = dense(decoder_output, self._config.dst_vocab_size, use_bias=False, 623 | kernel=self._dst_softmax, name='decoder', reuse=None) 624 | preds = tf.to_int32(tf.argmax(logits, axis=-1)) 625 | mask = tf.to_float(tf.not_equal(Y, 0)) 626 | 627 | # Token-level accuracy 628 | acc = tf.reduce_sum(tf.to_float(tf.equal(preds, Y)) * mask) / tf.reduce_sum(mask) 629 | if not tf.get_variable_scope().reuse: 630 | tf.summary.scalar('accuracy', acc) 631 | 632 | if teacher_probs is not None: 633 | # Knowledge distillation 634 | loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=teacher_probs) 635 | else: 636 | # Smoothed loss 637 | loss = common_layers.smoothing_cross_entropy(logits=logits, labels=Y, 638 | vocab_size=self._config.dst_vocab_size, 639 | confidence=1 - self._config.train.label_smoothing) 640 | loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask) 641 | 642 | self.register_loss('ml_loss', loss) 643 | -------------------------------------------------------------------------------- /models/rnnsearch.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.ops.rnn_cell import GRUCell 2 | from model import Model 3 | from utils import * 4 | 5 | 6 | class RNNSearch(Model): 7 | def __init__(self, *args, **kargs): 8 | super(RNNSearch, self).__init__(*args, **kargs) 9 | self._use_daisy_chain_getter = False 10 | 11 | def encoder_impl(self, encoder_input, is_training): 12 | dropout_rate = self._config.dropout_rate if is_training else 0.0 13 | 14 | # Mask 15 | encoder_mask = tf.to_int32(tf.not_equal(encoder_input, 0)) 16 | sequence_lengths = tf.reduce_sum(encoder_mask, axis=1) 17 | 18 | # Embedding 19 | encoder_output = embedding(encoder_input, 20 | vocab_size=self._config.src_vocab_size, 21 | dense_size=self._config.hidden_units, 22 | kernel=self._src_embedding, 23 | multiplier=self._config.hidden_units ** 0.5 if self._config.scale_embedding else 1.0, 24 | name="src_embedding") 25 | 26 | # Dropout 27 | encoder_output = tf.layers.dropout(encoder_output, rate=dropout_rate, training=is_training) 28 | 29 | cell_fw = GRUCell(num_units=self._config.hidden_units, name='fw_cell') 30 | cell_bw = GRUCell(num_units=self._config.hidden_units, name='bw_cell') 31 | 32 | # RNN 33 | encoder_outputs, _ = tf.nn.bidirectional_dynamic_rnn( 34 | cell_fw=cell_fw, cell_bw=cell_bw, 35 | inputs=encoder_output, 36 | sequence_length=sequence_lengths, 37 | dtype=tf.float32 38 | ) 39 | 40 | encoder_output = tf.concat(encoder_outputs, axis=2) 41 | 42 | # Dropout 43 | encoder_output = tf.layers.dropout(encoder_output, rate=dropout_rate, training=is_training) 44 | 45 | # Mask 46 | encoder_output *= tf.expand_dims(tf.to_float(encoder_mask), axis=-1) 47 | 48 | return encoder_output 49 | 50 | def decoder_impl(self, decoder_input, encoder_output, is_training): 51 | dropout_rate = self._config.dropout_rate if is_training else 0.0 52 | 53 | attention_bias = tf.equal(tf.reduce_sum(tf.abs(encoder_output), axis=-1, keepdims=True), 0.0) 54 | attention_bias = tf.to_float(attention_bias) * (- 1e9) 55 | 56 | decoder_output = embedding(decoder_input, 57 | vocab_size=self._config.dst_vocab_size, 58 | dense_size=self._config.hidden_units, 59 | kernel=self._dst_embedding, 60 | multiplier=self._config.hidden_units ** 0.5 if self._config.scale_embedding else 1.0, 61 | name="dst_embedding") 62 | decoder_output = tf.layers.dropout(decoder_output, rate=dropout_rate, training=is_training) 63 | cell = AttentionGRUCell(num_units=self._config.hidden_units, 64 | attention_memories=encoder_output, 65 | attention_bias=attention_bias, 66 | reuse=tf.AUTO_REUSE, 67 | name='attention_cell') 68 | decoder_output, _ = tf.nn.dynamic_rnn(cell=cell, inputs=decoder_output, dtype=tf.float32) 69 | decoder_output = tf.layers.dropout(decoder_output, rate=dropout_rate, training=is_training) 70 | 71 | return decoder_output 72 | 73 | def decoder_with_caching_impl(self, decoder_input, decoder_cache, encoder_output, is_training): 74 | dropout_rate = self._config.dropout_rate if is_training else 0.0 75 | decoder_input = decoder_input[:, -1] 76 | attention_bias = tf.equal(tf.reduce_sum(tf.abs(encoder_output), axis=-1, keepdims=True), 0.0) 77 | attention_bias = tf.to_float(attention_bias) * (- 1e9) 78 | decoder_output = embedding(decoder_input, 79 | vocab_size=self._config.dst_vocab_size, 80 | dense_size=self._config.hidden_units, 81 | kernel=self._dst_embedding, 82 | multiplier=self._config.hidden_units ** 0.5 if self._config.scale_embedding else 1.0, 83 | name="dst_embedding") 84 | cell = AttentionGRUCell(num_units=self._config.hidden_units, 85 | attention_memories=encoder_output, 86 | attention_bias=attention_bias, 87 | reuse=tf.AUTO_REUSE, 88 | name='attention_cell') 89 | decoder_cache = tf.cond(tf.equal(tf.shape(decoder_cache)[1], 0), 90 | lambda: tf.zeros([tf.shape(decoder_input)[0], 1, 1, self._config.hidden_units]), 91 | lambda: decoder_cache) 92 | with tf.variable_scope('rnn'): 93 | decoder_output, _ = cell(decoder_output, decoder_cache[:, -1, -1, :]) 94 | decoder_output = tf.layers.dropout(decoder_output, rate=dropout_rate, training=is_training) 95 | return decoder_output[:, None, :], decoder_output[:, None, None, :] 96 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.ops.rnn_cell import GRUCell 2 | 3 | from model import Model 4 | from utils import * 5 | 6 | 7 | class Transformer(Model): 8 | def __init__(self, *args, **kargs): 9 | super(Transformer, self).__init__(*args, **kargs) 10 | activations = {"relu": tf.nn.relu, 11 | "sigmoid": tf.sigmoid, 12 | "tanh": tf.tanh, 13 | "swish": lambda x: x * tf.sigmoid(x), 14 | "glu": lambda x, y: x * tf.sigmoid(y)} 15 | self._ff_activation = activations[self._config.ff_activation or 'relu'] 16 | 17 | def encoder_impl(self, encoder_input, is_training): 18 | 19 | attention_dropout_rate = self._config.attention_dropout_rate if is_training else 0.0 20 | residual_dropout_rate = self._config.residual_dropout_rate if is_training else 0.0 21 | 22 | # Mask 23 | encoder_padding = tf.equal(encoder_input, 0) 24 | encoder_attention_bias = common_attention.attention_bias_ignore_padding(encoder_padding) 25 | # encoder_attention_bias = tf.tile(encoder_attention_bias, 26 | # [1, self._config.num_heads, tf.shape(encoder_attention_bias)[-1], 1]) 27 | 28 | # Embedding 29 | encoder_output = embedding(encoder_input, 30 | vocab_size=self._config.src_vocab_size, 31 | dense_size=self._config.hidden_units, 32 | kernel=self._src_embedding, 33 | multiplier=self._config.hidden_units ** 0.5 if self._config.scale_embedding else 1.0, 34 | name="src_embedding") 35 | # Add positional signal 36 | encoder_output = common_attention.add_timing_signal_1d(encoder_output) 37 | # Dropout 38 | encoder_output = tf.layers.dropout(encoder_output, 39 | rate=residual_dropout_rate, 40 | training=is_training) 41 | 42 | # Blocks 43 | for i in range(self._config.num_blocks): 44 | with tf.variable_scope("block_{}".format(i)): 45 | # Multihead Attention 46 | encoder_output = residual(encoder_output, 47 | multihead_attention( 48 | query_antecedent=encoder_output, 49 | memory_antecedent=None, 50 | bias=encoder_attention_bias, 51 | total_key_depth=self._config.hidden_units, 52 | total_value_depth=self._config.hidden_units, 53 | output_depth=self._config.hidden_units, 54 | num_heads=self._config.num_heads, 55 | dropout_rate=attention_dropout_rate, 56 | name='encoder_self_attention', 57 | summaries=True), 58 | dropout_rate=residual_dropout_rate) 59 | 60 | # Feed Forward 61 | encoder_output = residual(encoder_output, 62 | ff_hidden( 63 | inputs=encoder_output, 64 | hidden_size=self._config.ff_hidden_units, 65 | output_size=self._config.hidden_units, 66 | activation=self._ff_activation), 67 | dropout_rate=residual_dropout_rate) 68 | # Mask padding part to zeros. 69 | encoder_output *= tf.expand_dims(1.0 - tf.to_float(encoder_padding), axis=-1) 70 | return encoder_output 71 | 72 | def decoder_impl(self, decoder_input, encoder_output, is_training): 73 | 74 | attention_dropout_rate = self._config.attention_dropout_rate if is_training else 0.0 75 | residual_dropout_rate = self._config.residual_dropout_rate if is_training else 0.0 76 | 77 | encoder_padding = tf.equal(tf.reduce_sum(tf.abs(encoder_output), axis=-1), 0.0) 78 | encoder_attention_bias = common_attention.attention_bias_ignore_padding(encoder_padding) 79 | # encoder_attention_bias = tf.tile(encoder_attention_bias, 80 | # [1, self._config.num_heads, tf.shape(encoder_attention_bias)[-1], 1]) 81 | 82 | decoder_output = embedding(decoder_input, 83 | vocab_size=self._config.dst_vocab_size, 84 | dense_size=self._config.hidden_units, 85 | kernel=self._dst_embedding, 86 | multiplier=self._config.hidden_units ** 0.5 if self._config.scale_embedding else 1.0, 87 | name="dst_embedding") 88 | # Positional Encoding 89 | decoder_output = common_attention.add_timing_signal_1d(decoder_output) 90 | # Dropout 91 | decoder_output = tf.layers.dropout(decoder_output, 92 | rate=residual_dropout_rate, 93 | training=is_training) 94 | # Bias for preventing peeping later information 95 | self_attention_bias = common_attention.attention_bias_lower_triangle(tf.shape(decoder_input)[1]) 96 | 97 | # Blocks 98 | for i in range(self._config.num_blocks): 99 | with tf.variable_scope("block_{}".format(i)): 100 | # Multihead Attention (self-attention) 101 | decoder_output = residual(decoder_output, 102 | multihead_attention( 103 | query_antecedent=decoder_output, 104 | memory_antecedent=None, 105 | bias=self_attention_bias, 106 | total_key_depth=self._config.hidden_units, 107 | total_value_depth=self._config.hidden_units, 108 | num_heads=self._config.num_heads, 109 | dropout_rate=attention_dropout_rate, 110 | output_depth=self._config.hidden_units, 111 | name="decoder_self_attention", 112 | summaries=True), 113 | dropout_rate=residual_dropout_rate) 114 | 115 | # Multihead Attention (vanilla attention) 116 | decoder_output = residual(decoder_output, 117 | multihead_attention( 118 | query_antecedent=decoder_output, 119 | memory_antecedent=encoder_output, 120 | bias=encoder_attention_bias, 121 | total_key_depth=self._config.hidden_units, 122 | total_value_depth=self._config.hidden_units, 123 | output_depth=self._config.hidden_units, 124 | num_heads=self._config.num_heads, 125 | dropout_rate=attention_dropout_rate, 126 | name="decoder_vanilla_attention", 127 | summaries=True), 128 | dropout_rate=residual_dropout_rate) 129 | 130 | # Feed Forward 131 | decoder_output = residual(decoder_output, 132 | ff_hidden( 133 | decoder_output, 134 | hidden_size=self._config.ff_hidden_units, 135 | output_size=self._config.hidden_units, 136 | activation=self._ff_activation), 137 | dropout_rate=residual_dropout_rate) 138 | return decoder_output 139 | 140 | def decoder_with_caching_impl(self, decoder_input, decoder_cache, encoder_output, is_training): 141 | 142 | attention_dropout_rate = self._config.attention_dropout_rate if is_training else 0.0 143 | residual_dropout_rate = self._config.residual_dropout_rate if is_training else 0.0 144 | 145 | encoder_padding = tf.equal(tf.reduce_sum(tf.abs(encoder_output), axis=-1), 0.0) 146 | encoder_attention_bias = common_attention.attention_bias_ignore_padding(encoder_padding) 147 | # encoder_attention_bias = tf.tile(encoder_attention_bias, 148 | # [1, self._config.num_heads, 1, 1]) 149 | 150 | decoder_output = embedding(decoder_input, 151 | vocab_size=self._config.dst_vocab_size, 152 | dense_size=self._config.hidden_units, 153 | kernel=self._dst_embedding, 154 | multiplier=self._config.hidden_units ** 0.5 if self._config.scale_embedding else 1.0, 155 | name="dst_embedding") 156 | # Positional Encoding 157 | decoder_output = common_attention.add_timing_signal_1d(decoder_output) 158 | # Dropout 159 | decoder_output = tf.layers.dropout(decoder_output, 160 | rate=residual_dropout_rate, 161 | training=is_training) 162 | 163 | new_cache = [] 164 | 165 | # Blocks 166 | for i in range(self._config.num_blocks): 167 | with tf.variable_scope("block_{}".format(i)): 168 | # Multihead Attention (self-attention) 169 | decoder_output = residual(decoder_output[:, -1:, :], 170 | multihead_attention( 171 | query_antecedent=decoder_output, 172 | memory_antecedent=None, 173 | bias=None, 174 | total_key_depth=self._config.hidden_units, 175 | total_value_depth=self._config.hidden_units, 176 | num_heads=self._config.num_heads, 177 | dropout_rate=attention_dropout_rate, 178 | num_queries=1, 179 | output_depth=self._config.hidden_units, 180 | name="decoder_self_attention", 181 | summaries=True), 182 | dropout_rate=residual_dropout_rate) 183 | 184 | # Multihead Attention (vanilla attention) 185 | decoder_output = residual(decoder_output, 186 | multihead_attention( 187 | query_antecedent=decoder_output, 188 | memory_antecedent=encoder_output, 189 | bias=encoder_attention_bias, 190 | total_key_depth=self._config.hidden_units, 191 | total_value_depth=self._config.hidden_units, 192 | output_depth=self._config.hidden_units, 193 | num_heads=self._config.num_heads, 194 | dropout_rate=attention_dropout_rate, 195 | num_queries=1, 196 | name="decoder_vanilla_attention", 197 | summaries=True), 198 | dropout_rate=residual_dropout_rate) 199 | 200 | # Feed Forward 201 | decoder_output = residual(decoder_output, 202 | ff_hidden( 203 | decoder_output, 204 | hidden_size=self._config.ff_hidden_units, 205 | output_size=self._config.hidden_units, 206 | activation=self._ff_activation), 207 | dropout_rate=residual_dropout_rate) 208 | 209 | decoder_output = tf.concat([decoder_cache[:, :, i, :], decoder_output], axis=1) 210 | new_cache.append(decoder_output[:, :, None, :]) 211 | 212 | new_cache = tf.concat(new_cache, axis=2) # [batch_size, n_step, num_blocks, num_hidden] 213 | 214 | return decoder_output, new_cache 215 | -------------------------------------------------------------------------------- /multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | sub add_to_ref { 35 | my ($file,$REF) = @_; 36 | my $s=0; 37 | open(REF,$file) or die "Can't read $file"; 38 | while() { 39 | chop; 40 | push @{$$REF[$s++]}, $_; 41 | } 42 | close(REF); 43 | } 44 | 45 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 46 | my $s=0; 47 | while() { 48 | chop; 49 | $_ = lc if $lowercase; 50 | my @WORD = split; 51 | my %REF_NGRAM = (); 52 | my $length_translation_this_sentence = scalar(@WORD); 53 | my ($closest_diff,$closest_length) = (9999,9999); 54 | foreach my $reference (@{$REF[$s]}) { 55 | # print "$s $_ <=> $reference\n"; 56 | $reference = lc($reference) if $lowercase; 57 | my @WORD = split(' ',$reference); 58 | my $length = scalar(@WORD); 59 | my $diff = abs($length_translation_this_sentence-$length); 60 | if ($diff < $closest_diff) { 61 | $closest_diff = $diff; 62 | $closest_length = $length; 63 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 64 | } elsif ($diff == $closest_diff) { 65 | $closest_length = $length if $length < $closest_length; 66 | # from two references with the same closeness to me 67 | # take the *shorter* into account, not the "first" one. 68 | } 69 | for(my $n=1;$n<=4;$n++) { 70 | my %REF_NGRAM_N = (); 71 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 72 | my $ngram = "$n"; 73 | for(my $w=0;$w<$n;$w++) { 74 | $ngram .= " ".$WORD[$start+$w]; 75 | } 76 | $REF_NGRAM_N{$ngram}++; 77 | } 78 | foreach my $ngram (keys %REF_NGRAM_N) { 79 | if (!defined($REF_NGRAM{$ngram}) || 80 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 81 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 82 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 83 | } 84 | } 85 | } 86 | } 87 | $length_translation += $length_translation_this_sentence; 88 | $length_reference += $closest_length; 89 | for(my $n=1;$n<=4;$n++) { 90 | my %T_NGRAM = (); 91 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 92 | my $ngram = "$n"; 93 | for(my $w=0;$w<$n;$w++) { 94 | $ngram .= " ".$WORD[$start+$w]; 95 | } 96 | $T_NGRAM{$ngram}++; 97 | } 98 | foreach my $ngram (keys %T_NGRAM) { 99 | $ngram =~ /^(\d+) /; 100 | my $n = $1; 101 | # my $corr = 0; 102 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 103 | $TOTAL[$n] += $T_NGRAM{$ngram}; 104 | if (defined($REF_NGRAM{$ngram})) { 105 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 106 | $CORRECT[$n] += $T_NGRAM{$ngram}; 107 | # $corr = $T_NGRAM{$ngram}; 108 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 109 | } 110 | else { 111 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 112 | # $corr = $REF_NGRAM{$ngram}; 113 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 114 | } 115 | } 116 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 117 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 118 | } 119 | } 120 | $s++; 121 | } 122 | my $brevity_penalty = 1; 123 | my $bleu = 0; 124 | 125 | my @bleu=(); 126 | 127 | for(my $n=1;$n<=4;$n++) { 128 | if (defined ($TOTAL[$n])){ 129 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 130 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 131 | }else{ 132 | $bleu[$n]=0; 133 | } 134 | } 135 | 136 | if ($length_reference==0){ 137 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 138 | exit(1); 139 | } 140 | 141 | if ($length_translation<$length_reference) { 142 | $brevity_penalty = exp(1-$length_reference/$length_translation); 143 | } 144 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 145 | my_log( $bleu[2] ) + 146 | my_log( $bleu[3] ) + 147 | my_log( $bleu[4] ) ) / 4) ; 148 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 149 | 100*$bleu, 150 | 100*$bleu[1], 151 | 100*$bleu[2], 152 | 100*$bleu[3], 153 | 100*$bleu[4], 154 | $brevity_penalty, 155 | $length_translation / $length_reference, 156 | $length_translation, 157 | $length_reference; 158 | 159 | sub my_log { 160 | return -9999999999 unless $_[0]; 161 | return log($_[0]); 162 | } 163 | -------------------------------------------------------------------------------- /third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chqiwang/transformer/b251331bd05e2e61b94ff38f992dd6bc2cc79ef7/third_party/__init__.py -------------------------------------------------------------------------------- /third_party/tensor2tensor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chqiwang/transformer/b251331bd05e2e61b94ff38f992dd6bc2cc79ef7/third_party/tensor2tensor/__init__.py -------------------------------------------------------------------------------- /third_party/tensor2tensor/avg_checkpoints.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Script to average values of variables in a list of checkpoint files.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | # Dependency imports 21 | 22 | import numpy as np 23 | import six 24 | from six.moves import zip # pylint: disable=redefined-builtin 25 | import tensorflow as tf 26 | 27 | flags = tf.flags 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_string("checkpoints", "", 31 | "Comma-separated list of checkpoints to average.") 32 | flags.DEFINE_string("prefix", "", 33 | "Prefix (e.g., directory) to append to each checkpoint.") 34 | flags.DEFINE_string("output_path", "/tmp/averaged.ckpt", 35 | "Path to output the averaged checkpoint to.") 36 | 37 | 38 | def checkpoint_exists(path): 39 | return (tf.gfile.Exists(path) or tf.gfile.Exists(path + ".meta") or 40 | tf.gfile.Exists(path + ".index")) 41 | 42 | 43 | def main(_): 44 | # Get the checkpoints list from flags and run some basic checks. 45 | checkpoints = [c.strip() for c in FLAGS.checkpoints.split(",")] 46 | checkpoints = [c for c in checkpoints if c] 47 | if not checkpoints: 48 | raise ValueError("No checkpoints provided for averaging.") 49 | if FLAGS.prefix: 50 | checkpoints = [FLAGS.prefix + c for c in checkpoints] 51 | checkpoints = [c for c in checkpoints if checkpoint_exists(c)] 52 | if not checkpoints: 53 | raise ValueError( 54 | "None of the provided checkpoints exist. %s" % FLAGS.checkpoints) 55 | 56 | # Read variables from all checkpoints and average them. 57 | tf.logging.info("Reading variables and averaging checkpoints:") 58 | for c in checkpoints: 59 | tf.logging.info("%s ", c) 60 | var_list = tf.contrib.framework.list_variables(checkpoints[0]) 61 | var_values, var_dtypes = {}, {} 62 | for (name, shape) in var_list: 63 | if not name.startswith("global_step"): 64 | var_values[name] = np.zeros(shape) 65 | for checkpoint in checkpoints: 66 | reader = tf.contrib.framework.load_checkpoint(checkpoint) 67 | for name in var_values: 68 | tensor = reader.get_tensor(name) 69 | var_dtypes[name] = tensor.dtype 70 | var_values[name] += tensor 71 | tf.logging.info("Read from checkpoint %s", checkpoint) 72 | for name in var_values: # Average. 73 | var_values[name] /= len(checkpoints) 74 | 75 | tf_vars = [ 76 | tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name]) 77 | for v in var_values 78 | ] 79 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 80 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 81 | global_step = tf.Variable( 82 | 0, name="global_step", trainable=False, dtype=tf.int64) 83 | saver = tf.train.Saver(tf.all_variables()) 84 | 85 | # Build a model consisting only of variables, set them to the average values. 86 | with tf.Session() as sess: 87 | sess.run(tf.initialize_all_variables()) 88 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 89 | six.iteritems(var_values)): 90 | sess.run(assign_op, {p: value}) 91 | # Use the built saver to save the averaged checkpoint. 92 | saver.save(sess, FLAGS.output_path, global_step=global_step) 93 | 94 | tf.logging.info("Averaged checkpoints saved in %s", FLAGS.output_path) 95 | 96 | 97 | if __name__ == "__main__": 98 | tf.app.run() 99 | -------------------------------------------------------------------------------- /third_party/tensor2tensor/common_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for attention.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import math 21 | 22 | import tensorflow as tf 23 | 24 | from third_party.tensor2tensor import common_layers 25 | 26 | 27 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 28 | """Adds a bunch of sinusoids of different frequencies to a Tensor. 29 | 30 | Each channel of the input Tensor is incremented by a sinusoid of a different 31 | frequency and phase. 32 | 33 | This allows attention to learn to use absolute and relative positions. 34 | Timing signals should be added to some precursors of both the query and the 35 | memory inputs to attention. 36 | 37 | The use of relative position is possible because sin(x+y) and cos(x+y) can be 38 | experessed in terms of y, sin(x) and cos(x). 39 | 40 | In particular, we use a geometric sequence of timescales starting with 41 | min_timescale and ending with max_timescale. The number of different 42 | timescales is equal to channels / 2. For each timescale, we 43 | generate the two sinusoidal signals sin(timestep/timescale) and 44 | cos(timestep/timescale). All of these sinusoids are concatenated in 45 | the channels dimension. 46 | 47 | Args: 48 | x: a Tensor with shape [batch, length, channels] 49 | min_timescale: a float 50 | max_timescale: a float 51 | 52 | Returns: 53 | a Tensor the same shape as x. 54 | """ 55 | length = tf.shape(x)[1] 56 | channels = tf.shape(x)[2] 57 | position = tf.to_float(tf.range(length)) 58 | num_timescales = channels // 2 59 | log_timescale_increment = ( 60 | math.log(float(max_timescale) / float(min_timescale)) / 61 | (tf.to_float(num_timescales) - 1)) 62 | inv_timescales = min_timescale * tf.exp( 63 | tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) 64 | scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0) 65 | signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) 66 | signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]]) 67 | signal = tf.reshape(signal, [1, length, channels]) 68 | return x + signal 69 | 70 | 71 | def add_timing_signal_nd(x, min_timescale=1.0, max_timescale=1.0e4): 72 | """Adds a bunch of sinusoids of different frequencies to a Tensor. 73 | 74 | Each channel of the input Tensor is incremented by a sinusoid of a different 75 | frequency and phase in one of the positional dimensions. 76 | 77 | This allows attention to learn to use absolute and relative positions. 78 | Timing signals should be added to some precursors of both the query and the 79 | memory inputs to attention. 80 | 81 | The use of relative position is possible because sin(a+b) and cos(a+b) can be 82 | experessed in terms of b, sin(a) and cos(a). 83 | 84 | x is a Tensor with n "positional" dimensions, e.g. one dimension for a 85 | sequence or two dimensions for an image 86 | 87 | We use a geometric sequence of timescales starting with 88 | min_timescale and ending with max_timescale. The number of different 89 | timescales is equal to channels // (n * 2). For each timescale, we 90 | generate the two sinusoidal signals sin(timestep/timescale) and 91 | cos(timestep/timescale). All of these sinusoids are concatenated in 92 | the channels dimension. 93 | 94 | Args: 95 | x: a Tensor with shape [batch, d1 ... dn, channels] 96 | min_timescale: a float 97 | max_timescale: a float 98 | 99 | Returns: 100 | a Tensor the same shape as x. 101 | """ 102 | static_shape = x.get_shape().as_list() 103 | num_dims = len(static_shape) - 2 104 | channels = tf.shape(x)[-1] 105 | num_timescales = channels // (num_dims * 2) 106 | log_timescale_increment = ( 107 | math.log(float(max_timescale) / float(min_timescale)) / 108 | (tf.to_float(num_timescales) - 1)) 109 | inv_timescales = min_timescale * tf.exp( 110 | tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) 111 | for dim in xrange(num_dims): 112 | length = tf.shape(x)[dim + 1] 113 | position = tf.to_float(tf.range(length)) 114 | scaled_time = tf.expand_dims(position, 1) * tf.expand_dims( 115 | inv_timescales, 0) 116 | signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) 117 | prepad = dim * 2 * num_timescales 118 | postpad = channels - (dim + 1) * 2 * num_timescales 119 | signal = tf.pad(signal, [[0, 0], [prepad, postpad]]) 120 | for _ in xrange(1 + dim): 121 | signal = tf.expand_dims(signal, 0) 122 | for _ in xrange(num_dims - 1 - dim): 123 | signal = tf.expand_dims(signal, -2) 124 | x += signal 125 | return x 126 | 127 | 128 | def add_positional_embedding_nd(x, max_length, name): 129 | """Add n-dimensional positional embedding. 130 | 131 | Adds embeddings to represent the positional dimensions of the tensor. 132 | The input tensor has n positional dimensions - i.e. 1 for text, 2 for images, 133 | 3 for video, etc. 134 | 135 | Args: 136 | x: a Tensor with shape [batch, p1 ... pn, depth] 137 | max_length: an integer. static maximum size of any dimension. 138 | name: a name for this layer. 139 | 140 | Returns: 141 | a Tensor the same shape as x. 142 | """ 143 | static_shape = x.get_shape().as_list() 144 | dynamic_shape = tf.shape(x) 145 | num_dims = len(static_shape) - 2 146 | depth = static_shape[-1] 147 | base_shape = [1] * (num_dims + 1) + [depth] 148 | base_start = [0] * (num_dims + 2) 149 | base_size = [-1] + [1] * num_dims + [depth] 150 | for i in xrange(num_dims): 151 | shape = base_shape[:] 152 | start = base_start[:] 153 | size = base_size[:] 154 | shape[i + 1] = max_length 155 | size[i + 1] = dynamic_shape[i + 1] 156 | var = (tf.get_variable( 157 | name + "_%d" % i, shape, 158 | initializer=tf.random_normal_initializer(0, depth ** -0.5)) 159 | * (depth ** 0.5)) 160 | x += tf.slice(var, start, size) 161 | return x 162 | 163 | 164 | def embedding_to_padding(emb): 165 | """Input embeddings -> is_padding. 166 | 167 | We have hacked symbol_modality to return all-zero embeddings for padding. 168 | 169 | Args: 170 | emb: a Tensor with shape [..., depth]. 171 | Returns: 172 | a boolean Tensor with shape [...]. 173 | """ 174 | emb_sum = tf.reduce_sum(tf.abs(emb), axis=-1) 175 | return tf.equal(emb_sum, 0.0) 176 | 177 | 178 | def attention_bias_lower_triangle(length): 179 | """Create an bias tensor to be added to attention logits. 180 | 181 | Args: 182 | length: a Scalar. 183 | 184 | Returns: 185 | a `Tensor` with shape [1, 1, length, length]. 186 | """ 187 | lower_triangle = tf.matrix_band_part(tf.ones([length, length]), -1, 0) 188 | ret = -1e9 * (1.0 - lower_triangle) 189 | return tf.reshape(ret, [1, 1, length, length]) 190 | 191 | 192 | def attention_bias_ignore_padding(memory_padding): 193 | """Create an bias tensor to be added to attention logits. 194 | 195 | Args: 196 | memory_padding: a boolean `Tensor` with shape [batch, memory_length]. 197 | 198 | Returns: 199 | a `Tensor` with shape [batch, 1, 1, memory_length]. 200 | """ 201 | ret = tf.to_float(memory_padding) * -1e9 202 | return tf.expand_dims(tf.expand_dims(ret, 1), 1) 203 | 204 | 205 | def split_last_dimension(x, n): 206 | """Reshape x so that the last dimension becomes two dimensions. 207 | 208 | The first of these two dimensions is n. 209 | 210 | Args: 211 | x: a Tensor with shape [..., m] 212 | n: an integer. 213 | 214 | Returns: 215 | a Tensor with shape [..., n, m/n] 216 | """ 217 | old_shape = x.get_shape().dims 218 | last = old_shape[-1] 219 | new_shape = old_shape[:-1] + [n] + [last // n if last else None] 220 | ret = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [n, -1]], 0)) 221 | ret.set_shape(new_shape) 222 | return ret 223 | 224 | 225 | def combine_last_two_dimensions(x): 226 | """Reshape x so that the last two dimension become one. 227 | 228 | Args: 229 | x: a Tensor with shape [..., a, b] 230 | 231 | Returns: 232 | a Tensor with shape [..., ab] 233 | """ 234 | old_shape = x.get_shape().dims 235 | a, b = old_shape[-2:] 236 | new_shape = old_shape[:-2] + [a * b if a and b else None] 237 | ret = tf.reshape(x, tf.concat([tf.shape(x)[:-2], [-1]], 0)) 238 | ret.set_shape(new_shape) 239 | return ret 240 | 241 | 242 | def split_heads(x, num_heads): 243 | """Split channels (dimension 3) into multiple heads (becomes dimension 1). 244 | 245 | Args: 246 | x: a Tensor with shape [batch, length, channels] 247 | num_heads: an integer 248 | 249 | Returns: 250 | a Tensor with shape [batch, num_heads, length, channels / num_heads] 251 | """ 252 | return tf.transpose(split_last_dimension(x, num_heads), [0, 2, 1, 3]) 253 | 254 | 255 | def combine_heads(x): 256 | """Inverse of split_heads. 257 | 258 | Args: 259 | x: a Tensor with shape [batch, num_heads, length, channels / num_heads] 260 | 261 | Returns: 262 | a Tensor with shape [batch, length, channels] 263 | """ 264 | return combine_last_two_dimensions(tf.transpose(x, [0, 2, 1, 3])) 265 | 266 | 267 | def attention_image_summary(attn, image_shapes=None): 268 | """Compute color image summary. 269 | 270 | Args: 271 | attn: a Tensor with shape [batch, num_heads, query_length, memory_length] 272 | image_shapes: optional quadruple of integer scalars. 273 | If the query positions and memory positions represent the 274 | pixels of a flattened image, then pass in their dimensions: 275 | (query_rows, query_cols, memory_rows, memory_cols). 276 | """ 277 | num_heads = attn.get_shape().as_list()[1] 278 | # [batch, query_length, memory_length, num_heads] 279 | image = tf.transpose(attn, [0, 2, 3, 1]) 280 | image = tf.pow(image, 0.2) # for high-dynamic-range 281 | # Each head will correspond to one of RGB. 282 | # pad the heads to be a multiple of 3 283 | image = tf.pad(image, [[0, 0], [0, 0], [0, 0], [0, -num_heads % 3]]) 284 | image = split_last_dimension(image, 3) 285 | image = tf.reduce_max(image, 4) 286 | if image_shapes is not None: 287 | q_rows, q_cols, m_rows, m_cols = list(image_shapes) 288 | image = tf.reshape(image, [-1, q_rows, q_cols, m_rows, m_cols, 3]) 289 | image = tf.transpose(image, [0, 1, 3, 2, 4, 5]) 290 | image = tf.reshape(image, [-1, q_rows * m_rows, q_cols * m_cols, 3]) 291 | tf.summary.image("attention", image, max_outputs=1) 292 | 293 | 294 | def dot_product_attention(q, 295 | k, 296 | v, 297 | bias, 298 | dropout_rate=0.0, 299 | summaries=False, 300 | image_shapes=None, 301 | name=None): 302 | """dot-product attention. 303 | 304 | Args: 305 | q: a Tensor with shape [batch, heads, length_q, depth_k] 306 | k: a Tensor with shape [batch, heads, length_kv, depth_k] 307 | v: a Tensor with shape [batch, heads, length_kv, depth_v] 308 | bias: bias Tensor (see attention_bias()) 309 | dropout_rate: a floating point number 310 | summaries: a boolean 311 | image_shapes: optional quadruple of integer scalars for image summary. 312 | If the query positions and memory positions represent the 313 | pixels of a flattened image, then pass in their dimensions: 314 | (query_rows, query_cols, memory_rows, memory_cols). 315 | name: an optional string 316 | 317 | Returns: 318 | A Tensor. 319 | """ 320 | with tf.variable_scope( 321 | name, default_name="dot_product_attention", values=[q, k, v]): 322 | # [batch, num_heads, query_length, memory_length] 323 | logits = tf.matmul(q, k, transpose_b=True) 324 | if bias is not None: 325 | logits += bias 326 | weights = tf.nn.softmax(logits, name="attention_weights") 327 | # dropping out the attention links for each of the heads 328 | weights = tf.nn.dropout(weights, 1.0 - dropout_rate) 329 | if summaries and not tf.get_variable_scope().reuse: 330 | attention_image_summary(weights, image_shapes) 331 | return tf.matmul(weights, v) 332 | 333 | 334 | def multihead_attention(query_antecedent, 335 | memory_antecedent, 336 | bias, 337 | total_key_depth, 338 | total_value_depth, 339 | output_depth, 340 | num_heads, 341 | dropout_rate, 342 | summaries=False, 343 | image_shapes=None, 344 | name=None): 345 | """Multihead scaled-dot-product attention with input/output transformations. 346 | 347 | Args: 348 | query_antecedent: a Tensor with shape [batch, length_q, channels] 349 | memory_antecedent: a Tensor with shape [batch, length_m, channels] 350 | bias: bias Tensor (see attention_bias()) 351 | total_key_depth: an integer 352 | total_value_depth: an integer 353 | output_depth: an integer 354 | num_heads: an integer dividing total_key_depth and total_value_depth 355 | dropout_rate: a floating point number 356 | summaries: a boolean 357 | image_shapes: optional quadruple of integer scalars for image summary. 358 | If the query positions and memory positions represent the 359 | pixels of a flattened image, then pass in their dimensions: 360 | (query_rows, query_cols, memory_rows, memory_cols). 361 | name: an optional string 362 | 363 | Returns: 364 | A Tensor. 365 | """ 366 | with tf.variable_scope( 367 | name, 368 | default_name="multihead_attention", 369 | values=[query_antecedent, memory_antecedent]): 370 | if memory_antecedent is None: 371 | # self attention 372 | combined = common_layers.conv1d( 373 | query_antecedent, 374 | total_key_depth * 2 + total_value_depth, 375 | 1, 376 | name="qkv_transform") 377 | q, k, v = tf.split( 378 | combined, [total_key_depth, total_key_depth, total_value_depth], 379 | axis=2) 380 | else: 381 | q = common_layers.conv1d( 382 | query_antecedent, total_key_depth, 1, name="q_transform") 383 | combined = common_layers.conv1d( 384 | memory_antecedent, 385 | total_key_depth + total_value_depth, 386 | 1, 387 | name="kv_transform") 388 | k, v = tf.split(combined, [total_key_depth, total_value_depth], axis=2) 389 | q = split_heads(q, num_heads) 390 | k = split_heads(k, num_heads) 391 | v = split_heads(v, num_heads) 392 | key_depth_per_head = total_key_depth // num_heads 393 | q *= key_depth_per_head**-0.5 394 | x = dot_product_attention( 395 | q, k, v, bias, dropout_rate, summaries, image_shapes) 396 | x = combine_heads(x) 397 | x = common_layers.conv1d(x, output_depth, 1, name="output_transform") 398 | return x 399 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | from argparse import ArgumentParser 5 | import tensorflow as tf 6 | import yaml 7 | 8 | from evaluate import Evaluator 9 | from models import * 10 | from utils import DataReader, AttrDict, available_variables, expand_feed_dict 11 | 12 | 13 | class BreakLoopException(Exception): 14 | pass 15 | 16 | 17 | def train(config): 18 | """Train a model with a config file.""" 19 | logger = logging.getLogger('') 20 | data_reader = DataReader(config=config) 21 | model = eval(config.model)(config=config, num_gpus=config.train.num_gpus) 22 | model.build_train_model(test=config.train.eval_on_dev) 23 | 24 | train_op, loss_op = model.get_train_op(name=None) 25 | global_saver = tf.train.Saver() 26 | 27 | sess_config = tf.ConfigProto() 28 | sess_config.gpu_options.allow_growth = True 29 | sess_config.allow_soft_placement = True 30 | 31 | summary_writer = tf.summary.FileWriter(config.model_dir) 32 | 33 | with tf.Session(config=sess_config) as sess: 34 | # Initialize all variables. 35 | sess.run(tf.global_variables_initializer()) 36 | # Reload variables from disk. 37 | if tf.train.latest_checkpoint(config.model_dir): 38 | available_vars = available_variables(config.model_dir) 39 | if available_vars: 40 | saver = tf.train.Saver(var_list=available_vars) 41 | saver.restore(sess, tf.train.latest_checkpoint(config.model_dir)) 42 | for v in available_vars: 43 | logger.info('Reload {} from disk.'.format(v.name)) 44 | else: 45 | logger.info('Nothing to be reload from disk.') 46 | else: 47 | logger.info('Nothing to be reload from disk.') 48 | 49 | evaluator = Evaluator() 50 | evaluator.init_from_existed(model, sess, data_reader) 51 | 52 | global dev_bleu, toleration 53 | dev_bleu = evaluator.evaluate(**config.dev) if config.train.eval_on_dev else 0 54 | toleration = config.train.toleration 55 | 56 | def train_one_step(batch, loss_op, train_op): 57 | feed_dict = expand_feed_dict({model.src_pls: batch[0], model.dst_pls: batch[1]}) 58 | step, lr, loss, _ = sess.run( 59 | [model.global_step, model.learning_rate, 60 | loss_op, train_op], 61 | feed_dict=feed_dict) 62 | if step % config.train.summary_freq == 0: 63 | summary = sess.run(model.summary_op, feed_dict=feed_dict) 64 | summary_writer.add_summary(summary, global_step=step) 65 | return step, lr, loss 66 | 67 | def maybe_save_model(): 68 | global dev_bleu, toleration 69 | 70 | def save(): 71 | mp = config.model_dir + '/model_step_{}'.format(step) 72 | global_saver.save(sess, mp) 73 | logger.info('Save model in %s.' % mp) 74 | 75 | if config.train.eval_on_dev: 76 | new_dev_bleu = evaluator.evaluate(**config.dev) 77 | if config.train.toleration is None: 78 | save() 79 | else: 80 | if new_dev_bleu >= dev_bleu: 81 | save() 82 | toleration = config.train.toleration 83 | dev_bleu = new_dev_bleu 84 | else: 85 | toleration -= 1 86 | else: 87 | save() 88 | 89 | try: 90 | step = 0 91 | for epoch in range(1, config.train.num_epochs+1): 92 | for batch in data_reader.get_training_batches(epoches=1): 93 | 94 | # Train normal instances. 95 | start_time = time.time() 96 | step, lr, loss = train_one_step(batch, loss_op, train_op) 97 | logger.info( 98 | 'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tloss: {3:.4f}\ttime: {4:.4f}'. 99 | format(epoch, step, lr, loss, time.time() - start_time)) 100 | # Save model 101 | if config.train.save_freq > 0 \ 102 | and step > 0 \ 103 | and step % config.train.save_freq == 0: 104 | maybe_save_model() 105 | 106 | if config.train.num_steps is not None and step >= config.train.num_steps: 107 | raise BreakLoopException("BreakLoop") 108 | 109 | if toleration is not None and toleration <= 0: 110 | raise BreakLoopException("BreakLoop") 111 | 112 | # Save model per epoch if config.train.save_freq is less or equal than zero 113 | if config.train.save_freq <= 0: 114 | maybe_save_model() 115 | except BreakLoopException as e: 116 | logger.info(e) 117 | 118 | logger.info("Finish training.") 119 | 120 | 121 | if __name__ == '__main__': 122 | parser = ArgumentParser() 123 | parser.add_argument('-c', '--config', dest='config') 124 | args = parser.parse_args() 125 | # Read config 126 | config = AttrDict(yaml.load(open(args.config))) 127 | # Logger 128 | if not os.path.exists(config.model_dir): 129 | os.makedirs(config.model_dir) 130 | logging.basicConfig(filename=config.model_dir + '/train.log', level=logging.INFO) 131 | console = logging.StreamHandler() 132 | console.setLevel(logging.INFO) 133 | logging.getLogger('').addHandler(console) 134 | # Train 135 | train(config) 136 | -------------------------------------------------------------------------------- /train_wkd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | from argparse import ArgumentParser 5 | import tensorflow as tf 6 | import yaml 7 | 8 | from evaluate import Evaluator 9 | from models import * 10 | from utils import DataReader, AttrDict, available_variables, expand_feed_dict 11 | 12 | 13 | class BreakLoopException(Exception): 14 | pass 15 | 16 | 17 | def wrap_scope(input_ckpt_path, output_ckpt_path, scope): 18 | with tf.Graph().as_default(): 19 | with tf.Session() as sess: 20 | with tf.variable_scope(scope): 21 | var_list = tf.contrib.framework.list_variables(input_ckpt_path) 22 | var_names, var_shapes = zip(*var_list) 23 | reader = tf.contrib.framework.load_checkpoint(input_ckpt_path) 24 | var_values = [reader.get_tensor(name) for name in var_names] 25 | new_var_list = [tf.get_variable(name, initializer=value) 26 | for name, value in zip(var_names, var_values)] 27 | sess.run(tf.global_variables_initializer()) 28 | saver = tf.train.Saver(new_var_list) 29 | saver.save(sess, output_ckpt_path) 30 | 31 | 32 | def train(config, teacher_config): 33 | """Train a model with a config file.""" 34 | logger = logging.getLogger('') 35 | data_reader = DataReader(config=config) 36 | model = eval(config.model)(config=config, num_gpus=config.train.num_gpus) 37 | with tf.variable_scope('teacher'): 38 | teacher_model = eval(teacher_config.model)(config=teacher_config, num_gpus=0) 39 | model.build_train_model(test=config.train.eval_on_dev, teacher_model=teacher_model) 40 | 41 | train_op, loss_op = model.get_train_op(name=None) 42 | global_saver = tf.train.Saver([v for v in tf.global_variables() if not v.name.startswith('teacher')]) 43 | 44 | sess_config = tf.ConfigProto() 45 | sess_config.gpu_options.allow_growth = True 46 | sess_config.allow_soft_placement = True 47 | 48 | summary_writer = tf.summary.FileWriter(config.model_dir) 49 | 50 | with tf.Session(config=sess_config) as sess: 51 | # Initialize all variables. 52 | sess.run(tf.global_variables_initializer()) 53 | 54 | # Reload teacher variables from disk. 55 | logger.info('Load teacher model parameters...') 56 | teacher_vars = tf.global_variables('teacher') 57 | teacher_saver = tf.train.Saver(var_list=teacher_vars) 58 | tmp_ckpt = '/tmp/teacher-{}.ckpt'.format(os.getpid()) 59 | wrap_scope(tf.train.latest_checkpoint(teacher_config.model_dir), tmp_ckpt, 'teacher') 60 | teacher_saver.restore(sess, tmp_ckpt) 61 | for v in teacher_vars: 62 | logger.info('Reload {} from disk.'.format(v.name)) 63 | 64 | # Reload student variables from disk. 65 | logger.info('Load student model parameters...') 66 | if tf.train.latest_checkpoint(config.model_dir): 67 | available_vars = available_variables(config.model_dir) 68 | if available_vars: 69 | saver = tf.train.Saver(var_list=available_vars) 70 | saver.restore(sess, tf.train.latest_checkpoint(config.model_dir)) 71 | for v in available_vars: 72 | logger.info('Reload {} from disk.'.format(v.name)) 73 | else: 74 | logger.info('Nothing to be reload from disk.') 75 | else: 76 | logger.info('Nothing to be reload from disk.') 77 | 78 | evaluator = Evaluator() 79 | evaluator.init_from_existed(model, sess, data_reader) 80 | 81 | global dev_bleu, toleration 82 | dev_bleu = evaluator.evaluate(**config.dev) if config.train.eval_on_dev else 0 83 | toleration = config.train.toleration 84 | 85 | def train_one_step(batch, loss_op, train_op): 86 | feed_dict = expand_feed_dict({model.src_pls: batch[0], model.dst_pls: batch[1]}) 87 | step, lr, loss, _ = sess.run( 88 | [model.global_step, model.learning_rate, 89 | loss_op, train_op], 90 | feed_dict=feed_dict) 91 | if step % config.train.summary_freq == 0: 92 | summary = sess.run(model.summary_op, feed_dict=feed_dict) 93 | summary_writer.add_summary(summary, global_step=step) 94 | return step, lr, loss 95 | 96 | def maybe_save_model(): 97 | global dev_bleu, toleration 98 | 99 | def save(): 100 | mp = config.model_dir + '/model_step_{}'.format(step) 101 | global_saver.save(sess, mp) 102 | logger.info('Save model in %s.' % mp) 103 | 104 | if config.train.eval_on_dev: 105 | new_dev_bleu = evaluator.evaluate(**config.dev) 106 | if config.train.toleration is None: 107 | save() 108 | else: 109 | if new_dev_bleu >= dev_bleu: 110 | save() 111 | toleration = config.train.toleration 112 | dev_bleu = new_dev_bleu 113 | else: 114 | toleration -= 1 115 | else: 116 | save() 117 | 118 | try: 119 | step = 0 120 | for epoch in range(1, config.train.num_epochs+1): 121 | for batch in data_reader.get_training_batches(epoches=1): 122 | 123 | # Train normal instances. 124 | start_time = time.time() 125 | step, lr, loss = train_one_step(batch, loss_op, train_op) 126 | logger.info( 127 | 'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tloss: {3:.4f}\ttime: {4:.4f}'. 128 | format(epoch, step, lr, loss, time.time() - start_time)) 129 | # Save model 130 | if config.train.save_freq > 0 \ 131 | and step > 0 \ 132 | and step % config.train.save_freq == 0: 133 | maybe_save_model() 134 | 135 | if config.train.num_steps is not None and step >= config.train.num_steps: 136 | raise BreakLoopException("BreakLoop") 137 | 138 | if toleration is not None and toleration <= 0: 139 | raise BreakLoopException("BreakLoop") 140 | 141 | # Save model per epoch if config.train.save_freq is less or equal than zero 142 | if config.train.save_freq <= 0: 143 | maybe_save_model() 144 | except BreakLoopException as e: 145 | logger.info(e) 146 | 147 | logger.info("Finish training.") 148 | 149 | 150 | if __name__ == '__main__': 151 | parser = ArgumentParser() 152 | parser.add_argument('-c', '--config', dest='config') 153 | parser.add_argument('-t', '--teacher_config', dest='teacher_config') 154 | args = parser.parse_args() 155 | # Read config 156 | config = AttrDict(yaml.load(open(args.config))) 157 | teacher_config = AttrDict(yaml.load(open(args.teacher_config))) 158 | # Logger 159 | if not os.path.exists(config.model_dir): 160 | os.makedirs(config.model_dir) 161 | logging.basicConfig(filename=config.model_dir + '/train.log', level=logging.INFO) 162 | console = logging.StreamHandler() 163 | console.setLevel(logging.INFO) 164 | logging.getLogger('').addHandler(console) 165 | # Train 166 | train(config, teacher_config) 167 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import codecs 4 | import logging 5 | import os 6 | import time 7 | from itertools import izip 8 | from tempfile import mkstemp 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | import tensorflow.contrib.framework as tff 13 | from tensorflow.python.layers import base as base_layer 14 | 15 | from third_party.tensor2tensor import common_layers, common_attention 16 | common_layers.allow_defun = False 17 | 18 | 19 | class AttrDict(dict): 20 | """ 21 | Dictionary whose keys can be accessed as attributes. 22 | """ 23 | 24 | def __init__(self, *args, **kwargs): 25 | super(AttrDict, self).__init__(*args, **kwargs) 26 | 27 | def __getattr__(self, item): 28 | if item not in self: 29 | logging.warning('{} is not in the dict. None is returned as default.'.format(item)) 30 | return None 31 | if type(self[item]) is dict: 32 | self[item] = AttrDict(self[item]) 33 | return self[item] 34 | 35 | 36 | class DataReader(object): 37 | """ 38 | Read data and create batches for training and testing. 39 | """ 40 | 41 | def __init__(self, config): 42 | self._config = config 43 | self._tmps = set() 44 | self.load_vocab() 45 | 46 | def __del__(self): 47 | for fname in self._tmps: 48 | if os.path.exists(fname): 49 | os.remove(fname) 50 | 51 | def load_vocab(self): 52 | """ 53 | Load vocab from disk. 54 | The first four items in the vocab should be , , , 55 | """ 56 | 57 | def load_vocab_(path, vocab_size): 58 | vocab = [line.split()[0] for line in codecs.open(path, 'r', 'utf-8')] 59 | vocab = vocab[:vocab_size] 60 | assert len(vocab) == vocab_size 61 | word2idx = {word: idx for idx, word in enumerate(vocab)} 62 | idx2word = {idx: word for idx, word in enumerate(vocab)} 63 | return word2idx, idx2word 64 | 65 | logging.debug('Load vocabularies %s and %s.' % (self._config.src_vocab, self._config.dst_vocab)) 66 | self.src2idx, self.idx2src = load_vocab_(self._config.src_vocab, self._config.src_vocab_size) 67 | self.dst2idx, self.idx2dst = load_vocab_(self._config.dst_vocab, self._config.dst_vocab_size) 68 | 69 | def get_training_batches(self, shuffle=True, epoches=None): 70 | """ 71 | Generate batches according to bucket setting. 72 | """ 73 | buckets = [(i, i) for i in range(5, 1000000, 3)] 74 | 75 | def select_bucket(sl, dl): 76 | for l1, l2 in buckets: 77 | if sl < l1 and dl < l2: 78 | return l1, l2 79 | raise Exception("The sequence is too long: ({}, {})".format(sl, dl)) 80 | 81 | # Shuffle the training files. 82 | src_path = self._config.train.src_path 83 | dst_path = self._config.train.dst_path 84 | max_length = self._config.train.max_length 85 | 86 | epoch = [0] 87 | 88 | def stop_condition(): 89 | if epoches is None: 90 | return True 91 | else: 92 | epoch[0] += 1 93 | return epoch[0] < epoches + 1 94 | 95 | while stop_condition(): 96 | if shuffle: 97 | logging.debug('Shuffle files %s and %s.' % (src_path, dst_path)) 98 | src_shuf_path, dst_shuf_path = self.shuffle([src_path, dst_path]) 99 | self._tmps.add(src_shuf_path) 100 | self._tmps.add(dst_shuf_path) 101 | else: 102 | src_shuf_path = src_path 103 | dst_shuf_path = dst_path 104 | 105 | caches = {} 106 | for bucket in buckets: 107 | caches[bucket] = [[], [], 0, 0] # src sentences, dst sentences, src tokens, dst tokens 108 | 109 | for src_sent, dst_sent in izip(open(src_shuf_path, 'r'), open(dst_shuf_path, 'r')): 110 | src_sent, dst_sent = src_sent.decode('utf8'), dst_sent.decode('utf8') 111 | 112 | src_sent = src_sent.split() 113 | dst_sent = dst_sent.split() 114 | 115 | # A special data augment method for training PTransformer model. 116 | # if self._config.model == 'PTransformer' and self._config.data_augment: 117 | # s = np.random.randint(2-self._config.num_parallel, self._config.num_parallel) 118 | # s = max(0, s) 119 | # s = [''] * s 120 | # src_sent = s + src_sent 121 | # dst_sent = s + dst_sent 122 | 123 | if len(src_sent) > max_length or len(dst_sent) > max_length: 124 | continue 125 | 126 | bucket = select_bucket(len(src_sent), len(dst_sent)) 127 | if bucket is None: # No bucket is selected when the sentence length exceed the max length. 128 | continue 129 | 130 | caches[bucket][0].append(src_sent) 131 | caches[bucket][1].append(dst_sent) 132 | caches[bucket][2] += len(src_sent) 133 | caches[bucket][3] += len(dst_sent) 134 | 135 | if max(caches[bucket][2], caches[bucket][3]) >= self._config.train.tokens_per_batch: 136 | batch = (self.create_batch(caches[bucket][0], o='src'), self.create_batch(caches[bucket][1], o='dst')) 137 | logging.debug( 138 | 'Yield batch with source shape %s and target shape %s.' % (batch[0].shape, batch[1].shape)) 139 | yield batch 140 | caches[bucket] = [[], [], 0, 0] 141 | 142 | # Clean remain sentences. 143 | for bucket in buckets: 144 | # Ensure each device at least get one sample. 145 | if len(caches[bucket][0]) >= max(1, self._config.train.num_gpus): 146 | batch = (self.create_batch(caches[bucket][0], o='src'), self.create_batch(caches[bucket][1], o='dst')) 147 | logging.debug( 148 | 'Yield batch with source shape %s and target shape %s.' % (batch[0].shape, batch[1].shape)) 149 | yield batch 150 | 151 | # Remove shuffled files when epoch finished. 152 | if shuffle: 153 | os.remove(src_shuf_path) 154 | os.remove(dst_shuf_path) 155 | self._tmps.remove(src_shuf_path) 156 | self._tmps.remove(dst_shuf_path) 157 | 158 | @staticmethod 159 | def shuffle(list_of_files): 160 | tf_os, tpath = mkstemp() 161 | tf = open(tpath, 'w') 162 | 163 | fds = [open(ff) for ff in list_of_files] 164 | 165 | for l in fds[0]: 166 | lines = [l.strip()] + [ff.readline().strip() for ff in fds[1:]] 167 | print("".join(lines), file=tf) 168 | 169 | [ff.close() for ff in fds] 170 | tf.close() 171 | 172 | os.system('shuf %s > %s' % (tpath, tpath + '.shuf')) 173 | 174 | fnames = ['/tmp/{}.{}.{}.shuf'.format(i, os.getpid(), time.time()) for i, ff in enumerate(list_of_files)] 175 | fds = [open(fn, 'w') for fn in fnames] 176 | 177 | for l in open(tpath + '.shuf'): 178 | s = l.strip().split('') 179 | for i, fd in enumerate(fds): 180 | print(s[i], file=fd) 181 | 182 | [ff.close() for ff in fds] 183 | 184 | os.remove(tpath) 185 | os.remove(tpath + '.shuf') 186 | 187 | return fnames 188 | 189 | def get_test_batches(self, src_path, batch_size): 190 | # Read batches for testing. 191 | src_sents = [] 192 | for src_sent in open(src_path, 'r'): 193 | src_sent = src_sent.decode('utf8') 194 | src_sent = src_sent.split() 195 | src_sents.append(src_sent) 196 | # Create a padded batch. 197 | if len(src_sents) >= batch_size: 198 | yield self.create_batch(src_sents, o='src') 199 | src_sents = [] 200 | if src_sents: 201 | # We ensure batch size not small than gpu number by padding redundant samples. 202 | if len(src_sents) < self._config.test.num_gpus: 203 | src_sents.extend([src_sents[-1]] * self._config.test.num_gpus) 204 | yield self.create_batch(src_sents, o='src') 205 | 206 | def get_test_batches_with_target(self, src_path, dst_path, batch_size): 207 | """ 208 | Usually we don't need target sentences for test unless we want to compute PPl. 209 | Returns: 210 | Paired source and target batches. 211 | """ 212 | 213 | src_sents, dst_sents = [], [] 214 | for src_sent, dst_sent in izip(open(src_path, 'r'), open(dst_path, 'r')): 215 | src_sent, dst_sent = src_sent.decode('utf8'), dst_sent.decode('utf8') 216 | src_sent = src_sent.split() 217 | dst_sent = dst_sent.split() 218 | src_sents.append(src_sent) 219 | dst_sents.append(dst_sent) 220 | # Create a padded batch. 221 | if len(src_sents) >= batch_size: 222 | yield self.create_batch(src_sents, o='src'), self.create_batch(dst_sents, o='dst') 223 | src_sents, dst_sents = [], [] 224 | if src_sents: 225 | yield self.create_batch(src_sents, o='src'), self.create_batch(dst_sents, o='dst') 226 | 227 | def create_batch(self, sents, o): 228 | # Convert words to indices. 229 | assert o in ('src', 'dst') 230 | word2idx = self.src2idx if o == 'src' else self.dst2idx 231 | indices = [] 232 | for sent in sents: 233 | x = [word2idx.get(word, 1) for word in (sent + [u""])] # 1: OOV,
: End of Text 234 | indices.append(x) 235 | 236 | # Pad to the same length. 237 | maxlen = max([len(s) for s in indices]) 238 | X = np.zeros([len(indices), maxlen], np.int32) 239 | for i, x in enumerate(indices): 240 | X[i, :len(x)] = x 241 | 242 | return X 243 | 244 | def indices_to_words(self, Y, o='dst'): 245 | assert o in ('src', 'dst') 246 | idx2word = self.idx2src if o == 'src' else self.idx2dst 247 | sents = [] 248 | for y in Y: # for each sentence 249 | sent = [] 250 | for i in y: # For each word 251 | if i == 3: #
252 | break 253 | w = idx2word[i] 254 | sent.append(w) 255 | sents.append(' '.join(sent)) 256 | return sents 257 | 258 | 259 | def expand_feed_dict(feed_dict): 260 | """If the key is a tuple of placeholders, 261 | split the input data then feed them into these placeholders. 262 | """ 263 | new_feed_dict = {} 264 | for k, v in feed_dict.items(): 265 | if type(k) is not tuple: 266 | new_feed_dict[k] = v 267 | else: 268 | # Split v along the first dimension. 269 | n = len(k) 270 | batch_size = v.shape[0] 271 | assert batch_size > 0 272 | span = batch_size // n 273 | remainder = batch_size % n 274 | base = 0 275 | for i, p in enumerate(k): 276 | if i < remainder: 277 | end = base + span + 1 278 | else: 279 | end = base + span 280 | new_feed_dict[p] = v[base: end] 281 | base = end 282 | return new_feed_dict 283 | 284 | 285 | def available_variables(checkpoint_dir): 286 | all_vars = tf.global_variables() 287 | all_available_vars = tff.list_variables(checkpoint_dir=checkpoint_dir) 288 | all_available_vars = dict(all_available_vars) 289 | available_vars = [] 290 | for v in all_vars: 291 | vname = v.name.split(':')[0] 292 | if vname in all_available_vars and v.get_shape() == all_available_vars[vname]: 293 | available_vars.append(v) 294 | return available_vars 295 | 296 | 297 | def average_gradients(tower_grads): 298 | """Calculate the average gradient for each shared variable across all towers. 299 | Note that this function provides a synchronization point across all towers. 300 | Args: 301 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 302 | is over individual gradients. The inner list is over the gradient 303 | calculation for each tower. 304 | Returns: 305 | List of pairs of (gradient, variable) where the gradient has been averaged 306 | across all towers. 307 | """ 308 | average_grads = [] 309 | for grad_and_vars in zip(*tower_grads): 310 | # Note that each grad_and_vars looks like the following: 311 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 312 | grads = [] 313 | for g, _ in grad_and_vars: 314 | # Add 0 dimension to the gradients to represent the tower. 315 | expanded_g = tf.expand_dims(g, 0) 316 | 317 | # Append on a 'tower' dimension which we will average over below. 318 | grads.append(expanded_g) 319 | else: 320 | # Average over the 'tower' dimension. 321 | grad = tf.concat(axis=0, values=grads) 322 | grad = tf.reduce_mean(grad, 0) 323 | 324 | # Keep in mind that the Variables are redundant because they are shared 325 | # across towers. So .. we will just return the first tower's pointer to 326 | # the Variable. 327 | v = grad_and_vars[0][1] 328 | grad_and_var = (grad, v) 329 | average_grads.append(grad_and_var) 330 | return average_grads 331 | 332 | 333 | def residual(inputs, outputs, dropout_rate): 334 | """Residual connection. 335 | 336 | Args: 337 | inputs: A Tensor. 338 | outputs: A Tensor. 339 | dropout_rate: A float range from [0, 1). 340 | 341 | Returns: 342 | A Tensor. 343 | """ 344 | outputs = inputs + tf.nn.dropout(outputs, 1 - dropout_rate) 345 | outputs = common_layers.layer_norm(outputs) 346 | return outputs 347 | 348 | 349 | def learning_rate_decay(config, global_step): 350 | """Inverse-decay learning rate until warmup_steps, then decay.""" 351 | warmup_steps = tf.to_float(config.train.warmup_steps) 352 | global_step = tf.to_float(global_step) 353 | return config.hidden_units ** -0.5 * tf.minimum( 354 | (global_step + 1.0) * warmup_steps ** -1.5, (global_step + 1.0) ** -0.5) 355 | 356 | 357 | def shift_right(input, pad=2): 358 | """Shift input tensor right to create decoder input. '2' denotes """ 359 | return tf.concat((tf.ones_like(input[:, :1]) * pad, input[:, :-1]), 1) 360 | 361 | 362 | def embedding(x, vocab_size, dense_size, name=None, reuse=None, kernel=None, multiplier=1.0): 363 | """Embed x of type int64 into dense vectors.""" 364 | with tf.variable_scope( 365 | name, default_name="embedding", values=[x], reuse=reuse): 366 | if kernel is not None: 367 | embedding_var = kernel 368 | else: 369 | embedding_var = tf.get_variable("kernel", [vocab_size, dense_size]) 370 | output = tf.gather(embedding_var, x) 371 | if multiplier != 1.0: 372 | output *= multiplier 373 | return output 374 | 375 | 376 | def dense(inputs, 377 | output_size, 378 | activation=tf.identity, 379 | use_bias=True, 380 | kernel=None, 381 | reuse=None, 382 | name=None): 383 | argcount = activation.func_code.co_argcount 384 | if activation.func_defaults: 385 | argcount -= len(activation.func_defaults) 386 | assert argcount in (1, 2) 387 | with tf.variable_scope(name, "dense", reuse=reuse): 388 | if argcount == 1: 389 | input_size = inputs.get_shape().as_list()[-1] 390 | inputs_shape = tf.unstack(tf.shape(inputs)) 391 | inputs = tf.reshape(inputs, [-1, input_size]) 392 | if kernel is not None: 393 | assert kernel.get_shape().as_list()[0] == output_size 394 | w = kernel 395 | else: 396 | with tf.variable_scope(tf.get_variable_scope()): 397 | w = tf.get_variable("kernel", [output_size, input_size]) 398 | outputs = tf.matmul(inputs, w, transpose_b=True) 399 | if use_bias: 400 | b = tf.get_variable("bias", [output_size], initializer=tf.zeros_initializer) 401 | outputs += b 402 | outputs = activation(outputs) 403 | return tf.reshape(outputs, inputs_shape[:-1] + [output_size]) 404 | else: 405 | arg1 = dense(inputs, output_size, tf.identity, use_bias, name='arg1') 406 | arg2 = dense(inputs, output_size, tf.identity, use_bias, name='arg2') 407 | return activation(arg1, arg2) 408 | 409 | 410 | def ff_hidden(inputs, hidden_size, output_size, activation, use_bias=True, reuse=None, name=None): 411 | with tf.variable_scope(name, "ff_hidden", reuse=reuse): 412 | hidden_outputs = dense(inputs, hidden_size, activation, use_bias) 413 | outputs = dense(hidden_outputs, output_size, tf.identity, use_bias) 414 | return outputs 415 | 416 | 417 | def multihead_attention(query_antecedent, 418 | memory_antecedent, 419 | bias, 420 | total_key_depth, 421 | total_value_depth, 422 | output_depth, 423 | num_heads, 424 | dropout_rate, 425 | num_queries=None, 426 | query_eq_key=False, 427 | summaries=False, 428 | image_shapes=None, 429 | name=None): 430 | """Multihead scaled-dot-product attention with input/output transformations. 431 | 432 | Args: 433 | query_antecedent: a Tensor with shape [batch, length_q, channels] 434 | memory_antecedent: a Tensor with shape [batch, length_m, channels] 435 | bias: bias Tensor (see attention_bias()) 436 | total_key_depth: an integer 437 | total_value_depth: an integer 438 | output_depth: an integer 439 | num_heads: an integer dividing total_key_depth and total_value_depth 440 | dropout_rate: a floating point number 441 | num_queries: a int or None 442 | query_eq_key: a boolean 443 | summaries: a boolean 444 | image_shapes: optional quadruple of integer scalars for image summary. 445 | If the query positions and memory positions represent the 446 | pixels of a flattened image, then pass in their dimensions: 447 | (query_rows, query_cols, memory_rows, memory_cols). 448 | name: an optional string 449 | 450 | Returns: 451 | A Tensor. 452 | """ 453 | with tf.variable_scope( 454 | name, 455 | default_name="multihead_attention", 456 | values=[query_antecedent, memory_antecedent]): 457 | 458 | if not query_eq_key: 459 | if memory_antecedent is None: 460 | # Q = K = V 461 | # self attention 462 | combined = dense(query_antecedent, total_key_depth * 2 + total_value_depth, name="qkv_transform") 463 | q, k, v = tf.split( 464 | combined, [total_key_depth, total_key_depth, total_value_depth], 465 | axis=2) 466 | else: 467 | # Q != K = V 468 | q = dense(query_antecedent, total_key_depth, name="q_transform") 469 | combined = dense(memory_antecedent, total_key_depth + total_value_depth, name="kv_transform") 470 | k, v = tf.split(combined, [total_key_depth, total_value_depth], axis=2) 471 | else: 472 | # In this setting, we use query_antecedent as the query and key, 473 | # and use memory_antecedent as the value. 474 | assert memory_antecedent is not None 475 | combined = dense(query_antecedent, total_key_depth * 2, name="qk_transform") 476 | q, k = tf.split( 477 | combined, [total_key_depth, total_key_depth], 478 | axis=2) 479 | v = dense(memory_antecedent, total_value_depth, name='v_transform') 480 | 481 | if num_queries: 482 | q = q[:, -num_queries:, :] 483 | 484 | q = common_attention.split_heads(q, num_heads) 485 | k = common_attention.split_heads(k, num_heads) 486 | v = common_attention.split_heads(v, num_heads) 487 | key_depth_per_head = total_key_depth // num_heads 488 | q *= key_depth_per_head**-0.5 489 | x = common_attention.dot_product_attention( 490 | q, k, v, bias, dropout_rate, summaries, image_shapes) 491 | x = common_attention.combine_heads(x) 492 | x = dense(x, output_depth, name="output_transform") 493 | return x 494 | 495 | 496 | class AttentionGRUCell(tf.nn.rnn_cell.GRUCell): 497 | def __init__(self, 498 | num_units, 499 | attention_memories, 500 | attention_bias=None, 501 | activation=None, 502 | reuse=None, 503 | kernel_initializer=None, 504 | bias_initializer=None, 505 | name=None): 506 | super(AttentionGRUCell, self).__init__( 507 | num_units=num_units, 508 | activation=activation, 509 | reuse=reuse, 510 | kernel_initializer=kernel_initializer, 511 | bias_initializer=bias_initializer, 512 | name=name) 513 | with tf.variable_scope(name, "AttentionGRUCell", reuse=reuse): 514 | self._attention_keys = dense(attention_memories, num_units, name='attention_key') 515 | self._attention_values = dense(attention_memories, num_units, name='attention_value') 516 | self._attention_bias = attention_bias 517 | 518 | def attention(self, inputs, state): 519 | attention_query = tf.matmul( 520 | tf.concat([inputs, state], 1), self._attention_query_kernel) 521 | attention_query = tf.nn.bias_add(attention_query, self._attention_query_bias) 522 | 523 | alpha = tf.tanh(attention_query[:, None, :] + self._attention_keys) 524 | alpha = dense(alpha, 1, kernel=self._alpha_kernel, name='attention') 525 | if self._attention_bias is not None: 526 | alpha += self._attention_bias 527 | alpha = tf.nn.softmax(alpha, axis=1) 528 | 529 | context = tf.multiply(self._attention_values, alpha) 530 | context = tf.reduce_sum(context, axis=1) 531 | 532 | return context 533 | 534 | def call(self, inputs, state): 535 | context = self.attention(inputs, state) 536 | inputs = tf.concat([inputs, context], axis=1) 537 | return super(AttentionGRUCell, self).call(inputs, state) 538 | 539 | def build(self, inputs_shape): 540 | if inputs_shape[-1].value is None: 541 | raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" 542 | % inputs_shape) 543 | 544 | input_depth = inputs_shape[1].value 545 | self._gate_kernel = self.add_variable( 546 | "gates/weights", 547 | shape=[input_depth + 2 * self._num_units, 2 * self._num_units], 548 | initializer=self._kernel_initializer) 549 | self._gate_bias = self.add_variable( 550 | "gates/bias", 551 | shape=[2 * self._num_units], 552 | initializer=( 553 | self._bias_initializer 554 | if self._bias_initializer is not None 555 | else tf.constant_initializer(1.0, dtype=self.dtype))) 556 | self._candidate_kernel = self.add_variable( 557 | "candidate/weights", 558 | shape=[input_depth + 2 * self._num_units, self._num_units], 559 | initializer=self._kernel_initializer) 560 | self._candidate_bias = self.add_variable( 561 | "candidate/bias", 562 | shape=[self._num_units], 563 | initializer=( 564 | self._bias_initializer 565 | if self._bias_initializer is not None 566 | else tf.zeros_initializer(dtype=self.dtype))) 567 | 568 | self._attention_query_kernel = self.add_variable( 569 | "attention_query/weight", 570 | shape=[input_depth + self._num_units, self._num_units], 571 | initializer=self._kernel_initializer) 572 | self._attention_query_bias = self.add_variable( 573 | "attention_query/bias", 574 | shape=[self._num_units], 575 | initializer=( 576 | self._bias_initializer 577 | if self._bias_initializer is not None 578 | else tf.constant_initializer(1.0, dtype=self.dtype))) 579 | self._alpha_kernel = self.add_variable( 580 | 'alpha_kernel', 581 | shape=[1, self._num_units], 582 | initializer=self._kernel_initializer) 583 | self.built = True 584 | 585 | 586 | class IndRNNCell(tf.nn.rnn_cell.RNNCell): 587 | """The independent RNN cell.""" 588 | 589 | def __init__(self, num_units, recurrent_initializer=None, reuse=None, name=None): 590 | super(IndRNNCell, self).__init__(_reuse=reuse, name=name) 591 | 592 | # Inputs must be 2-dimensional. 593 | self.input_spec = base_layer.InputSpec(ndim=2) 594 | 595 | self._num_units = num_units 596 | self._recurrent_initializer = recurrent_initializer 597 | self._kernel_initializer = None 598 | self._bias_initializer = tf.constant_initializer(0.0, dtype=tf.float32) 599 | 600 | @property 601 | def state_size(self): 602 | return self._num_units 603 | 604 | @property 605 | def output_size(self): 606 | return self._num_units 607 | 608 | def build(self, inputs_shape): 609 | if inputs_shape[1].value is None: 610 | raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" 611 | % inputs_shape) 612 | 613 | input_depth = inputs_shape[1].value 614 | self._recurrent_kernel = self.add_variable( 615 | "recurrent/weights", 616 | shape=[self._num_units], 617 | initializer=self._recurrent_initializer) 618 | epsilon = np.power(2.0, 1.0/50.0) 619 | self._recurrent_kernel = tf.clip_by_value(self._recurrent_kernel, -epsilon, epsilon) 620 | 621 | self._input_kernel = self.add_variable( 622 | "input/weights", 623 | shape=[input_depth, self._num_units], 624 | initializer=self._kernel_initializer) 625 | 626 | self._bias = self.add_variable( 627 | "bias", 628 | shape=[self._num_units], 629 | initializer=self._bias_initializer) 630 | 631 | self.built = True 632 | 633 | def call(self, inputs, state): 634 | inputs = tf.matmul(inputs, self._input_kernel) 635 | state = tf.multiply(state, self._recurrent_kernel) 636 | output = inputs + state 637 | output = tf.nn.bias_add(output, self._bias) 638 | output = tf.nn.relu(output) 639 | return output, output 640 | 641 | 642 | class AttentionIndRNNCell(IndRNNCell): 643 | def __init__(self, 644 | num_units, 645 | attention_memories, 646 | attention_bias=None, 647 | recurrent_initializer=None, 648 | reuse=None, 649 | name=None): 650 | super(AttentionIndRNNCell, self).__init__(num_units, 651 | recurrent_initializer=recurrent_initializer, 652 | reuse=reuse, name=name) 653 | with tf.variable_scope(name, "AttentionIndRNNCell", reuse=reuse): 654 | self._attention_keys = dense(attention_memories, num_units, name='attention_key') 655 | self._attention_values = dense(attention_memories, num_units, name='attention_value') 656 | self._attention_bias = attention_bias 657 | 658 | def attention(self, inputs, state): 659 | attention_query = tf.matmul( 660 | tf.concat([inputs, state], 1), self._attention_query_kernel) 661 | attention_query = tf.nn.bias_add(attention_query, self._attention_query_bias) 662 | 663 | alpha = tf.tanh(attention_query[:, None, :] + self._attention_keys) 664 | alpha = dense(alpha, 1, kernel=self._alpha_kernel, name='attention') 665 | if self._attention_bias is not None: 666 | alpha += self._attention_bias 667 | alpha = tf.nn.softmax(alpha, axis=1) 668 | self._alpha = alpha 669 | 670 | context = tf.multiply(self._attention_values, alpha) 671 | context = tf.reduce_sum(context, axis=1) 672 | 673 | return context 674 | 675 | def build(self, inputs_shape): 676 | if inputs_shape[1].value is None: 677 | raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" 678 | % inputs_shape) 679 | 680 | input_depth = inputs_shape[1].value 681 | self._recurrent_kernel = self.add_variable( 682 | "recurrent/weights", 683 | shape=[self._num_units], 684 | initializer=self._recurrent_initializer) 685 | epsilon = np.power(2.0, 1.0/50.0) 686 | self._recurrent_kernel = tf.clip_by_value(self._recurrent_kernel, -epsilon, epsilon) 687 | 688 | self._input_kernel = self.add_variable( 689 | "input/weights", 690 | shape=[input_depth + self._num_units, self._num_units], 691 | initializer=self._kernel_initializer) 692 | 693 | self._bias = self.add_variable( 694 | "bias", 695 | shape=[self._num_units], 696 | initializer=self._bias_initializer) 697 | 698 | self._attention_query_kernel = self.add_variable( 699 | "attention_query/weight", 700 | shape=[input_depth + self._num_units, self._num_units], 701 | initializer=self._kernel_initializer) 702 | self._attention_query_bias = self.add_variable( 703 | "attention_query/bias", 704 | shape=[self._num_units], 705 | initializer=self._bias_initializer) 706 | self._alpha_kernel = self.add_variable( 707 | 'alpha_kernel', 708 | shape=[1, self._num_units], 709 | initializer=self._kernel_initializer) 710 | 711 | self.built = True 712 | 713 | def call(self, inputs, state): 714 | context = self.attention(inputs, state) 715 | inputs = tf.concat([inputs, context], axis=1) 716 | return super(AttentionIndRNNCell, self).call(inputs, state) 717 | 718 | def get_attention_weights(self): 719 | return self._alpha 720 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import logging 3 | import os 4 | from argparse import ArgumentParser 5 | from collections import Counter 6 | 7 | import yaml 8 | 9 | from utils import AttrDict 10 | 11 | 12 | def make_vocab(fpath, fname): 13 | """Constructs vocabulary. 14 | 15 | Args: 16 | fpath: A string. Input file path. 17 | fname: A string. Output file name. 18 | 19 | Writes vocabulary line by line to `fname`. 20 | """ 21 | word2cnt = Counter() 22 | for l in codecs.open(fpath, 'r', 'utf-8'): 23 | words = l.split() 24 | word2cnt.update(Counter(words)) 25 | word2cnt.update({"": 10000000000000, 26 | "": 1000000000000, 27 | "": 100000000000, 28 | "": 10000000000}) 29 | with codecs.open(fname, 'w', 'utf-8') as fout: 30 | for word, cnt in word2cnt.most_common(): 31 | fout.write(u"{}\t{}\n".format(word, cnt)) 32 | logging.info('Vocab path: {}\t size: {}'.format(fname, len(word2cnt))) 33 | 34 | 35 | if __name__ == '__main__': 36 | parser = ArgumentParser() 37 | parser.add_argument('-c', '--config', dest='config') 38 | args = parser.parse_args() 39 | # Read config 40 | config = AttrDict(yaml.load(open(args.config))) 41 | logging.basicConfig(level=logging.INFO) 42 | if os.path.exists(config.src_vocab): 43 | logging.info('Source vocab already exists at {}'.format(config.src_vocab)) 44 | else: 45 | make_vocab(config.train.src_path, config.src_vocab) 46 | if os.path.exists(config.dst_vocab): 47 | logging.info('Destination vocab already exists at {}'.format(config.dst_vocab)) 48 | else: 49 | make_vocab(config.train.dst_path, config.dst_vocab) 50 | logging.info("Done") 51 | --------------------------------------------------------------------------------