├── .gitignore ├── LICENSE ├── README.md ├── data.py ├── docs ├── afs_speech_translation │ ├── README.md │ ├── afs_training.png │ └── example.png ├── colactc │ ├── README.md │ ├── colactc.png │ └── mt.png ├── conditional_language_specific_routing │ └── README.md ├── context_aware_st │ ├── README.md │ ├── cast.png │ └── training.png ├── depth_scale_init_and_merged_attention │ ├── README.md │ ├── dsinit.png │ └── grad.png ├── interleaved_bidirectional_transformer │ ├── README.md │ └── overview.png ├── iwslt2021_uoe_submission │ ├── README.md │ └── overview.png ├── l0drop │ ├── README.md │ ├── l0drop-att.png │ ├── l0drop.png │ └── mt_ende.png ├── multilingual_laln_lalt │ ├── README.md │ ├── many-to-many-full-results-per-language.md │ └── many-to-many.xlsx ├── rela_sparse_attention │ ├── README.md │ ├── aer.png │ ├── null.png │ └── rela.png └── usage │ └── README.md ├── evalu.py ├── func.py ├── lrs ├── __init__.py ├── cosinelr.py ├── epochlr.py ├── gnmtplr.py ├── lr.py ├── noamlr.py ├── scorelr.py └── vanillalr.py ├── main.py ├── models ├── __init__.py ├── deepnmt.py ├── model.py ├── rnnsearch.py ├── rnnsearch_deepatt.py ├── transformer.py ├── transformer_aan.py ├── transformer_fixup.py ├── transformer_fuse.py ├── transformer_l0drop.py ├── transformer_rela.py └── transformer_rpr.py ├── modules ├── __init__.py ├── fixup.py ├── initializer.py ├── l0norm.py ├── rela.py └── rpr.py ├── rnns ├── __init__.py ├── atr.py ├── cell.py ├── gru.py ├── lrn.py ├── lstm.py ├── olrn.py ├── rnn.py └── sru.py ├── run.py ├── scripts ├── bleu_over_length.py ├── checkpoint_averaging.py ├── chrF.py ├── evaluate_pos_translation_rate.py ├── multi-bleu-detok.perl ├── multi-bleu.perl └── shuffle_corpus.py ├── search.py ├── utils ├── __init__.py ├── cycle.py ├── dtype.py ├── metric.py ├── parallel.py ├── queuer.py ├── recorder.py ├── saver.py └── util.py └── vocab.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, Biao Zhang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Zero 2 | A neural machine translation system implemented by python2 + tensorflow. 3 | 4 | ## Features 5 | 1. Multi-Process Data Loading/Processing (*Problems Exist*) 6 | 2. Multi-GPU Training/Decoding 7 | 3. Gradient Aggregation 8 | 9 | ## Papers 10 | 11 | We associate each paper below with a readme file link. Please click the paper link you are interested for more details. 12 | 13 | * [Efficient CTC Regularization via Coarse Labels for End-to-End Speech Translation, EACL2023](docs/colactc) 14 | * [Revisiting End-to-End Speech-to-Text Translation From Scratch, ICML2022](https://github.com/bzhangGo/st_from_scratch) 15 | * [Sparse Attention with Linear Units, EMNLP2021](docs/rela_sparse_attention) 16 | * [Edinburgh's End-to-End Multilingual Speech Translation System for IWSLT 2021, IWSLT2021 System submission](docs/iwslt2021_uoe_submission) 17 | * [Beyond Sentence-Level End-to-End Speech Translation: Context Helps, ACL2021](docs/context_aware_st) 18 | * [On Sparsifying Encoder Outputs in Sequence-to-Sequence Models, ACL2021 Findings](docs/l0drop) 19 | * [Share or Not? Learning to Schedule Language-Specific Capacity for Multilingual Translation, ICLR2021](docs/conditional_language_specific_routing) 20 | * [Fast Interleaved Bidirectional Sequence Generation, WMT2020](docs/interleaved_bidirectional_transformer) 21 | * [Adaptive Feature Selection for End-to-End Speech Translation, EMNLP2020 Findings](docs/afs_speech_translation) 22 | * [Improving Massively Multilingual Neural Machine Translation and Zero-Shot Translation, ACL2020](docs/multilingual_laln_lalt) 23 | * [Improving Deep Transformer with Depth-Scaled Initialization and Merged Attention, EMNLP2019](docs/depth_scale_init_and_merged_attention) 24 | 25 | ## Supported Models 26 | * RNNSearch: support LSTM, GRU, SRU, [ATR, EMNLP2018](https://github.com/bzhangGo/ATR), and [LRN, ACL2019](https://github.com/bzhangGo/lrn) 27 | models. 28 | * Deep attention: [Neural Machine Translation with Deep Attention, TPAMI](https://ieeexplore.ieee.org/document/8493282) 29 | * CAEncoder: the context-aware recurrent encoder, see [the paper, TASLP](https://ieeexplore.ieee.org/document/8031316) 30 | and the original [source code](https://github.com/DeepLearnXMU/CAEncoder-NMT) (in Theano). 31 | * Transformer: [attention is all you need](https://arxiv.org/abs/1706.03762) 32 | * AAN: the [average attention model, ACL2018](https://github.com/bzhangGo/transformer-aan) that accelerates the decoding! 33 | * Fixup: [Fixup Initialization: Residual Learning Without Normalization](https://arxiv.org/abs/1901.09321) 34 | * Relative position representation: [Self-Attention with Relative Position Representations](https://arxiv.org/abs/1803.02155) 35 | 36 | ## Requirements 37 | * python2.7 38 | * tensorflow <= 1.13.2 39 | 40 | ## Usage 41 | [How to use this toolkit for machine translation?](docs/usage) 42 | 43 | ## TODO: 44 | 1. organize the parameters and interpretations in config. 45 | 2. reformat and fulfill code comments 46 | 3. simplify and remove unecessary coding 47 | 4. improve rnn models 48 | 49 | ## Citation 50 | 51 | If you use the source code, please consider citing the follow paper: 52 | ``` 53 | @InProceedings{D18-1459, 54 | author = "Zhang, Biao 55 | and Xiong, Deyi 56 | and su, jinsong 57 | and Lin, Qian 58 | and Zhang, Huiji", 59 | title = "Simplifying Neural Machine Translation with Addition-Subtraction Twin-Gated Recurrent Networks", 60 | booktitle = "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing", 61 | year = "2018", 62 | publisher = "Association for Computational Linguistics", 63 | pages = "4273--4283", 64 | location = "Brussels, Belgium", 65 | url = "http://aclweb.org/anthology/D18-1459" 66 | } 67 | ``` 68 | 69 | If you are interested in the CAEncoder model, please consider citing our TASLP paper: 70 | ``` 71 | @article{Zhang:2017:CRE:3180104.3180106, 72 | author = {Zhang, Biao and Xiong, Deyi and Su, Jinsong and Duan, Hong}, 73 | title = {A Context-Aware Recurrent Encoder for Neural Machine Translation}, 74 | journal = {IEEE/ACM Trans. Audio, Speech and Lang. Proc.}, 75 | issue_date = {December 2017}, 76 | volume = {25}, 77 | number = {12}, 78 | month = dec, 79 | year = {2017}, 80 | issn = {2329-9290}, 81 | pages = {2424--2432}, 82 | numpages = {9}, 83 | url = {https://doi.org/10.1109/TASLP.2017.2751420}, 84 | doi = {10.1109/TASLP.2017.2751420}, 85 | acmid = {3180106}, 86 | publisher = {IEEE Press}, 87 | address = {Piscataway, NJ, USA}, 88 | } 89 | ``` 90 | 91 | ## Reference 92 | When developing this repository, I referred to the following projects: 93 | 94 | * [Nematus](https://github.com/EdinburghNLP/nematus) 95 | * [THUMT](https://github.com/thumt/THUMT) 96 | * [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor) 97 | * [Keras](https://github.com/keras-team/keras) 98 | 99 | ## Contact 100 | For any questions or suggestions, please feel free to contact [Biao Zhang](mailto:biaojiaxing@gmail.com) 101 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | from utils.util import batch_indexer, token_indexer 9 | 10 | 11 | class Dataset(object): 12 | def __init__(self, src_file, tgt_file, 13 | src_vocab, tgt_vocab, max_len=100, 14 | batch_or_token='batch', 15 | data_leak_ratio=0.5): 16 | self.source = src_file 17 | self.target = tgt_file 18 | self.src_vocab = src_vocab 19 | self.tgt_vocab = tgt_vocab 20 | self.max_len = max_len 21 | self.batch_or_token = batch_or_token 22 | self.data_leak_ratio = data_leak_ratio 23 | 24 | self.leak_buffer = [] 25 | 26 | def load_data(self): 27 | with open(self.source, 'r') as src_reader, \ 28 | open(self.target, 'r') as tgt_reader: 29 | while True: 30 | src_line = src_reader.readline() 31 | tgt_line = tgt_reader.readline() 32 | 33 | if src_line == "" or tgt_line == "": 34 | break 35 | 36 | src_line = src_line.strip() 37 | tgt_line = tgt_line.strip() 38 | 39 | if src_line == "" or tgt_line == "": 40 | continue 41 | 42 | yield ( 43 | self.src_vocab.to_id(src_line.strip().split()[:self.max_len]), 44 | self.tgt_vocab.to_id(tgt_line.strip().split()[:self.max_len]) 45 | ) 46 | 47 | def to_matrix(self, batch): 48 | batch_size = len(batch) 49 | 50 | src_lens = [len(sample[1]) for sample in batch] 51 | tgt_lens = [len(sample[2]) for sample in batch] 52 | 53 | src_len = min(self.max_len, max(src_lens)) 54 | tgt_len = min(self.max_len, max(tgt_lens)) 55 | 56 | s = np.zeros([batch_size, src_len], dtype=np.int32) 57 | t = np.zeros([batch_size, tgt_len], dtype=np.int32) 58 | x = [] 59 | for eidx, sample in enumerate(batch): 60 | x.append(sample[0]) 61 | src_ids, tgt_ids = sample[1], sample[2] 62 | 63 | s[eidx, :min(src_len, len(src_ids))] = src_ids[:src_len] 64 | t[eidx, :min(tgt_len, len(tgt_ids))] = tgt_ids[:tgt_len] 65 | return x, s, t 66 | 67 | def batcher(self, size, buffer_size=1000, shuffle=True, train=True): 68 | def _handle_buffer(_buffer): 69 | sorted_buffer = sorted( 70 | _buffer, key=lambda xx: max(len(xx[1]), len(xx[2]))) 71 | 72 | if self.batch_or_token == 'batch': 73 | buffer_index = batch_indexer(len(sorted_buffer), size) 74 | else: 75 | buffer_index = token_indexer( 76 | [[len(sample[1]), len(sample[2])] for sample in sorted_buffer], size) 77 | 78 | index_over_index = batch_indexer(len(buffer_index), 1) 79 | if shuffle: np.random.shuffle(index_over_index) 80 | 81 | for ioi in index_over_index: 82 | index = buffer_index[ioi[0]] 83 | batch = [sorted_buffer[ii] for ii in index] 84 | x, s, t = self.to_matrix(batch) 85 | yield { 86 | 'src': s, 87 | 'tgt': t, 88 | 'index': x, 89 | 'raw': batch, 90 | } 91 | 92 | buffer = self.leak_buffer 93 | self.leak_buffer = [] 94 | for i, (src_ids, tgt_ids) in enumerate(self.load_data()): 95 | buffer.append((i, src_ids, tgt_ids)) 96 | if len(buffer) >= buffer_size: 97 | for data in _handle_buffer(buffer): 98 | # check whether the data is tailed 99 | batch_size = len(data['raw']) if self.batch_or_token == 'batch' \ 100 | else max(np.sum(data['tgt'] > 0), np.sum(data['src'] > 0)) 101 | if batch_size < size * self.data_leak_ratio: 102 | self.leak_buffer += data['raw'] 103 | else: 104 | yield data 105 | buffer = self.leak_buffer 106 | self.leak_buffer = [] 107 | 108 | # deal with data in the buffer 109 | if len(buffer) > 0: 110 | for data in _handle_buffer(buffer): 111 | # check whether the data is tailed 112 | batch_size = len(data['raw']) if self.batch_or_token == 'batch' \ 113 | else max(np.sum(data['tgt'] > 0), np.sum(data['src'] > 0)) 114 | if train and batch_size < size * self.data_leak_ratio: 115 | self.leak_buffer += data['raw'] 116 | else: 117 | yield data 118 | -------------------------------------------------------------------------------- /docs/afs_speech_translation/afs_training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/afs_speech_translation/afs_training.png -------------------------------------------------------------------------------- /docs/afs_speech_translation/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/afs_speech_translation/example.png -------------------------------------------------------------------------------- /docs/colactc/README.md: -------------------------------------------------------------------------------- 1 | ## Efficient CTC Regularization via Coarse Labels for End-to-End Speech Translation, EACL 2023 2 | 3 | - [paper link]() 4 | - source code is available at [st_from_scratch](https://github.com/bzhangGo/st_from_scratch) 5 | 6 | ### **Why CTC Regularization?** 7 | 8 | Speech translation (ST) requires the model to capture the semantics of an audio input, 9 | but auido carries many content-irrelevant information, such as emotion and pauses, that increases the 10 | difficulty of translation modeling. 11 | 12 | CTC Regularization offers a mechanism to dynamically align speech representations with their 13 | discrete labels, down-weighting content-irrelevant information and encouraging the learning of speech semantics. In the 14 | literature, many studies have confirmed its effectiveness on ST. 15 | 16 | 17 | ### **Why NOT CTC Regularization?** 18 | 19 | CTC Regularization requires an extra projection layer similar to the word prediction layer, which bringins in 20 | many model parameters and slows the running. 21 | 22 | We are particularly interested in improving the efficiency of CTC Regularization. 23 | 24 | 25 | ### **CoLaCTC** 26 | 27 | *Why use genuine labels for CTC regulariation?* particularly considering the CTC regularizaiton layer will be dropped 28 | after training. 29 | 30 | Following this idea, we propose to use pseudo CTC labels at coarser grain for CTC regularization, which offers a direct 31 | control over the CTC space and decoupled this space with the genuine word vocabulary space. 32 | We only used some simple operations to produce CoLaCTC labels as follows: 33 | 34 | 35 | 36 | 37 | ### **How does it work?** 38 | 39 | 40 | | System | Params | BLEU | Speedup | 41 | |-------------------------------------------------------|--------|------|---------| 42 | | Baseline (no CTC) | 46.1M | 21.8 | 1.39x | 43 | | CTC Regularization + translation labels | 47.9M | 22.7 | 1.00x | 44 | | CTC Regularization + translation-based CoLaCTC labels | 46.2M | 22.7 | 1.39x | 45 | | CTC Regularization + transcript labels | 47.5M | 23.8 | 1.00x | 46 | | CTC Regularization + transcript-based CoLaCTC labels | 46.2M | 24.3 | 1.31x | 47 | 48 | (Quality on MuST-C En-De) 49 | 50 | ### **Why does it work?** 51 | 52 | We are still lack of understanding on why it could work so well on ST. One observation is that CoLaCTC label sequence is 53 | still quite informative. Using it as the source input for machine translation could achieve decent 54 | performance: 55 | 56 | 57 | 58 | 59 | ### Model Training & Evaluation 60 | 61 | We added the implementation to [st_from_scratch](https://github.com/bzhangGo/st_from_scratch). The implementation is 62 | quite simple. 63 | We change 64 | ```python 65 | seq_values.extend(sequence) 66 | ``` 67 | to 68 | ```python 69 | seq_values.extend([v % self.p.cola_ctc_L for v in sequence]) 70 | ``` 71 | as in https://github.com/bzhangGo/st_from_scratch/blob/master/data.py#L152 72 | 73 | 74 | ### Citation 75 | 76 | Please consider cite our paper as follows: 77 | >Biao Zhang; Barry Haddow; Rico Sennrich (2023). Efficient CTC Regularization via Coarse Labels for End-to-End Speech Translation. In EACL 2023. 78 | ``` 79 | @inproceedings{zhang-etal-2023-colactc, 80 | title = "Efficient CTC Regularization via Coarse Labels for End-to-End Speech Translation", 81 | author = "Zhang, Biao and 82 | Haddow, Barry and 83 | Sennrich, Rico", 84 | booktitle = "Proceedings of the 17th Conference of the European Chapter of the Association for Computational Linguistics: Main Volume", 85 | month = may, 86 | year = "2023", 87 | publisher = "Association for Computational Linguistics", 88 | } 89 | ``` 90 | -------------------------------------------------------------------------------- /docs/colactc/colactc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/colactc/colactc.png -------------------------------------------------------------------------------- /docs/colactc/mt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/colactc/mt.png -------------------------------------------------------------------------------- /docs/conditional_language_specific_routing/README.md: -------------------------------------------------------------------------------- 1 | ## Share or Not? Learning to Schedule Language-Specific Capacity for Multilingual Translation 2 | 3 | 4 | We host our source code for our ICLR paper here. 5 | 6 | Please go to [iclr2021_clsr branch](https://github.com/bzhangGo/zero/tree/iclr2021_clsr) for more details. 7 | -------------------------------------------------------------------------------- /docs/context_aware_st/README.md: -------------------------------------------------------------------------------- 1 | ## Beyond Sentence-Level End-to-End Speech Translation: Context Helps 2 | 3 | [**Paper**](https://aclanthology.org/2021.acl-long.200/) | 4 | [**Highlights**](#paper-highlights) | 5 | [**Overview**](#context-aware-st) | 6 | [**Results**](#results) | 7 | [**Training&Eval**](#training-and-evaluation) | 8 | [**Citation**](#citation) 9 | 10 | ### Paper highlights 11 | 12 | Contextual information carries valuable clues for translation. So far, studies on text-based context-aware translation have shown 13 | success, but whether and how context helps end-to-end speech question is still under-studied. 14 | 15 | We believe that context would be more helpful to ST, because speech signals often contain more ambiguous expressions apart 16 | from the ones commonly occurred in texts. For example, homophones, like flower and flour, are almost indistinguishable without context. 17 | 18 | We study context-aware ST in this project and using a simple concatenation-based model. Our main findings are as follows: 19 | * Incorporating context improves overall translation quality (+0.18-2.61 BLEU) and benefits pronoun translation across different language pairs. 20 | * Context also improves the translation of homophones 21 | * ST models with contexts suffer less from (artificial) audio segmentation errors 22 | * Contextual modeling improves translation quality and reduces latency and flicker for simultaneous translation under re-translation strategy 23 | 24 | 25 | ### Context Aware ST 26 | 27 | We use AFS to reduce the audio feature length and improve training efficiency. Figure below shows our overall framework: 28 | 29 | 30 | 31 | Note creating novel context-aware ST architectures is not the key topic of this study, which is our next-step study. 32 | 33 | 34 | ### Training and Evaluation 35 | 36 | - We implement the model in [context-aware speech_translation branch](https://github.com/bzhangGo/zero/tree/context_aware_speech_translation) 37 | 38 | Our training involves two phrases, as shown below: 39 | 40 | 41 | 42 | Please refer to [our paper](https://aclanthology.org/2021.acl-long.200/) for more details. 43 | 44 | 45 | ### Results 46 | 47 | We mainly experiment with MuST-C corpus and below we show our model outputs (also BLEU) in all languages. 48 | 49 | | Model | De | Es | Fr | It | Nl | Pt | Ro | Ru | 50 | |---------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------| 51 | | Baseline | [22.38](http://data.statmt.org/bzhang/acl2021_context_aware_st/baseline/de.txt) | [27.04](http://data.statmt.org/bzhang/acl2021_context_aware_st/baseline/es.txt) | [33.43](http://data.statmt.org/bzhang/acl2021_context_aware_st/baseline/fr.txt) | [23.35](http://data.statmt.org/bzhang/acl2021_context_aware_st/baseline/it.txt) | [25.05](http://data.statmt.org/bzhang/acl2021_context_aware_st/baseline/nl.txt) | [26.55](http://data.statmt.org/bzhang/acl2021_context_aware_st/baseline/pt.txt) | [21.87](http://data.statmt.org/bzhang/acl2021_context_aware_st/baseline/ro.txt) | [14.92](http://data.statmt.org/bzhang/acl2021_context_aware_st/baseline/ru.txt) | 52 | | CA ST w/ SWBD | [22.7](http://data.statmt.org/bzhang/acl2021_context_aware_st/swbd/de.txt) | [27.12](http://data.statmt.org/bzhang/acl2021_context_aware_st/swbd/es.txt) | [34.23](http://data.statmt.org/bzhang/acl2021_context_aware_st/swbd/fr.txt) | [23.46](http://data.statmt.org/bzhang/acl2021_context_aware_st/swbd/it.txt) | [25.84](http://data.statmt.org/bzhang/acl2021_context_aware_st/swbd/nl.txt) | [26.63](http://data.statmt.org/bzhang/acl2021_context_aware_st/swbd/pt.txt) | [23.7](http://data.statmt.org/bzhang/acl2021_context_aware_st/swbd/ro.txt) | [15.53](http://data.statmt.org/bzhang/acl2021_context_aware_st/swbd/ru.txt) | 53 | | CA ST w/ IMED | [22.86](http://data.statmt.org/bzhang/acl2021_context_aware_st/imed/de.txt) | [27.5](http://data.statmt.org/bzhang/acl2021_context_aware_st/imed/es.txt) | [34.28](http://data.statmt.org/bzhang/acl2021_context_aware_st/imed/fr.txt) | [23.53](http://data.statmt.org/bzhang/acl2021_context_aware_st/imed/it.txt) | [26.12](http://data.statmt.org/bzhang/acl2021_context_aware_st/imed/nl.txt) | [27.37](http://data.statmt.org/bzhang/acl2021_context_aware_st/imed/pt.txt) | [24.48](http://data.statmt.org/bzhang/acl2021_context_aware_st/imed/ro.txt) | [15.95](http://data.statmt.org/bzhang/acl2021_context_aware_st/imed/ru.txt) | 54 | 55 | 56 | ### Citation 57 | 58 | Please consider cite our paper as follows: 59 | >Biao Zhang; Ivan Titov; Barry Haddow; Rico Sennrich (2021). Beyond Sentence-Level End-to-End Speech Translation: Context Helps. In Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers). 60 | ``` 61 | @inproceedings{zhang-etal-2021-beyond, 62 | title = "Beyond Sentence-Level End-to-End Speech Translation: Context Helps", 63 | author = "Zhang, Biao and 64 | Titov, Ivan and 65 | Haddow, Barry and 66 | Sennrich, Rico", 67 | booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)", 68 | month = aug, 69 | year = "2021", 70 | address = "Online", 71 | publisher = "Association for Computational Linguistics", 72 | url = "https://aclanthology.org/2021.acl-long.200", 73 | doi = "10.18653/v1/2021.acl-long.200", 74 | pages = "2566--2578", 75 | abstract = "Document-level contextual information has shown benefits to text-based machine translation, but whether and how context helps end-to-end (E2E) speech translation (ST) is still under-studied. We fill this gap through extensive experiments using a simple concatenation-based context-aware ST model, paired with adaptive feature selection on speech encodings for computational efficiency. We investigate several decoding approaches, and introduce in-model ensemble decoding which jointly performs document- and sentence-level translation using the same model. Our results on the MuST-C benchmark with Transformer demonstrate the effectiveness of context to E2E ST. Compared to sentence-level ST, context-aware ST obtains better translation quality (+0.18-2.61 BLEU), improves pronoun and homophone translation, shows better robustness to (artificial) audio segmentation errors, and reduces latency and flicker to deliver higher quality for simultaneous translation.", 76 | } 77 | ``` -------------------------------------------------------------------------------- /docs/context_aware_st/cast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/context_aware_st/cast.png -------------------------------------------------------------------------------- /docs/context_aware_st/training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/context_aware_st/training.png -------------------------------------------------------------------------------- /docs/depth_scale_init_and_merged_attention/README.md: -------------------------------------------------------------------------------- 1 | ## Improving Deep Transformer with Depth-Scaled Initialization and Merged Attention, EMNLP2019 2 | 3 | - [paper link](https://www.aclweb.org/anthology/D19-1083/) 4 | 5 | This paper focus on improving Deep Transformer. 6 | Our empirical observation suggests that simply stacking more Transformer layers makes training divergent. 7 | Rather than resorting to the pre-norm structure which shifts the layer normalization before modeling blocks, 8 | we analyze the reason why a vanilla deep Transformer suffers from poor convergence. 9 | 10 | 11 | 12 | Our evidence shows that it's because of *gradient vanishing* (shown above) caused by the interaction between residual connection 13 | and layer normalization. **In short, the residual connection increases the variance of its output, which decreases the gradient 14 | backpropagated from layer normalization. (empirically)** 15 | 16 | We solve this problem by proposing depth-scaled initialization (DS-Init), which decreases 17 | parameter variance at the initialization stage. DS-Init reduces output variance of residual connections so as to 18 | ease gradient back-propagation through normalization layers. In practice, DS-Init often produces slightly better 19 | translation quality than the pre-norm structure. 20 | 21 | We also care about the computational overhead raised by deep models. To settle this issue, we propose the merged 22 | attention network which combines a simplified average attention model and the encoder-decoder attention model on 23 | the target side. Merged attention model enables the deep Transformer matching the decoding speed of its baseline 24 | with a clear higher BLEU score. 25 | 26 | ### Approach 27 | 28 | To train a deep Transformer model for machine translation, scale your initialization for each layer as follows: 29 | 30 | 31 | 32 | where `\alpha` and `\gamma` are hyperparameters for the uniform distribution. `l` denotes the depth of the layer. 33 | 34 | 35 | ### Model Training 36 | 37 | Train 12-layer Transformer model with the following settings: 38 | >The model class is: `transformer_fuse`, the merged attention is enabled by giving `fuse_mask` into `dot_attention` function. 39 | ``` 40 | python run.py --mode train --parameters=hidden_size=512,embed_size=512,filter_size=2048,\ 41 | initializer="uniform_unit_scaling",initializer_gain=1.,\ 42 | model_name="transformer_fuse",scope_name="transformer_fuse",\ 43 | deep_transformer_init=True,\ 44 | num_encoder_layer=12,\ 45 | num_decoder_layer=12,\ 46 | ``` 47 | 48 | Other details can be found [here](../usage). 49 | 50 | ### Performance and Download 51 | 52 | We offer a range of [pretrained models](http://data.statmt.org/bzhang/emnlp19_deep_transformer/) for further study. 53 | 54 | 55 | | Task | Model | BLEU | Download | 56 | |---------------|-------------------------------------|-------| -------- | 57 | | WMT14 En-Fr | Base Transformer + 6 Layers | 39.09 | | 58 | | | Base Transformer + Ours + 12 Layers | 40.58 | | 59 | | IWSLT14 De-En | Base Transformer + 6 Layers | 34.41 | | 60 | | | Base Transformer + Ours + 12 Layers | 35.63 | | 61 | | WMT18 En-Fr | Base Transformer + 6 Layers | 15.5 | | 62 | | | Base Transformer + Ours + 12 Layers | 15.8 | | 63 | | WMT18 Zh-En | Base Transformer + 6 Layers | 21.1 | | 64 | | | Base Transformer + Ours + 12 Layers | 22.3 | | 65 | | WMT14 En-De | Base Transformer + 6 Layers | 27.59 | [download](http://data.statmt.org/bzhang/emnlp19_deep_transformer/model/base.tar.gz) | 66 | | | Base Transformer + Ours + 12 Layers | 28.55 | | 67 | | | Big Transformer + 6 Layers | 29.07 | [download](http://data.statmt.org/bzhang/emnlp19_deep_transformer/model/big.tar.gz) | 68 | | | Big Transformer + Ours + 12 Layers | 29.47 | | 69 | | | Base Transformer + Ours + 20 Layers | 28.67 | [download](http://data.statmt.org/bzhang/emnlp19_deep_transformer/model/base+fuse_init20.tar.gz) | 70 | | | Base Transformer + Ours + 30 Layers | 28.86 | [download](http://data.statmt.org/bzhang/emnlp19_deep_transformer/model/base+fuse_init30.tar.gz) | 71 | | | Big Transformer + Ours + 20 Layers | 29.62 | [download](http://data.statmt.org/bzhang/emnlp19_deep_transformer/model/big+fuse_init20.tar.gz) | 72 | 73 | Please go to [pretrained models](http://data.statmt.org/bzhang/emnlp19_deep_transformer/) for more details. 74 | 75 | ### Citation 76 | 77 | Please consider cite our paper as follows: 78 | >Biao Zhang; Ivan Titov; Rico Sennrich (2019). Improving Deep Transformer with Depth-Scaled Initialization and Merged Attention. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP). Hong Kong, China, pp. 898-909. 79 | ``` 80 | @inproceedings{zhang-etal-2019-improving-deep, 81 | title = "Improving Deep Transformer with Depth-Scaled Initialization and Merged Attention", 82 | author = "Zhang, Biao and 83 | Titov, Ivan and 84 | Sennrich, Rico", 85 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)", 86 | month = nov, 87 | year = "2019", 88 | address = "Hong Kong, China", 89 | publisher = "Association for Computational Linguistics", 90 | url = "https://www.aclweb.org/anthology/D19-1083", 91 | doi = "10.18653/v1/D19-1083", 92 | pages = "898--909", 93 | abstract = "The general trend in NLP is towards increasing model capacity and performance via deeper neural networks. However, simply stacking more layers of the popular Transformer architecture for machine translation results in poor convergence and high computational overhead. Our empirical analysis suggests that convergence is poor due to gradient vanishing caused by the interaction between residual connection and layer normalization. We propose depth-scaled initialization (DS-Init), which decreases parameter variance at the initialization stage, and reduces output variance of residual connections so as to ease gradient back-propagation through normalization layers. To address computational cost, we propose a merged attention sublayer (MAtt) which combines a simplified average-based self-attention sublayer and the encoder-decoder attention sublayer on the decoder side. Results on WMT and IWSLT translation tasks with five translation directions show that deep Transformers with DS-Init and MAtt can substantially outperform their base counterpart in terms of BLEU (+1.1 BLEU on average for 12-layer models), while matching the decoding speed of the baseline model thanks to the efficiency improvements of MAtt. Source code for reproduction will be released soon.", 94 | } 95 | ``` 96 | -------------------------------------------------------------------------------- /docs/depth_scale_init_and_merged_attention/dsinit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/depth_scale_init_and_merged_attention/dsinit.png -------------------------------------------------------------------------------- /docs/depth_scale_init_and_merged_attention/grad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/depth_scale_init_and_merged_attention/grad.png -------------------------------------------------------------------------------- /docs/interleaved_bidirectional_transformer/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/interleaved_bidirectional_transformer/overview.png -------------------------------------------------------------------------------- /docs/iwslt2021_uoe_submission/README.md: -------------------------------------------------------------------------------- 1 | ## Edinburgh's End-to-End Multilingual Speech Translation System for IWSLT 2021 2 | 3 | [**Paper**](https://aclanthology.org/2021.iwslt-1.19/) | 4 | [**Highlights**](#paper-highlights) | 5 | [**Results**](#results) | 6 | [**Citation**](#citation) 7 | 8 | ### Paper highlights 9 | 10 | We participated IWSLT21 multilingual speech translation with an integrated end-to-end system. Below shows the overview of our 11 | system: 12 | 13 | 14 | 15 | For more details, please check out [our paper](https://aclanthology.org/2021.iwslt-1.19/). 16 | 17 | 18 | ### Results 19 | 20 | We release our system and outputs at [here](http://data.statmt.org/bzhang/iwslt2021_uoe_system/) to facilitate other researches, particularly for those who are interested in a 21 | reproduction and comparision. 22 | 23 | 24 | ### Citation 25 | 26 | Please consider cite our paper as follows: 27 | >Biao Zhang; Rico Sennrich (2021). Edinburgh's End-to-End Multilingual Speech Translation System for IWSLT 2021. In Proceedings of the 18th International Conference on Spoken Language Translation (IWSLT 2021). 28 | ``` 29 | @inproceedings{zhang-sennrich-2021-edinburghs, 30 | title = "{E}dinburgh{'}s End-to-End Multilingual Speech Translation System for {IWSLT} 2021", 31 | author = "Zhang, Biao and 32 | Sennrich, Rico", 33 | booktitle = "Proceedings of the 18th International Conference on Spoken Language Translation (IWSLT 2021)", 34 | month = aug, 35 | year = "2021", 36 | address = "Bangkok, Thailand (online)", 37 | publisher = "Association for Computational Linguistics", 38 | url = "https://aclanthology.org/2021.iwslt-1.19", 39 | doi = "10.18653/v1/2021.iwslt-1.19", 40 | pages = "160--168", 41 | abstract = "This paper describes Edinburgh{'}s submissions to the IWSLT2021 multilingual speech translation (ST) task. We aim at improving multilingual translation and zero-shot performance in the constrained setting (without using any extra training data) through methods that encourage transfer learning and larger capacity modeling with advanced neural components. We build our end-to-end multilingual ST model based on Transformer, integrating techniques including adaptive speech feature selection, language-specific modeling, multi-task learning, deep and big Transformer, sparsified linear attention and root mean square layer normalization. We adopt data augmentation using machine translation models for ST which converts the zero-shot problem into a zero-resource one. Experimental results show that these methods deliver substantial improvements, surpassing the official baseline by {\textgreater} 15 average BLEU and outperforming our cascading system by {\textgreater} 2 average BLEU. Our final submission achieves competitive performance (runner up).", 42 | } 43 | ``` -------------------------------------------------------------------------------- /docs/iwslt2021_uoe_submission/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/iwslt2021_uoe_submission/overview.png -------------------------------------------------------------------------------- /docs/l0drop/README.md: -------------------------------------------------------------------------------- 1 | ## On Sparsifying Encoder Outputs in Sequence-to-Sequence Models 2 | 3 | [**Paper**](https://aclanthology.org/2021.findings-acl.255/) | 4 | [**Highlights**](#paper-highlights) | 5 | [**Overview**](#l0drop) | 6 | [**CaseStudy**](#examples) | 7 | [**Training&Eval**](#training-and-evaluation) | 8 | [**Citation**](#citation) 9 | 10 | ### Paper Highlights 11 | 12 | Information is not uniformly distributed in sentences, while 13 | 14 | Standard encoder-decoder models for sequence-to-sequence learning always feed all encoder outputs to the decoder for 15 | generation. However, information in a sequence is not uniformly distributed over tokens. In translation, we have null 16 | alignments where some source tokens are not translated at all; and in summarization, source document often contains many 17 | redundant tokens. The research questions we are interested in: 18 | 19 | * Are encoder outputs compressible? 20 | * Can we identify those uninformative outputs and prune them out automatically? 21 | * Can we obtain higher inference speed with shortened encoding sequence? 22 | 23 | We propose L0Drop to this end, and our main findings are as follows: 24 | 25 | * We confirm that the encoder outputs can be compressed, around 40-70% of them can be dropped without large effects on 26 | the generation quality. 27 | * The resulting sparsity level differs across word types, the encodings corresponding to function words 28 | (such as determiners, prepositions) are more frequently pruned than those of content words (e.g., verbs and nouns). 29 | * L0Drop can improve decoding efficiency particularly for lengthy source inputs. We achieve a decoding speedup of up to 30 | 1.65x on document summarization tasks and 1.20x on character-based machine translation task. 31 | * Filtering out source encodings with rule-based sparse patterns is also feasible. 32 | 33 | 34 | ### L0Drop 35 | 36 | L0Drop forces model to route information through a subset of the encoder outputs, and the subset is learned automatically. 37 | 38 | 39 | 40 | L0Drop is different from sparse attention, which is comparably shown below. 41 | 42 | 43 | 44 | Note that L0Drop is data-driven and task-agnostic. We applied it to machine translation as well as 45 | document summarization tasks. Results on WMT14 En-De translation tasks are shown below: 46 | 47 | 48 | 49 | Please refer to [our paper](https://aclanthology.org/2021.findings-acl.255/) for more details. 50 | 51 | ### Examples 52 | 53 | Here, we show some examples learned by L0Drop on machine translation tasks (highlighted source words are dropped after encoding): 54 | 55 | | Task | Sample | 56 | |----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 57 | | WMT18 Zh-En | Source `这` 一年 `来` `,` 中国 电@@ 商 出 `了` `一份` 怎样 `的` 成绩@@ 单 ?
Reference: what sort of report card did China 's e @-@ commerce industry receive this year ?
Translation: what kind of report card was produced by Chinese telec@@ om dealers during the year ? | 58 | | | Source: 中国 `在` 地区 合作 `中` `发挥` `的` 作用 一贯 `是` 积极 正面 `的` `,` `受到` 地区 国家 高度 认可 `。`
Reference: China has always played an active and positive role in regional cooperation , which is well recognized by regional countries .
Translation: China 's role in regional cooperation has always been positive and highly recognized by the countries of the region . | 59 | | WMT14 En-De | Source: `The` cause `of` `the` b@@ `last` `was` not known `,` he said `.`
Reference: Die Ursache der Explo@@ sion sei nicht bekannt , erklärte er .
Translation: Die Ursache der Explo@@ sion war nicht bekannt , sagte er . | 60 | | | Source: `The` night `was` long , `the` music loud and `the` atmosphere good `,` but `at` some `point` everyone has `to` go home `.`
Reference: Die Nacht war lang , die Musik laut und die Stimmung gut , aber irgendwann geht es nach Hause .
Translation: Die Nacht war lang , die Musik laut und die Atmosphäre gut , aber irgendwann muss jeder nach Hause gehen . | 61 | 62 | 63 | ### Training and Evaluation 64 | 65 | - We implement the model in [transformer_l0drop](../../models/transformer_l0drop.py) and [l0norm](../../modules/l0norm.py) 66 | 67 | #### Training 68 | 69 | It's possible to train Transformer with L0Drop from scratch by setting proper schedulers for `\lambda`, 70 | a hyper-parameter loosely controlling the sparsity rate of L0Drop. Unfortunately, the optimal scheduler is 71 | data&task-dependent. 72 | 73 | We suggest first pre-train a normal Transformer model, and then finetune the Transformer+L0Drop. This could 74 | save a lot of efforts. 75 | 76 | * Step 1. train a normal Transformer model as described [here](../../docs/usage/README.md). Below is 77 | an example on WMT14 En-De for reference: 78 | ``` 79 | data_dir=the preprocessed data diretory 80 | zero=the path of this code base 81 | python $zero/run.py --mode train --parameters=hidden_size=512,embed_size=512,filter_size=2048,\ 82 | dropout=0.1,label_smooth=0.1,attention_dropout=0.1,\ 83 | max_len=256,batch_size=80,eval_batch_size=32,\ 84 | token_size=6250,batch_or_token='token',\ 85 | initializer="uniform_unit_scaling",initializer_gain=1.,\ 86 | model_name="transformer",scope_name="transformer",buffer_size=60000,\ 87 | clip_grad_norm=0.0,\ 88 | num_heads=8,\ 89 | lrate=1.0,\ 90 | process_num=3,\ 91 | num_encoder_layer=6,\ 92 | num_decoder_layer=6,\ 93 | warmup_steps=4000,\ 94 | lrate_strategy="noam",\ 95 | epoches=5000,\ 96 | update_cycle=4,\ 97 | gpus=[0],\ 98 | disp_freq=1,\ 99 | eval_freq=5000,\ 100 | sample_freq=1000,\ 101 | checkpoints=5,\ 102 | max_training_steps=300000,\ 103 | beta1=0.9,\ 104 | beta2=0.98,\ 105 | epsilon=1e-8,\ 106 | random_seed=1234,\ 107 | src_vocab_file="$data_dir/vocab.zero.en",\ 108 | tgt_vocab_file="$data_dir/vocab.zero.de",\ 109 | src_train_file="$data_dir/train.32k.en.shuf",\ 110 | tgt_train_file="$data_dir/train.32k.de.shuf",\ 111 | src_dev_file="$data_dir/dev.32k.en",\ 112 | tgt_dev_file="$data_dir/dev.32k.de",\ 113 | src_test_file="$data_dir/newstest2014.32k.en",\ 114 | tgt_test_file="$data_dir/newstest2014.de",\ 115 | output_dir="train" 116 | ``` 117 | 118 | * Step 2. finetune L0Drop using the following command: 119 | ``` 120 | data_dir=the preprocessed data directory 121 | zero=the path of this code base 122 | python $zero/run.py --mode train --parameters=\ 123 | l0_norm_reg_scalar=0.3,\ 124 | l0_norm_warm_up=False,\ 125 | model_name="transformer_l0drop",scope_name="transformer",\ 126 | pretrained_model="path-to-pretrained-transformer",\ 127 | max_training_steps=320000,\ 128 | src_vocab_file="$data_dir/vocab.zero.en",\ 129 | tgt_vocab_file="$data_dir/vocab.zero.de",\ 130 | src_train_file="$data_dir/train.32k.en.shuf",\ 131 | tgt_train_file="$data_dir/train.32k.de.shuf",\ 132 | src_dev_file="$data_dir/dev.32k.en",\ 133 | tgt_dev_file="$data_dir/dev.32k.de",\ 134 | src_test_file="$data_dir/newstest2014.32k.en",\ 135 | tgt_test_file="$data_dir/newstest2014.de",\ 136 | output_dir="train" 137 | ``` 138 | where `l0_norm_reg_scalar` is the `\lambda`, and `0.2 or 0.3` is a nice hyperparameter in our experiments. 139 | 140 | #### Evaluation 141 | 142 | The evaluation follows the same procedure as the baseline Transformer. 143 | 144 | ### Citation 145 | 146 | Please consider cite our paper as follows: 147 | >Biao Zhang; Ivan Titov; Rico Sennrich (2021). On Sparsifying Encoder Outputs in Sequence-to-Sequence Models. Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021. 148 | ``` 149 | @inproceedings{zhang-etal-2021-sparsifying, 150 | title = "On Sparsifying Encoder Outputs in Sequence-to-Sequence Models", 151 | author = "Zhang, Biao and 152 | Titov, Ivan and 153 | Sennrich, Rico", 154 | booktitle = "Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021", 155 | month = aug, 156 | year = "2021", 157 | address = "Online", 158 | publisher = "Association for Computational Linguistics", 159 | url = "https://aclanthology.org/2021.findings-acl.255", 160 | doi = "10.18653/v1/2021.findings-acl.255", 161 | pages = "2888--2900", 162 | } 163 | 164 | ``` 165 | 166 | -------------------------------------------------------------------------------- /docs/l0drop/l0drop-att.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/l0drop/l0drop-att.png -------------------------------------------------------------------------------- /docs/l0drop/l0drop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/l0drop/l0drop.png -------------------------------------------------------------------------------- /docs/l0drop/mt_ende.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/l0drop/mt_ende.png -------------------------------------------------------------------------------- /docs/multilingual_laln_lalt/many-to-many.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/multilingual_laln_lalt/many-to-many.xlsx -------------------------------------------------------------------------------- /docs/rela_sparse_attention/README.md: -------------------------------------------------------------------------------- 1 | ## Sparse Attention with Linear Units, EMNLP2021 2 | 3 | 4 | [**Paper**](https://arxiv.org/abs/2104.07012/) | 5 | [**Highlights**](#paper-highlights) | 6 | [**Training**](#model-training) | 7 | [**Results**](#results) | 8 | [**Citation**](#citation) 9 | 10 | 11 | ### Paper highlights 12 | 13 | Attention contributes substantially to NMT/NLP. It relies on `SoftMax` activation to produce a dense, categorical 14 | distribution as an estimation of the relevance between input query and contexts. 15 | 16 | **But why dense? and Why distribution?** 17 | 18 | 19 | * Low attention score in dense attention doesn't mean low relevance. By contrast, sparse attention often does better. 20 | * Attention scores estimate the relevance, which doesn't necessarily follow distribution albeit such normalization 21 | often stabilize the training. 22 | 23 | 24 | In this work, we propose rectified linear attention (ReLA), which directly uses ReLU rather than 25 | softmax as an activation function for attention scores. ReLU naturally leads to sparse attention, and we 26 | apply [RMSNorm](https://openreview.net/pdf?id=BylmcHHgIB) to attention outputs to stabilize model training. Below 27 | shows the difference between ReLA and the vanilla attention. 28 | 29 | 30 | 31 | We find that: 32 | * ReLA achieves comparable translation performance to the softmax-based attention on five translation tasks, with 33 | similar running efficiency, but faster than other sparse baselines. 34 | * ReLA delivers high sparsity rate, high head diversity, and better accuracy than all baselines with respect to 35 | word alignment. 36 | * We also observe the emergence of attention heads with a high rate of null attention, only activating for certain queries. 37 | 38 | 39 | ### Model Training 40 | 41 | - We implement the model in [transformer_rela](../../models/transformer_rela.py) and [ReLA](../../modules/rela.py) 42 | 43 | The training of `ReLA` follows the baseline. You just need to change the `model_name` to `transformer_rela` as below 44 | (take WMT14 En-De as example): 45 | ``` 46 | data_dir=the preprocessed data diretory 47 | zero=the path of this code base 48 | python $zero/run.py --mode train --parameters=hidden_size=512,embed_size=512,filter_size=2048,\ 49 | dropout=0.1,label_smooth=0.1,attention_dropout=0.1,\ 50 | max_len=256,batch_size=80,eval_batch_size=32,\ 51 | token_size=6250,batch_or_token='token',\ 52 | initializer="uniform_unit_scaling",initializer_gain=1.,\ 53 | model_name="transformer_rela",scope_name="transformer_rela",buffer_size=60000,\ 54 | clip_grad_norm=0.0,\ 55 | num_heads=8,\ 56 | lrate=1.0,\ 57 | process_num=3,\ 58 | num_encoder_layer=6,\ 59 | num_decoder_layer=6,\ 60 | warmup_steps=4000,\ 61 | lrate_strategy="noam",\ 62 | epoches=5000,\ 63 | update_cycle=4,\ 64 | gpus=[0],\ 65 | disp_freq=1,\ 66 | eval_freq=5000,\ 67 | sample_freq=1000,\ 68 | checkpoints=5,\ 69 | max_training_steps=300000,\ 70 | beta1=0.9,\ 71 | beta2=0.98,\ 72 | epsilon=1e-8,\ 73 | random_seed=1234,\ 74 | src_vocab_file="$data_dir/vocab.zero.en",\ 75 | tgt_vocab_file="$data_dir/vocab.zero.de",\ 76 | src_train_file="$data_dir/train.32k.en.shuf",\ 77 | tgt_train_file="$data_dir/train.32k.de.shuf",\ 78 | src_dev_file="$data_dir/dev.32k.en",\ 79 | tgt_dev_file="$data_dir/dev.32k.de",\ 80 | src_test_file="$data_dir/newstest2014.32k.en",\ 81 | tgt_test_file="$data_dir/newstest2014.de",\ 82 | output_dir="train" 83 | ``` 84 | 85 | 86 | ### Results 87 | 88 | * Translation performance (SacreBLEU Scores) on different WMT tasks 89 | 90 | | Model | WMT14 En-Fr | WMT18 En-Fi | WMT18 Zh-En | WMT16 Ro-En | 91 | |:----------:|:-----------:|:-----------:|:-----------:|:-----------:| 92 | | softmax | 37.2 | 15.5 | 21.1 | 32.7 | 93 | | sparsemax | 37.3 | 15.1 | 19.2 | 33.5 | 94 | | 1.5-entmax | 37.9 | 15.5 | 20.8 | 33.2 | 95 | | ReLA | 37.9 | 15.4 | 20.8 | 32.9 | 96 | 97 | * Training and decoding efficiency of ReLA (based on tensorflow 1.13) 98 | 99 | | Model | Params | Train Speedup | Decode Speedup | 100 | |:----------:|:------:|:-------------:|:--------------:| 101 | | softmax | 72.31M | 1.00x | 1.00x | 102 | | sparsemax | 72.31M | 0.26x | 0.54x | 103 | | 1.5-entmax | 72.31M | 0.27x | 0.49x | 104 | | ReLA | 72.34M | 0.93x | 0.98x | 105 | 106 | 107 | * Source-target attention of ReLA aligns better with word alignment 108 | 109 | 110 | 111 | Note solid curves are for best head per layer, while dashed curves are average results over heads. 112 | 113 | * ReLA enables null-attention: attend to nothing 114 | 115 | 116 | 117 | 118 | ### Citation 119 | 120 | Please consider cite our paper as follows: 121 | >Biao Zhang; Ivan Titov; Rico Sennrich (2021). Sparse Attention with Linear Units. In The 2021 Conference on Empirical Methods in Natural Language Processing. Punta Cana, Dominican Republic 122 | ``` 123 | @inproceedings{zhang-etal-2021-sparse, 124 | title = "Sparse Attention with Linear Units", 125 | author = "Zhang, Biao and 126 | Titov, Ivan and 127 | Sennrich, Rico", 128 | booktitle = "The 2021 Conference on Empirical Methods in Natural Language Processing", 129 | month = nov, 130 | year = "2021", 131 | address = "Punta Cana, Dominican Republic", 132 | publisher = "Association for Computational Linguistics", 133 | eprint = "2104.07012" 134 | } 135 | ``` 136 | -------------------------------------------------------------------------------- /docs/rela_sparse_attention/aer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/rela_sparse_attention/aer.png -------------------------------------------------------------------------------- /docs/rela_sparse_attention/null.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/rela_sparse_attention/null.png -------------------------------------------------------------------------------- /docs/rela_sparse_attention/rela.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/rela_sparse_attention/rela.png -------------------------------------------------------------------------------- /docs/usage/README.md: -------------------------------------------------------------------------------- 1 | ## Questions 2 | 1. What's the effective batch size, training and decoding 3 | 4 | * When using `token-based` training (batch_or_token=token), the effective token number equals `number_gpus * update_cycles * token_siz`. 5 | * When using `batch-based` training (batch_or_token=batch), the effective batch size equals `number_gpus * update_cycles * batch_size` 6 | * At decoding phrase, we only use batch-based decoding with size of `eval_batch_size`. 7 | 8 | 2. What's the difference between `model_name` and `scope_name` 9 | 10 | The `model_name` means which model you want to train. The model name should be a registered model, which is 11 | under the folder `models`. The `scope_name` denotes the scope name in tensorflow for each model weights or variables. 12 | 13 | For example, when you want to train a Transformer model, you should set `model_name=transformer`. But you can use 14 | any valid scope name as you want, such as transformer, nmtmodel, transformer_exp1, .etc. 15 | 16 | ## How to use it? 17 | 18 | Below is a rough procedure for WMT14 En-De translation tasks. 19 | 20 | 1. Prepare your training, development and test data. 21 | 22 | For example, you can download the preprocessed WMT14 En-De dataset from [Stanford NMT](https://nlp.stanford.edu/projects/nmt/) 23 | * The training file: [train.en](https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.en), 24 | [train.de](https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.de) 25 | * The development file: [newstest12.en](https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2012.en), 26 | [newstest12.de](https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2012.de), 27 | [newstest13.en](https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2013.en), 28 | [newstest13.de](https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2013.de) 29 | 30 | Then, concate the `newstest12.en` and `newstest13.en` into `dev.en` using command like `cat newstest12.en newstest13.en > dev.en`. 31 | The same is for German language: `cat newstest12.de newstest13.de > dev.de` 32 | * The test file: [newstest14.en](https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.en), 33 | [newstest14.de](https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.de) 34 | 35 | 2. Preprocess your dataset. 36 | 37 | Generally, you can process your dataset with a standard pipeline as given in the WMT official site. 38 | Full content can be found at the preprocessed [datasets](http://data.statmt.org/wmt17/translation-task/preprocessed/). 39 | See the `prepare.sh` for more details. 40 | 41 | In our case, for WMT14 En-De translation, however, the dataset has already been pre-processeed. 42 | So, you can do nothing at this stage. 43 | 44 | *For some languages, such as Chinese, you need to perform word segmentation (Chinese-version Tokenize) first. 45 | You can find more information about segmentation [here](https://nlp.stanford.edu/software/segmenter.shtml)* 46 | 47 | 3. Optional but strongly suggested, Perform BPE decoding 48 | 49 | BPE algorithm is the most popular and currently standard way to handle rare words, or OOVs. It iteratively 50 | merges the most frequent patterns until the maximum merging number is reached. It splits rare words into 51 | `sub-words`, such as `Bloom` => `Blo@@ om`. Another benefit of BPE is that you can control the size of vocabulary. 52 | 53 | - download the [subword project](https://github.com/rsennrich/subword-nmt) 54 | - learn the subword model: 55 | ``` 56 | python subword-nmt/learn_joint_bpe_and_vocab.py --input train.en train.de -s 32000 -o bpe32k --write-vocabulary vocab.en vocab.de 57 | ``` 58 | Notice that the 32000 indicates 32k pieces, or you can simply understand it as your vocabulary size. 59 | - Apply the subword model to all your datasets. 60 | - To training data 61 | ``` 62 | python subword-nmt/apply_bpe.py --vocabulary vocab.en --vocabulary-threshold 50 -c bpe32k < train.en > train.32k.en 63 | python subword-nmt/apply_bpe.py --vocabulary vocab.de --vocabulary-threshold 50 -c bpe32k < train.de > train.32k.de 64 | ``` 65 | - To dev data 66 | ``` 67 | python subword-nmt/apply_bpe.py --vocabulary vocab.en --vocabulary-threshold 50 -c bpe32k < dev.en > dev.32k.en 68 | python subword-nmt/apply_bpe.py --vocabulary vocab.de --vocabulary-threshold 50 -c bpe32k < dev.de > dev.32k.de 69 | ``` 70 | Notice that you do not have to apply bpe to the `dev.de`, but we use it in our model. 71 | - To test data 72 | ``` 73 | python subword-nmt/apply_bpe.py --vocabulary vocab.en --vocabulary-threshold 50 -c bpe32k < newstest14.en > newstest14.32k.en 74 | ``` 75 | 76 | 4. Extract Vocabulary 77 | 78 | You still need to prepare the vocabulary using our code, because there are some special symbols in our vocabulary. 79 | - download our project: 80 | ```git clone https://github.com/bzhangGo/zero.git``` 81 | - Run the code as follows: 82 | ``` 83 | python zero/vocab.py train.en vocab.en 84 | python zero/vocab.py train.de vocab.de 85 | ``` 86 | Roughly, the vocabulary size would be 32000, more or less. 87 | 88 | 5. Training your model. 89 | 90 | train your model with the following settings: 91 | ``` 92 | data_dir=the preprocessed data directory 93 | python zero/run.py --mode train --parameters=hidden_size=1024,embed_size=512,\ 94 | dropout=0.1,label_smooth=0.1,\ 95 | max_len=80,batch_size=80,eval_batch_size=240,\ 96 | token_size=3000,batch_or_token='token',\ 97 | model_name="rnnsearch",scope_name="rnnsearch",buffer_size=3200,\ 98 | clip_grad_norm=5.0,\ 99 | lrate=5e-4,\ 100 | epoches=10,\ 101 | update_cycle=1,\ 102 | gpus=[3],\ 103 | disp_freq=100,\ 104 | eval_freq=10000,\ 105 | sample_freq=1000,\ 106 | checkpoints=5,\ 107 | caencoder=True,\ 108 | cell='atr',\ 109 | max_training_steps=100000000,\ 110 | nthreads=8,\ 111 | swap_memory=True,\ 112 | layer_norm=True,\ 113 | max_queue_size=100,\ 114 | random_seed=1234,\ 115 | src_vocab_file="$data_dir/vocab.en",\ 116 | tgt_vocab_file="$data_dir/vocab.de",\ 117 | src_train_file="$data_dir/train.32k.en.shuf",\ 118 | tgt_train_file="$data_dir/train.32k.de.shuf",\ 119 | src_dev_file="$data_dir/dev.32k.en",\ 120 | tgt_dev_file="$data_dir/dev.32k.de",\ 121 | src_test_file="",\ 122 | tgt_test_file="",\ 123 | output_dir="train",\ 124 | test_output="" 125 | ``` 126 | Model would be saved into directory `train` 127 | 128 | 6. Testing your model 129 | 130 | - Average your checkpoints which can give you better results. 131 | ``` 132 | python zero/scripts/checkpoint_averaging.py --checkpoints 5 --output avg --path ../train --gpu 0 133 | ``` 134 | - Then test your model with the following code 135 | ``` 136 | data_dir=the preprocessed data directory 137 | python zero/run.py --mode test --parameters=hidden_size=1024,embed_size=512,\ 138 | dropout=0.1,label_smooth=0.1,\ 139 | max_len=80,batch_size=80,eval_batch_size=240,\ 140 | token_size=3000,batch_or_token='token',\ 141 | model_name="rnnsearch",scope_name="rnnsearch",buffer_size=3200,\ 142 | clip_grad_norm=5.0,\ 143 | lrate=5e-4,\ 144 | epoches=10,\ 145 | update_cycle=1,\ 146 | gpus=[3],\ 147 | disp_freq=100,\ 148 | eval_freq=10000,\ 149 | sample_freq=1000,\ 150 | checkpoints=5,\ 151 | caencoder=True,\ 152 | cell='atr',\ 153 | max_training_steps=100000000,\ 154 | nthreads=8,\ 155 | swap_memory=True,\ 156 | layer_norm=True,\ 157 | max_queue_size=100,\ 158 | random_seed=1234,\ 159 | src_vocab_file="$data_dir/vocab.en",\ 160 | tgt_vocab_file="$data_dir/vocab.de",\ 161 | src_train_file="$data_dir/train.32k.en.shuf",\ 162 | tgt_train_file="$data_dir/train.32k.de.shuf",\ 163 | src_dev_file="$data_dir/dev.32k.en",\ 164 | tgt_dev_file="$data_dir/dev.32k.de",\ 165 | src_test_file="$data_dir/newstest14.32k.en",\ 166 | tgt_test_file="$data_dir/newstest14.de",\ 167 | output_dir="avg",\ 168 | test_output="newstest14.trans.bpe" 169 | ``` 170 | The final translation will be dumped into `newstest14.trans.bpe`. 171 | 172 | You need remove the BPE splitter as follows: `sed -r 's/(@@ )|(@@ ?$)//g' < newstest14.trans.bpe > newstest14.trans.txt` 173 | 174 | Then evaluate the BLEU score using [multi-bleu.perl](https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/generic/multi-bleu.perl): 175 | ```perl multi-bleu.perl $data_dir/newstest14.de < newstest14.trans.txt``` 176 | 177 | > Notice that the official evaluation has stated clearly that researchers should not use the multi-bleu.perl anymore, because it 178 | heavily relies on the tokenization schema. In fact, tokenization could have a strong influence to the 179 | final BLEU score, particularly when the aggressive mode is used. However, in current stage, multi-bleu.perl is still 180 | the most-widely used evaluation script ~~ 181 | 182 | 7. Command line or Seperate configuration file 183 | 184 | In case you dislike the long command line style, you can convert the parameters into a 185 | separate `config.py`. For the training example, you can convert the running comment into follows: 186 | ``` 187 | python zero/run.py --mode train --config config.py 188 | ``` 189 | where the `config.py` has the following structure: 190 | ``` 191 | dict( 192 | hidden_size=1024, 193 | embed_size=512, 194 | dropout=0.1, 195 | label_smooth=0.1, 196 | max_len=80, 197 | batch_size=80, 198 | eval_batch_size=240, 199 | token_size=3000, 200 | batch_or_token='token', 201 | model_name="rnnsearch", 202 | scope_name="rnnsearch", 203 | buffer_size=3200, 204 | clip_grad_norm=5.0, 205 | lrate=5e-4, 206 | epoches=10, 207 | update_cycle=1, 208 | gpus=[3], 209 | disp_freq=100, 210 | eval_freq=10000, 211 | sample_freq=1000, 212 | checkpoints=5, 213 | caencoder=True, 214 | cell='atr', 215 | max_training_steps=100000000, 216 | nthreads=8, 217 | swap_memory=True, 218 | layer_norm=True, 219 | max_queue_size=100, 220 | random_seed=1234, 221 | src_vocab_file="$data_dir/vocab.en", 222 | tgt_vocab_file="$data_dir/vocab.de", 223 | src_train_file="$data_dir/train.32k.en.shuf", 224 | tgt_train_file="$data_dir/train.32k.de.shuf", 225 | src_dev_file="$data_dir/dev.32k.en", 226 | tgt_dev_file="$data_dir/dev.32k.de", 227 | src_test_file="", 228 | tgt_test_file="", 229 | output_dir="train", 230 | test_output="", 231 | ) 232 | ``` 233 | 234 | 235 | And That's it! 236 | -------------------------------------------------------------------------------- /evalu.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import time 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | from utils import queuer, util, metric 12 | 13 | 14 | def decode_target_token(id_seq, vocab): 15 | """Convert sequence ids into tokens""" 16 | valid_id_seq = [] 17 | for tok_id in id_seq: 18 | if tok_id == vocab.eos() \ 19 | or tok_id == vocab.pad(): 20 | break 21 | valid_id_seq.append(tok_id) 22 | return vocab.to_tokens(valid_id_seq) 23 | 24 | 25 | def decode_hypothesis(seqs, scores, params, mask=None): 26 | """Generate decoded sequence from seqs""" 27 | if mask is None: 28 | mask = [1.] * len(seqs) 29 | 30 | hypoes = [] 31 | marks = [] 32 | for _seqs, _scores, _m in zip(seqs, scores, mask): 33 | if _m < 1.: continue 34 | 35 | for seq, score in zip(_seqs, _scores): 36 | # Temporarily, Use top-1 decoding 37 | best_seq = seq[0] 38 | best_score = score[0] 39 | 40 | hypo = decode_target_token(best_seq, params.tgt_vocab) 41 | mark = best_score 42 | 43 | hypoes.append(hypo) 44 | marks.append(mark) 45 | 46 | return hypoes, marks 47 | 48 | 49 | def decoding(session, features, out_seqs, out_scores, dataset, params): 50 | """Performing decoding with exising information""" 51 | translations = [] 52 | scores = [] 53 | indices = [] 54 | 55 | eval_queue = queuer.EnQueuer( 56 | dataset.batcher(params.eval_batch_size, 57 | buffer_size=params.buffer_size, 58 | shuffle=False, 59 | train=False), 60 | lambda x: x, 61 | worker_processes_num=params.process_num, 62 | input_queue_size=params.input_queue_size, 63 | output_queue_size=params.output_queue_size, 64 | ) 65 | 66 | def _predict_one_batch(_data_on_gpu): 67 | feed_dicts = {} 68 | 69 | _step_indices = [] 70 | for fidx, shard_data in enumerate(_data_on_gpu): 71 | # define feed_dict 72 | _feed_dict = { 73 | features[fidx]["source"]: shard_data['src'], 74 | } 75 | feed_dicts.update(_feed_dict) 76 | 77 | # collect data indices 78 | _step_indices.extend(shard_data['index']) 79 | 80 | # pick up valid outputs 81 | data_size = len(_data_on_gpu) 82 | valid_out_seqs = out_seqs[:data_size] 83 | valid_out_scores = out_scores[:data_size] 84 | 85 | _decode_seqs, _decode_scores = session.run( 86 | [valid_out_seqs, valid_out_scores], feed_dict=feed_dicts) 87 | 88 | _step_translations, _step_scores = decode_hypothesis( 89 | _decode_seqs, _decode_scores, params 90 | ) 91 | 92 | return _step_translations, _step_scores, _step_indices 93 | 94 | very_begin_time = time.time() 95 | data_on_gpu = [] 96 | for bidx, data in enumerate(eval_queue): 97 | if bidx == 0: 98 | # remove the data reading time 99 | very_begin_time = time.time() 100 | 101 | data_on_gpu.append(data) 102 | # use multiple gpus, and data samples is not enough 103 | if len(params.gpus) > 0 and len(data_on_gpu) < len(params.gpus): 104 | continue 105 | 106 | start_time = time.time() 107 | step_outputs = _predict_one_batch(data_on_gpu) 108 | data_on_gpu = [] 109 | 110 | translations.extend(step_outputs[0]) 111 | scores.extend(step_outputs[1]) 112 | indices.extend(step_outputs[2]) 113 | 114 | tf.logging.info( 115 | "Decoding Batch {} using {:.3f} s, translating {} " 116 | "sentences using {:.3f} s in total".format( 117 | bidx, time.time() - start_time, 118 | len(translations), time.time() - very_begin_time 119 | ) 120 | ) 121 | 122 | if len(data_on_gpu) > 0: 123 | 124 | start_time = time.time() 125 | step_outputs = _predict_one_batch(data_on_gpu) 126 | 127 | translations.extend(step_outputs[0]) 128 | scores.extend(step_outputs[1]) 129 | indices.extend(step_outputs[2]) 130 | 131 | tf.logging.info( 132 | "Decoding Batch {} using {:.3f} s, translating {} " 133 | "sentences using {:.3f} s in total".format( 134 | 'final', time.time() - start_time, 135 | len(translations), time.time() - very_begin_time 136 | ) 137 | ) 138 | 139 | return translations, scores, indices 140 | 141 | 142 | def scoring(session, features, out_scores, dataset, params): 143 | """Performing decoding with exising information""" 144 | scores = [] 145 | indices = [] 146 | 147 | eval_queue = queuer.EnQueuer( 148 | dataset.batcher(params.eval_batch_size, 149 | buffer_size=params.buffer_size, 150 | shuffle=False, 151 | train=False), 152 | lambda x: x, 153 | worker_processes_num=params.process_num, 154 | input_queue_size=params.input_queue_size, 155 | output_queue_size=params.output_queue_size, 156 | ) 157 | 158 | total_entropy = 0. 159 | total_tokens = 0. 160 | 161 | def _predict_one_batch(_data_on_gpu): 162 | feed_dicts = {} 163 | 164 | _step_indices = [] 165 | for fidx, shard_data in enumerate(_data_on_gpu): 166 | # define feed_dict 167 | _feed_dict = { 168 | features[fidx]["source"]: shard_data['src'], 169 | features[fidx]["target"]: shard_data['tgt'], 170 | } 171 | feed_dicts.update(_feed_dict) 172 | 173 | # collect data indices 174 | _step_indices.extend(shard_data['index']) 175 | 176 | # pick up valid outputs 177 | data_size = len(_data_on_gpu) 178 | valid_out_scores = out_scores[:data_size] 179 | 180 | _decode_scores = session.run( 181 | valid_out_scores, feed_dict=feed_dicts) 182 | 183 | _batch_entropy = sum([s * float((d > 0).sum()) 184 | for shard_data, shard_scores in zip(_data_on_gpu, _decode_scores) 185 | for d, s in zip(shard_data['tgt'], shard_scores.tolist())]) 186 | _batch_tokens = sum([(shard_data['tgt'] > 0).sum() for shard_data in _data_on_gpu]) 187 | 188 | _decode_scores = [s for _scores in _decode_scores for s in _scores] 189 | 190 | return _decode_scores, _step_indices, _batch_entropy, _batch_tokens 191 | 192 | very_begin_time = time.time() 193 | data_on_gpu = [] 194 | for bidx, data in enumerate(eval_queue): 195 | if bidx == 0: 196 | # remove the data reading time 197 | very_begin_time = time.time() 198 | 199 | data_on_gpu.append(data) 200 | # use multiple gpus, and data samples is not enough 201 | if len(params.gpus) > 0 and len(data_on_gpu) < len(params.gpus): 202 | continue 203 | 204 | start_time = time.time() 205 | step_outputs = _predict_one_batch(data_on_gpu) 206 | data_on_gpu = [] 207 | 208 | scores.extend(step_outputs[0]) 209 | indices.extend(step_outputs[1]) 210 | 211 | total_entropy += step_outputs[2] 212 | total_tokens += step_outputs[3] 213 | 214 | tf.logging.info( 215 | "Decoding Batch {} using {:.3f} s, translating {} " 216 | "sentences using {:.3f} s in total".format( 217 | bidx, time.time() - start_time, 218 | len(scores), time.time() - very_begin_time 219 | ) 220 | ) 221 | 222 | if len(data_on_gpu) > 0: 223 | 224 | start_time = time.time() 225 | step_outputs = _predict_one_batch(data_on_gpu) 226 | 227 | scores.extend(step_outputs[0]) 228 | indices.extend(step_outputs[1]) 229 | 230 | total_entropy += step_outputs[2] 231 | total_tokens += step_outputs[3] 232 | 233 | tf.logging.info( 234 | "Decoding Batch {} using {:.3f} s, translating {} " 235 | "sentences using {:.3f} s in total".format( 236 | 'final', time.time() - start_time, 237 | len(scores), time.time() - very_begin_time 238 | ) 239 | ) 240 | 241 | scores = [data[1] for data in 242 | sorted(zip(indices, scores), key=lambda x: x[0])] 243 | 244 | ppl = np.exp(total_entropy / total_tokens) 245 | 246 | return scores, ppl 247 | 248 | 249 | def eval_metric(trans, target_file, indices=None): 250 | """BLEU Evaluate """ 251 | target_valid_files = util.fetch_valid_ref_files(target_file) 252 | if target_valid_files is None: 253 | return 0.0 254 | 255 | if indices is not None: 256 | trans = [data[1] for data in sorted(zip(indices, trans), key=lambda x: x[0])] 257 | 258 | references = [] 259 | for ref_file in target_valid_files: 260 | cur_refs = tf.gfile.Open(ref_file).readlines() 261 | cur_refs = [line.strip().split() for line in cur_refs] 262 | references.append(cur_refs) 263 | 264 | references = list(zip(*references)) 265 | 266 | return metric.bleu(trans, references) 267 | 268 | 269 | def dump_tanslation(tranes, output, indices=None): 270 | """save translation""" 271 | if indices is not None: 272 | tranes = [data[1] for data in 273 | sorted(zip(indices, tranes), key=lambda x: x[0])] 274 | with tf.gfile.Open(output, 'w') as writer: 275 | for hypo in tranes: 276 | if isinstance(hypo, list): 277 | writer.write(' '.join(hypo) + "\n") 278 | else: 279 | writer.write(str(hypo) + "\n") 280 | tf.logging.info("Saving translations into {}".format(output)) 281 | -------------------------------------------------------------------------------- /lrs/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from lrs import vanillalr, noamlr, scorelr, gnmtplr, epochlr, cosinelr 4 | 5 | 6 | def get_lr(params): 7 | 8 | strategy = params.lrate_strategy.lower() 9 | 10 | if strategy == "noam": 11 | return noamlr.NoamDecayLr( 12 | params.lrate, 13 | params.min_lrate, 14 | params.max_lrate, 15 | params.warmup_steps, 16 | params.hidden_size 17 | ) 18 | elif strategy == "gnmt+": 19 | return gnmtplr.GNMTPDecayLr( 20 | params.lrate, 21 | params.min_lrate, 22 | params.max_lrate, 23 | params.warmup_steps, 24 | params.nstable, 25 | params.lrdecay_start, 26 | params.lrdecay_end 27 | ) 28 | elif strategy == "epoch": 29 | return epochlr.EpochDecayLr( 30 | params.lrate, 31 | params.min_lrate, 32 | params.max_lrate, 33 | params.lrate_decay, 34 | ) 35 | elif strategy == "score": 36 | return scorelr.ScoreDecayLr( 37 | params.lrate, 38 | params.min_lrate, 39 | params.max_lrate, 40 | history_scores=[v[1] for v in params.recorder.valid_script_scores], 41 | decay=params.lrate_decay, 42 | patience=params.lrate_patience, 43 | ) 44 | elif strategy == "vanilla": 45 | return vanillalr.VanillaLR( 46 | params.lrate, 47 | params.min_lrate, 48 | params.max_lrate, 49 | ) 50 | elif strategy == "cosine": 51 | return cosinelr.CosineDecayLr( 52 | params.lrate, 53 | params.min_lrate, 54 | params.max_lrate, 55 | params.warmup_steps, 56 | params.lrate_decay, 57 | t_mult=params.cosine_factor, 58 | update_period=params.cosine_period 59 | ) 60 | else: 61 | raise NotImplementedError( 62 | "{} is not supported".format(strategy)) 63 | -------------------------------------------------------------------------------- /lrs/cosinelr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import math 8 | 9 | from lrs import lr 10 | 11 | 12 | class CosineDecayLr(lr.Lr): 13 | """Decay the learning rate during each training step, follows FairSeq""" 14 | def __init__(self, 15 | init_lr, # initial learning rate => warmup_init_lr 16 | min_lr, # minimum learning rate 17 | max_lr, # maximum learning rate 18 | warmup_steps, # warmup step => warmup_updates 19 | decay, # learning rate shrink factor for annealing 20 | t_mult=1, # factor to grow the length of each period 21 | update_period=5000, # initial number of updates per period 22 | name="cosine_decay_lr" # model name, no use 23 | ): 24 | super(CosineDecayLr, self).__init__(init_lr, min_lr, max_lr, name=name) 25 | 26 | self.warmup_steps = warmup_steps 27 | 28 | self.warmup_init_lr = init_lr 29 | self.warmup_end_lr = max_lr 30 | self.t_mult = t_mult 31 | self.period = update_period 32 | 33 | if self.warmup_steps > 0: 34 | self.lr_step = (self.warmup_end_lr - self.warmup_init_lr) / self.warmup_steps 35 | else: 36 | self.lr_step = 1. 37 | 38 | self.decay = decay 39 | 40 | # initial learning rate 41 | self.lrate = init_lr 42 | 43 | def step(self, step): 44 | if step < self.warmup_steps: 45 | self.lrate = self.warmup_init_lr + step * self.lr_step 46 | else: 47 | curr_updates = step - self.warmup_steps 48 | if self.t_mult != 1: 49 | i = math.floor(math.log(1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult)) 50 | t_i = self.t_mult ** i * self.period 51 | t_curr = curr_updates - (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period 52 | else: 53 | i = math.floor(curr_updates / self.period) 54 | t_i = self.period 55 | t_curr = curr_updates - (self.period * i) 56 | 57 | lr_shrink = self.decay ** i 58 | min_lr = self.min_lrate * lr_shrink 59 | max_lr = self.max_lrate * lr_shrink 60 | 61 | self.lrate = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i)) 62 | 63 | return self.lrate 64 | -------------------------------------------------------------------------------- /lrs/epochlr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | 8 | from lrs import lr 9 | 10 | 11 | class EpochDecayLr(lr.Lr): 12 | """Decay the learning rate after each epoch""" 13 | def __init__(self, 14 | init_lr, 15 | min_lr, # minimum learning rate 16 | max_lr, # maximum learning rate 17 | decay=0.5, # learning rate decay rate 18 | name="epoch_decay_lr" 19 | ): 20 | super(EpochDecayLr, self).__init__(init_lr, min_lr, max_lr, name=name) 21 | 22 | self.decay = decay 23 | 24 | def after_epoch(self, eidx=None): 25 | if eidx is None: 26 | self.lrate = self.init_lrate * self.decay 27 | else: 28 | self.lrate = self.init_lrate * self.decay ** int(eidx) 29 | -------------------------------------------------------------------------------- /lrs/gnmtplr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | 9 | from lrs import lr 10 | 11 | 12 | class GNMTPDecayLr(lr.Lr): 13 | """Decay the learning rate during each training step, follows GNMT+""" 14 | def __init__(self, 15 | init_lr, # initial learning rate 16 | min_lr, # minimum learning rate 17 | max_lr, # maximum learning rate 18 | warmup_steps, # warmup step 19 | nstable, # number of replica 20 | lrdecay_start, # start of learning rate decay 21 | lrdecay_end, # end of learning rate decay 22 | name="gnmtp_decay_lr" # model name, no use 23 | ): 24 | super(GNMTPDecayLr, self).__init__(init_lr, min_lr, max_lr, name=name) 25 | 26 | self.warmup_steps = warmup_steps 27 | self.nstable = nstable 28 | self.lrdecay_start = lrdecay_start 29 | self.lrdecay_end = lrdecay_end 30 | 31 | if nstable < 1: 32 | raise Exception("Stabled Lrate Value should " 33 | "greater than 0, but is {}".format(nstable)) 34 | 35 | def step(self, step): 36 | t = float(step) 37 | p = float(self.warmup_steps) 38 | n = float(self.nstable) 39 | s = float(self.lrdecay_start) 40 | e = float(self.lrdecay_end) 41 | 42 | decay = np.minimum(1. + t * (n - 1) / (n * p), n) 43 | decay = np.minimum(decay, n * (2 * n) ** ((s - n * t) / (e - s))) 44 | 45 | self.lrate = self.init_lrate * decay 46 | -------------------------------------------------------------------------------- /lrs/lr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | 8 | # This is an abstract class that deals with 9 | # different learning rate decay strategy 10 | # Generally, we decay the learning rate with GPU computation 11 | # However, in this paper, we simply decay the learning rate 12 | # at CPU level, and feed the decayed lr into GPU for 13 | # optimization 14 | class Lr(object): 15 | def __init__(self, 16 | init_lrate, # initial learning rate 17 | min_lrate, # minimum learning rate 18 | max_lrate, # maximum learning rate 19 | name="lr", # learning rate name, no use 20 | ): 21 | self.name = name 22 | self.init_lrate = init_lrate # just record the init learning rate 23 | self.lrate = init_lrate # active learning rate, change with training 24 | self.min_lrate = min_lrate 25 | self.max_lrate = max_lrate 26 | 27 | assert self.max_lrate > self.min_lrate, "Minimum learning rate " \ 28 | "should less than maximum learning rate" 29 | 30 | # suppose the eidx starts from 1 31 | def before_epoch(self, eidx=None): 32 | pass 33 | 34 | def after_epoch(self, eidx=None): 35 | pass 36 | 37 | def step(self, step): 38 | pass 39 | 40 | def after_eval(self, eval_score): 41 | pass 42 | 43 | def get_lr(self): 44 | """Return the learning rate whenever you want""" 45 | return max(min(self.lrate, self.max_lrate), self.min_lrate) 46 | -------------------------------------------------------------------------------- /lrs/noamlr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | 9 | from lrs import lr 10 | 11 | 12 | class NoamDecayLr(lr.Lr): 13 | """Decay the learning rate during each training step, follows Transformer""" 14 | def __init__(self, 15 | init_lr, # initial learning rate 16 | min_lr, # minimum learning rate 17 | max_lr, # maximum learning rate 18 | warmup_steps, # warmup step 19 | hidden_size, # model hidden size 20 | name="noam_decay_lr" # model name, no use 21 | ): 22 | super(NoamDecayLr, self).__init__(init_lr, min_lr, max_lr, name=name) 23 | 24 | self.warmup_steps = warmup_steps 25 | self.hidden_size = hidden_size 26 | 27 | def step(self, step): 28 | step = float(step) 29 | warmup_steps = float(self.warmup_steps) 30 | 31 | multiplier = float(self.hidden_size) ** -0.5 32 | decay = multiplier * np.minimum((step + 1) * (warmup_steps ** -1.5), 33 | (step + 1) ** -0.5) 34 | self.lrate = self.init_lrate * decay 35 | -------------------------------------------------------------------------------- /lrs/scorelr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | 8 | from lrs import lr 9 | 10 | 11 | class ScoreDecayLr(lr.Lr): 12 | """Decay the learning rate after each evaluation""" 13 | def __init__(self, 14 | init_lr, 15 | min_lr, # minimum learning rate 16 | max_lr, # maximum learning rate 17 | history_scores=None, # evaluation history metric scores, such as BLEU 18 | decay=0.5, # learning rate decay rate 19 | patience=1, # decay after this number of bad counter 20 | name="score_decay_lr" # model name, no use 21 | ): 22 | super(ScoreDecayLr, self).__init__(init_lr, min_lr, max_lr, name=name) 23 | 24 | self.decay = decay 25 | self.patience = patience 26 | self.bad_counter = 0 27 | self.best_score = -1e9 28 | 29 | if history_scores is not None: 30 | for score in history_scores: 31 | self.after_eval(score[1]) 32 | 33 | def after_eval(self, eval_score): 34 | if eval_score > self.best_score: 35 | self.best_score = eval_score 36 | self.bad_counter = 0 37 | else: 38 | self.bad_counter += 1 39 | if self.bad_counter >= self.patience: 40 | self.lrate = self.lrate * self.decay 41 | 42 | self.bad_counter = 0 43 | -------------------------------------------------------------------------------- /lrs/vanillalr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | 8 | from lrs import lr 9 | 10 | 11 | class VanillaLR(lr.Lr): 12 | """Very basic learning rate, constant learning rate""" 13 | def __init__(self, 14 | init_lr, # learning rate 15 | min_lr, # minimum learning rate 16 | max_lr, # maximum learning rate 17 | name="vanilla_lr" 18 | ): 19 | super(VanillaLR, self).__init__(init_lr, min_lr, max_lr, name=name) 20 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | from collections import namedtuple 9 | 10 | # global models defined in Zero 11 | _total_models = {} 12 | 13 | 14 | class ModelWrapper(namedtuple("ModelTupleWrapper", 15 | ("train_fn", "score_fn", "infer_fn"))): 16 | pass 17 | 18 | 19 | # you need register your model by your self 20 | def model_register(model_name, train_fn, score_fn, infer_fn): 21 | model_name = model_name.lower() 22 | 23 | if model_name in _total_models: 24 | raise Exception("Conflict Model Name: {}".format(model_name)) 25 | 26 | tf.logging.info("Registering model: {}".format(model_name)) 27 | 28 | _total_models[model_name] = ModelWrapper( 29 | train_fn=train_fn, 30 | score_fn=score_fn, 31 | infer_fn=infer_fn, 32 | ) 33 | 34 | 35 | def get_model(model_name): 36 | model_name = model_name.lower() 37 | 38 | if model_name in _total_models: 39 | return _total_models[model_name] 40 | 41 | raise Exception("No supported model {}".format(model_name)) 42 | -------------------------------------------------------------------------------- /models/rnnsearch.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import copy 8 | import tensorflow as tf 9 | 10 | from func import linear 11 | from models import model 12 | from utils import util, dtype 13 | from rnns import rnn 14 | 15 | 16 | def encoder(source, params): 17 | mask = dtype.tf_to_float(tf.cast(source, tf.bool)) 18 | hidden_size = params.hidden_size 19 | 20 | source, mask = util.remove_invalid_seq(source, mask) 21 | 22 | embed_name = "embedding" if params.shared_source_target_embedding \ 23 | else "src_embedding" 24 | src_emb = tf.get_variable(embed_name, 25 | [params.src_vocab.size(), params.embed_size]) 26 | src_bias = tf.get_variable("bias", [params.embed_size]) 27 | 28 | inputs = tf.gather(src_emb, source) 29 | inputs = tf.nn.bias_add(inputs, src_bias) 30 | 31 | inputs = util.valid_apply_dropout(inputs, params.dropout) 32 | 33 | with tf.variable_scope("encoder"): 34 | # forward rnn 35 | with tf.variable_scope('forward'): 36 | outputs = rnn.rnn(params.cell, inputs, hidden_size, mask=mask, 37 | ln=params.layer_norm, sm=params.swap_memory) 38 | output_fw, state_fw = outputs[1] 39 | # backward rnn 40 | with tf.variable_scope('backward'): 41 | if not params.caencoder: 42 | outputs = rnn.rnn(params.cell, tf.reverse(inputs, [1]), 43 | hidden_size, mask=tf.reverse(mask, [1]), 44 | ln=params.layer_norm, sm=params.swap_memory) 45 | output_bw, state_bw = outputs[1] 46 | else: 47 | outputs = rnn.cond_rnn(params.cell, tf.reverse(inputs, [1]), 48 | tf.reverse(output_fw, [1]), hidden_size, 49 | mask=tf.reverse(mask, [1]), 50 | ln=params.layer_norm, 51 | sm=params.swap_memory, 52 | num_heads=params.num_heads, 53 | one2one=True) 54 | output_bw, state_bw = outputs[1] 55 | 56 | output_bw = tf.reverse(output_bw, [1]) 57 | 58 | if not params.caencoder: 59 | source_encodes = tf.concat([output_fw, output_bw], -1) 60 | source_feature = tf.concat([state_fw, state_bw], -1) 61 | else: 62 | source_encodes = output_bw 63 | source_feature = state_bw 64 | 65 | with tf.variable_scope("decoder_initializer"): 66 | decoder_init = rnn.get_cell( 67 | params.cell, hidden_size, ln=params.layer_norm 68 | ).get_init_state(x=source_feature) 69 | decoder_init = tf.tanh(decoder_init) 70 | 71 | return { 72 | "encodes": source_encodes, 73 | "decoder_initializer": decoder_init, 74 | "mask": mask 75 | } 76 | 77 | 78 | def decoder(target, state, params): 79 | mask = dtype.tf_to_float(tf.cast(target, tf.bool)) 80 | hidden_size = params.hidden_size 81 | 82 | is_training = ('decoder' not in state) 83 | 84 | if is_training: 85 | target, mask = util.remove_invalid_seq(target, mask) 86 | 87 | embed_name = "embedding" if params.shared_source_target_embedding \ 88 | else "tgt_embedding" 89 | tgt_emb = tf.get_variable(embed_name, 90 | [params.tgt_vocab.size(), params.embed_size]) 91 | tgt_bias = tf.get_variable("bias", [params.embed_size]) 92 | 93 | inputs = tf.gather(tgt_emb, target) 94 | inputs = tf.nn.bias_add(inputs, tgt_bias) 95 | 96 | # shift 97 | if is_training: 98 | inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]]) 99 | inputs = inputs[:, :-1, :] 100 | else: 101 | inputs = tf.cond(tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())), 102 | lambda: tf.zeros_like(inputs), 103 | lambda: inputs) 104 | mask = tf.ones_like(mask) 105 | 106 | inputs = util.valid_apply_dropout(inputs, params.dropout) 107 | 108 | with tf.variable_scope("decoder"): 109 | init_state = state["decoder_initializer"] 110 | if not is_training: 111 | init_state = state["decoder"]["state"] 112 | returns = rnn.cond_rnn(params.cell, inputs, state["encodes"], hidden_size, 113 | init_state=init_state, mask=mask, 114 | mem_mask=state["mask"], ln=params.layer_norm, 115 | sm=params.swap_memory, one2one=False) 116 | (_, hidden_state), (outputs, _), contexts, attentions = returns 117 | 118 | feature = linear([outputs, contexts, inputs], params.embed_size, 119 | ln=params.layer_norm, scope="pre_logits") 120 | if 'dev_decode' in state: 121 | feature = feature[:, -1, :] 122 | 123 | feature = tf.tanh(feature) 124 | feature = util.valid_apply_dropout(feature, params.dropout) 125 | 126 | embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \ 127 | else "softmax_embedding" 128 | embed_name = "embedding" if params.shared_source_target_embedding \ 129 | else embed_name 130 | softmax_emb = tf.get_variable(embed_name, 131 | [params.tgt_vocab.size(), params.embed_size]) 132 | feature = tf.reshape(feature, [-1, params.embed_size]) 133 | logits = tf.matmul(feature, softmax_emb, False, True) 134 | 135 | logits = tf.cast(logits, tf.float32) 136 | 137 | soft_label, normalizer = util.label_smooth( 138 | target, 139 | util.shape_list(logits)[-1], 140 | factor=params.label_smooth) 141 | centropy = tf.nn.softmax_cross_entropy_with_logits_v2( 142 | logits=logits, 143 | labels=soft_label 144 | ) 145 | centropy -= normalizer 146 | centropy = tf.reshape(centropy, tf.shape(target)) 147 | 148 | mask = tf.cast(mask, tf.float32) 149 | per_sample_loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum(mask, -1) 150 | loss = tf.reduce_mean(per_sample_loss) 151 | 152 | # these mask tricks mainly used to deal with zero shapes, such as [0, 1] 153 | loss = tf.cond(tf.equal(tf.shape(target)[0], 0), 154 | lambda: tf.constant(0, dtype=tf.float32), 155 | lambda: loss) 156 | 157 | if not is_training: 158 | state['decoder']['state'] = hidden_state 159 | 160 | return loss, logits, state, per_sample_loss 161 | 162 | 163 | def train_fn(features, params, initializer=None): 164 | with tf.variable_scope(params.scope_name or "model", 165 | initializer=initializer, 166 | reuse=tf.AUTO_REUSE, 167 | dtype=tf.as_dtype(dtype.floatx()), 168 | custom_getter=dtype.float32_variable_storage_getter): 169 | state = encoder(features['source'], params) 170 | loss, logits, state, _ = decoder(features['target'], state, params) 171 | 172 | return { 173 | "loss": loss 174 | } 175 | 176 | 177 | def score_fn(features, params, initializer=None): 178 | params = copy.copy(params) 179 | params = util.closing_dropout(params) 180 | params.label_smooth = 0.0 181 | with tf.variable_scope(params.scope_name or "model", 182 | initializer=initializer, 183 | reuse=tf.AUTO_REUSE, 184 | dtype=tf.as_dtype(dtype.floatx()), 185 | custom_getter=dtype.float32_variable_storage_getter): 186 | state = encoder(features['source'], params) 187 | _, _, _, scores = decoder(features['target'], state, params) 188 | 189 | return { 190 | "score": scores 191 | } 192 | 193 | 194 | def infer_fn(params): 195 | params = copy.copy(params) 196 | params = util.closing_dropout(params) 197 | 198 | def encoding_fn(source): 199 | with tf.variable_scope(params.scope_name or "model", 200 | reuse=tf.AUTO_REUSE, 201 | dtype=tf.as_dtype(dtype.floatx()), 202 | custom_getter=dtype.float32_variable_storage_getter): 203 | state = encoder(source, params) 204 | state["decoder"] = { 205 | "state": state["decoder_initializer"] 206 | } 207 | return state 208 | 209 | def decoding_fn(target, state, time): 210 | with tf.variable_scope(params.scope_name or "model", 211 | reuse=tf.AUTO_REUSE, 212 | dtype=tf.as_dtype(dtype.floatx()), 213 | custom_getter=dtype.float32_variable_storage_getter): 214 | if params.search_mode == "cache": 215 | step_loss, step_logits, step_state, _ = decoder( 216 | target, state, params) 217 | else: 218 | estate = encoder(state, params) 219 | estate['dev_decode'] = True 220 | _, step_logits, _, _ = decoder(target, estate, params) 221 | step_state = state 222 | 223 | return step_logits, step_state 224 | 225 | return encoding_fn, decoding_fn 226 | 227 | 228 | # register the model, with a unique name 229 | model.model_register("rnnsearch", train_fn, score_fn, infer_fn) 230 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | -------------------------------------------------------------------------------- /modules/fixup.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import math 8 | import tensorflow as tf 9 | 10 | import func 11 | from utils import util, dtype 12 | from modules import rpr, initializer 13 | 14 | 15 | def shift_layer(x, scope="shift"): 16 | with tf.variable_scope(scope or "shift"): 17 | offset = tf.get_variable("offset", [1], initializer=tf.zeros_initializer()) 18 | return x - offset 19 | 20 | 21 | def scale_layer(x, init=1., scope="scale"): 22 | with tf.variable_scope(scope or "scale"): 23 | scale = tf.get_variable( 24 | "scale", [1], 25 | initializer=initializer.scale_initializer(init, tf.ones_initializer())) 26 | return x * scale 27 | 28 | 29 | def ffn_layer(x, d, d_o, dropout=None, scope=None, numblocks=None): 30 | """ 31 | FFN layer in Transformer 32 | :param numblocks: size of 'L' in fixup paper 33 | :param scope: 34 | """ 35 | with tf.variable_scope(scope or "ffn_layer", 36 | dtype=tf.as_dtype(dtype.floatx())) as scope: 37 | assert numblocks is not None, 'Fixup requires the total model depth L' 38 | 39 | in_initializer = initializer.scale_initializer( 40 | math.pow(numblocks, -1. / 2.), scope.initializer) 41 | 42 | x = shift_layer(x) 43 | hidden = func.linear(x, d, scope="enlarge", 44 | weight_initializer=in_initializer, bias=False) 45 | hidden = shift_layer(hidden) 46 | hidden = tf.nn.relu(hidden) 47 | 48 | hidden = util.valid_apply_dropout(hidden, dropout) 49 | 50 | hidden = shift_layer(hidden) 51 | output = func.linear(hidden, d_o, scope="output", bias=False, 52 | weight_initializer=tf.zeros_initializer()) 53 | output = scale_layer(output) 54 | 55 | return output 56 | 57 | 58 | def dot_attention(query, memory, mem_mask, hidden_size, 59 | ln=False, num_heads=1, cache=None, dropout=None, 60 | use_relative_pos=False, max_relative_position=16, 61 | out_map=True, scope=None, fuse_mask=None, 62 | decode_step=None, numblocks=None): 63 | """ 64 | dotted attention model 65 | :param query: [batch_size, qey_len, dim] 66 | :param memory: [batch_size, seq_len, mem_dim] or None 67 | :param mem_mask: [batch_size, seq_len] 68 | :param hidden_size: attention space dimension 69 | :param ln: whether use layer normalization 70 | :param num_heads: attention head number 71 | :param dropout: attention dropout, default disable 72 | :param out_map: output additional mapping 73 | :param cache: cache-based decoding 74 | :param fuse_mask: aan mask during training, and timestep for testing 75 | :param max_relative_position: maximum position considered for relative embedding 76 | :param use_relative_pos: whether use relative position information 77 | :param decode_step: the time step of current decoding, 0-based 78 | :param numblocks: size of 'L' in fixup paper 79 | :param scope: 80 | :return: a value matrix, [batch_size, qey_len, mem_dim] 81 | """ 82 | with tf.variable_scope(scope or "dot_attention", reuse=tf.AUTO_REUSE, 83 | dtype=tf.as_dtype(dtype.floatx())) as scope: 84 | if fuse_mask: 85 | assert memory is not None, 'Fuse mechanism only applied with cross-attention' 86 | if cache and use_relative_pos: 87 | assert decode_step is not None, 'Decode Step must provide when use relative position encoding' 88 | 89 | assert numblocks is not None, 'Fixup requires the total model depth L' 90 | 91 | scale_base = 6. if fuse_mask is None else 8. 92 | in_initializer = initializer.scale_initializer( 93 | math.pow(numblocks, -1. / scale_base), scope.initializer) 94 | 95 | if memory is None: 96 | # suppose self-attention from queries alone 97 | h = func.linear(query, hidden_size * 3, ln=ln, scope="qkv_map", 98 | weight_initializer=in_initializer, bias=False) 99 | q, k, v = tf.split(h, 3, -1) 100 | 101 | if cache is not None: 102 | k = tf.concat([cache['k'], k], axis=1) 103 | v = tf.concat([cache['v'], v], axis=1) 104 | cache = { 105 | 'k': k, 106 | 'v': v, 107 | } 108 | else: 109 | q = func.linear(query, hidden_size, ln=ln, scope="q_map", 110 | weight_initializer=in_initializer, bias=False) 111 | if cache is not None and ('mk' in cache and 'mv' in cache): 112 | k, v = cache['mk'], cache['mv'] 113 | else: 114 | k = func.linear(memory, hidden_size, ln=ln, scope="k_map", 115 | weight_initializer=in_initializer, bias=False) 116 | v = func.linear(memory, hidden_size, ln=ln, scope="v_map", 117 | weight_initializer=in_initializer, bias=False) 118 | 119 | if cache is not None: 120 | cache['mk'] = k 121 | cache['mv'] = v 122 | 123 | q = func.split_heads(q, num_heads) 124 | k = func.split_heads(k, num_heads) 125 | v = func.split_heads(v, num_heads) 126 | 127 | q *= (hidden_size // num_heads) ** (-0.5) 128 | 129 | q_shp = util.shape_list(q) 130 | k_shp = util.shape_list(k) 131 | v_shp = util.shape_list(v) 132 | 133 | q_len = q_shp[2] if decode_step is None else decode_step + 1 134 | r_lst = None if decode_step is None else 1 135 | 136 | # q * k => attention weights 137 | if use_relative_pos: 138 | r = rpr.get_relative_positions_embeddings( 139 | q_len, k_shp[2], k_shp[3], 140 | max_relative_position, name="rpr_keys", last=r_lst) 141 | logits = rpr.relative_attention_inner(q, k, r, transpose=True) 142 | else: 143 | logits = tf.matmul(q, k, transpose_b=True) 144 | 145 | if mem_mask is not None: 146 | logits += mem_mask 147 | 148 | weights = tf.nn.softmax(logits) 149 | 150 | dweights = util.valid_apply_dropout(weights, dropout) 151 | 152 | # weights * v => attention vectors 153 | if use_relative_pos: 154 | r = rpr.get_relative_positions_embeddings( 155 | q_len, k_shp[2], v_shp[3], 156 | max_relative_position, name="rpr_values", last=r_lst) 157 | o = rpr.relative_attention_inner(dweights, v, r, transpose=False) 158 | else: 159 | o = tf.matmul(dweights, v) 160 | 161 | o = func.combine_heads(o) 162 | 163 | if fuse_mask is not None: 164 | # This is for AAN, the important part is sharing v_map 165 | v_q = func.linear(query, hidden_size, ln=ln, scope="v_map", 166 | weight_initializer=in_initializer, bias=False) 167 | 168 | if cache is not None and 'aan' in cache: 169 | aan_o = (v_q + cache['aan']) / dtype.tf_to_float(fuse_mask + 1) 170 | else: 171 | # Simplified Average Attention Network 172 | aan_o = tf.matmul(fuse_mask, v_q) 173 | 174 | if cache is not None: 175 | if 'aan' not in cache: 176 | cache['aan'] = v_q 177 | else: 178 | cache['aan'] = v_q + cache['aan'] 179 | 180 | # Directly sum both self-attention and cross attention 181 | o = o + aan_o 182 | 183 | if out_map: 184 | o = func.linear(o, hidden_size, ln=ln, scope="o_map", 185 | weight_initializer=tf.zeros_initializer(), bias=False) 186 | 187 | results = { 188 | 'weights': weights, 189 | 'output': o, 190 | 'cache': cache 191 | } 192 | 193 | return results 194 | -------------------------------------------------------------------------------- /modules/initializer.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | from utils import dtype 9 | 10 | 11 | def get_initializer(initializer, initializer_gain): 12 | tfdtype = tf.as_dtype(dtype.floatx()) 13 | 14 | if initializer == "uniform": 15 | max_val = initializer_gain 16 | return tf.random_uniform_initializer(-max_val, max_val, dtype=tfdtype) 17 | elif initializer == "normal": 18 | return tf.random_normal_initializer(0.0, initializer_gain, dtype=tfdtype) 19 | elif initializer == "normal_unit_scaling": 20 | return tf.variance_scaling_initializer(initializer_gain, 21 | mode="fan_avg", 22 | distribution="normal", 23 | dtype=tfdtype) 24 | elif initializer == "uniform_unit_scaling": 25 | return tf.variance_scaling_initializer(initializer_gain, 26 | mode="fan_avg", 27 | distribution="uniform", 28 | dtype=tfdtype) 29 | else: 30 | tf.logging.warn("Unrecognized initializer: %s" % initializer) 31 | tf.logging.warn("Return to default initializer: glorot_uniform_initializer") 32 | return tf.glorot_uniform_initializer(dtype=tfdtype) 33 | 34 | 35 | def scale_initializer(scale, initializer): 36 | """Rescale the value given by initializer""" 37 | tfdtype = tf.as_dtype(dtype.floatx()) 38 | 39 | def _initializer(shape, dtype=tfdtype, partition_info=None): 40 | value = initializer(shape, dtype=dtype, partition_info=partition_info) 41 | value *= scale 42 | 43 | return value 44 | 45 | return _initializer 46 | -------------------------------------------------------------------------------- /modules/l0norm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Defines common utilities for l0-regularization layers.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | # Small constant value to add when taking logs or sqrts to avoid NaNs 24 | EPSILON = 1e-8 25 | 26 | # The default hard-concrete distribution parameters 27 | BETA = 2.0 / 3.0 28 | GAMMA = -0.1 29 | ZETA = 1.1 30 | 31 | 32 | def hard_concrete_sample( 33 | log_alpha, 34 | beta=BETA, 35 | gamma=GAMMA, 36 | zeta=ZETA, 37 | eps=EPSILON): 38 | """Sample values from the hard concrete distribution. 39 | 40 | The hard concrete distribution is described in 41 | https://arxiv.org/abs/1712.01312. 42 | 43 | Args: 44 | log_alpha: The log alpha parameters that control the "location" of the 45 | distribution. 46 | beta: The beta parameter, which controls the "temperature" of 47 | the distribution. Defaults to 2/3 from the above paper. 48 | gamma: The gamma parameter, which controls the lower bound of the 49 | stretched distribution. Defaults to -0.1 from the above paper. 50 | zeta: The zeta parameters, which controls the upper bound of the 51 | stretched distribution. Defaults to 1.1 from the above paper. 52 | eps: A small constant value to add to logs and sqrts to avoid NaNs. 53 | 54 | Returns: 55 | A tf.Tensor representing the output of the sampling operation. 56 | """ 57 | random_noise = tf.random_uniform( 58 | tf.shape(log_alpha), 59 | minval=0.0, 60 | maxval=1.0) 61 | 62 | # NOTE: We add a small constant value to the noise before taking the 63 | # log to avoid NaNs if a noise value is exactly zero. We sample values 64 | # in the range [0, 1), so the right log is not at risk of NaNs. 65 | gate_inputs = tf.log(random_noise + eps) - tf.log(1.0 - random_noise) 66 | gate_inputs = tf.sigmoid((gate_inputs + log_alpha) / beta) 67 | stretched_values = gate_inputs * (zeta - gamma) + gamma 68 | 69 | return tf.clip_by_value( 70 | stretched_values, 71 | clip_value_max=1.0, 72 | clip_value_min=0.0) 73 | 74 | 75 | def hard_concrete_mean(log_alpha, gamma=GAMMA, zeta=ZETA): 76 | """Calculate the mean of the hard concrete distribution. 77 | 78 | The hard concrete distribution is described in 79 | https://arxiv.org/abs/1712.01312. 80 | 81 | Args: 82 | log_alpha: The log alpha parameters that control the "location" of the 83 | distribution. 84 | gamma: The gamma parameter, which controls the lower bound of the 85 | stretched distribution. Defaults to -0.1 from the above paper. 86 | zeta: The zeta parameters, which controls the upper bound of the 87 | stretched distribution. Defaults to 1.1 from the above paper. 88 | 89 | Returns: 90 | A tf.Tensor representing the calculated means. 91 | """ 92 | stretched_values = tf.sigmoid(log_alpha) * (zeta - gamma) + gamma 93 | return tf.clip_by_value( 94 | stretched_values, 95 | clip_value_max=1.0, 96 | clip_value_min=0.0) 97 | 98 | 99 | def l0_norm( 100 | log_alpha, 101 | beta=BETA, 102 | gamma=GAMMA, 103 | zeta=ZETA): 104 | """Calculate the l0-regularization contribution to the loss. 105 | Args: 106 | log_alpha: Tensor of the log alpha parameters for the hard concrete 107 | distribution. 108 | beta: The beta parameter, which controls the "temperature" of 109 | the distribution. Defaults to 2/3 from the above paper. 110 | gamma: The gamma parameter, which controls the lower bound of the 111 | stretched distribution. Defaults to -0.1 from the above paper. 112 | zeta: The zeta parameters, which controls the upper bound of the 113 | stretched distribution. Defaults to 1.1 from the above paper. 114 | Returns: 115 | Scalar tensor containing the unweighted l0-regularization term contribution 116 | to the loss. 117 | """ 118 | # Value of the CDF of the hard-concrete distribution evaluated at 0 119 | reg_per_weight = tf.sigmoid(log_alpha - beta * tf.log(-gamma / zeta)) 120 | return reg_per_weight 121 | 122 | 123 | def var_train( 124 | weight_parameters, 125 | beta=BETA, 126 | gamma=GAMMA, 127 | zeta=ZETA, 128 | eps=EPSILON): 129 | """Model training, sampling hard concrete variables""" 130 | theta, log_alpha = weight_parameters 131 | 132 | # Sample the z values from the hard-concrete distribution 133 | weight_noise = hard_concrete_sample( 134 | log_alpha, 135 | beta, 136 | gamma, 137 | zeta, 138 | eps) 139 | weights = theta * weight_noise 140 | 141 | return weights, weight_noise 142 | 143 | 144 | def l0_regularization_loss(l0_norm_loss, 145 | reg_scalar=1.0, 146 | start_reg_ramp_up=0, 147 | end_reg_ramp_up=1000, 148 | warm_up=True): 149 | """Calculate the l0-norm weight for this iteration""" 150 | step = tf.train.get_or_create_global_step() 151 | current_step_reg = tf.maximum( 152 | 0.0, 153 | tf.cast(step - start_reg_ramp_up, tf.float32)) 154 | 155 | fraction_ramp_up_completed = tf.minimum( 156 | current_step_reg / (end_reg_ramp_up - start_reg_ramp_up), 1.0) 157 | 158 | if warm_up: 159 | # regularizer intensifies over the course of ramp-up 160 | reg_scalar = fraction_ramp_up_completed * reg_scalar 161 | 162 | l0_norm_loss = reg_scalar * l0_norm_loss 163 | return l0_norm_loss 164 | 165 | 166 | def var_eval( 167 | weight_parameters, 168 | gamma=GAMMA, 169 | zeta=ZETA): 170 | """Model evaluation, obtain mean value""" 171 | theta, log_alpha = weight_parameters 172 | 173 | # Use the mean of the learned hard-concrete distribution as the 174 | # deterministic weight noise at evaluation time 175 | weight_noise = hard_concrete_mean(log_alpha, gamma, zeta) 176 | weights = theta * weight_noise 177 | return weights, weight_noise 178 | -------------------------------------------------------------------------------- /modules/rela.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | 9 | import func 10 | from utils import util, dtype 11 | 12 | 13 | def dot_attention(query, memory, mem_mask, hidden_size, 14 | ln=False, num_heads=1, cache=None, dropout=None, 15 | out_map=True, scope=None): 16 | """ 17 | dotted attention model 18 | :param query: [batch_size, qey_len, dim] 19 | :param memory: [batch_size, seq_len, mem_dim] or None 20 | :param mem_mask: [batch_size, seq_len] 21 | :param hidden_size: attention space dimension 22 | :param ln: whether use layer normalization 23 | :param num_heads: attention head number 24 | :param dropout: attention dropout, default disable 25 | :param out_map: output additional mapping 26 | :param cache: cache-based decoding 27 | :param scope: 28 | :return: a value matrix, [batch_size, qey_len, mem_dim] 29 | """ 30 | with tf.variable_scope(scope or "dot_attention", reuse=tf.AUTO_REUSE, 31 | dtype=tf.as_dtype(dtype.floatx())): 32 | if memory is None: 33 | # suppose self-attention from queries alone 34 | h = func.linear(query, hidden_size * 3, ln=ln, scope="qkv_map") 35 | q, k, v = tf.split(h, 3, -1) 36 | 37 | if cache is not None: 38 | k = tf.concat([cache['k'], k], axis=1) 39 | v = tf.concat([cache['v'], v], axis=1) 40 | cache = { 41 | 'k': k, 42 | 'v': v, 43 | } 44 | else: 45 | q = func.linear(query, hidden_size, ln=ln, scope="q_map") 46 | if cache is not None and ('mk' in cache and 'mv' in cache): 47 | k, v = cache['mk'], cache['mv'] 48 | else: 49 | k = func.linear(memory, hidden_size, ln=ln, scope="k_map") 50 | v = func.linear(memory, hidden_size, ln=ln, scope="v_map") 51 | 52 | if cache is not None: 53 | cache['mk'] = k 54 | cache['mv'] = v 55 | 56 | q = func.split_heads(q, num_heads) 57 | k = func.split_heads(k, num_heads) 58 | v = func.split_heads(v, num_heads) 59 | 60 | q *= (hidden_size // num_heads) ** (-0.5) 61 | 62 | # q * k => attention weights 63 | logits = tf.matmul(q, k, transpose_b=True) 64 | 65 | # convert the mask to 0-1 form and multiply to logits 66 | if mem_mask is not None: 67 | zero_one_mask = tf.to_float(tf.equal(mem_mask, 0.0)) 68 | logits *= zero_one_mask 69 | 70 | # replace softmax with relu 71 | # weights = tf.nn.softmax(logits) 72 | weights = tf.nn.relu(logits) 73 | 74 | dweights = util.valid_apply_dropout(weights, dropout) 75 | 76 | # weights * v => attention vectors 77 | o = tf.matmul(dweights, v) 78 | o = func.combine_heads(o) 79 | 80 | # perform RMSNorm to stabilize running 81 | o = gated_rms_norm(o, scope="post") 82 | 83 | if out_map: 84 | o = func.linear(o, hidden_size, ln=ln, scope="o_map") 85 | 86 | results = { 87 | 'weights': weights, 88 | 'output': o, 89 | 'cache': cache 90 | } 91 | 92 | return results 93 | 94 | 95 | def gated_rms_norm(x, eps=None, scope=None): 96 | """RMS-based Layer normalization layer""" 97 | if eps is None: 98 | eps = dtype.epsilon() 99 | with tf.variable_scope(scope or "rms_norm", 100 | dtype=tf.as_dtype(dtype.floatx())): 101 | layer_size = util.shape_list(x)[-1] 102 | 103 | scale = tf.get_variable("scale", [layer_size], initializer=tf.ones_initializer()) 104 | gate = tf.get_variable("gate", [layer_size], initializer=None) 105 | 106 | ms = tf.reduce_mean(x ** 2, -1, keep_dims=True) 107 | 108 | # adding gating here which slightly improves quality 109 | return scale * x * tf.rsqrt(ms + eps) * tf.nn.sigmoid(gate * x) 110 | -------------------------------------------------------------------------------- /modules/rpr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | 9 | 10 | def relative_attention_inner(x, y, z=None, transpose=False): 11 | """Relative position-aware dot-product attention inner calculation. 12 | This batches matrix multiply calculations to avoid unnecessary broadcasting. 13 | 14 | Args: 15 | x: Tensor with shape [batch_size, heads, length, length or depth]. 16 | y: Tensor with shape [batch_size, heads, length, depth]. 17 | z: Tensor with shape [length, length, depth]. 18 | transpose: Whether to transpose inner matrices of y and z. Should be true if 19 | last dimension of x is depth, not length. 20 | 21 | Returns: 22 | A Tensor with shape [batch_size, heads, length, length or depth]. 23 | """ 24 | batch_size = tf.shape(x)[0] 25 | heads = x.get_shape().as_list()[1] 26 | length = tf.shape(x)[2] 27 | 28 | # xy_matmul is [batch_size, heads, length, length or depth] 29 | xy_matmul = tf.matmul(x, y, transpose_b=transpose) 30 | if z is not None: 31 | # x_t is [length, batch_size, heads, length or depth] 32 | x_t = tf.transpose(x, [2, 0, 1, 3]) 33 | # x_t_r is [length, batch_size * heads, length or depth] 34 | x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1]) 35 | # x_tz_matmul is [length, batch_size * heads, length or depth] 36 | x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose) 37 | # x_tz_matmul_r is [length, batch_size, heads, length or depth] 38 | x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1]) 39 | # x_tz_matmul_r_t is [batch_size, heads, length, length or depth] 40 | x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3]) 41 | return xy_matmul + x_tz_matmul_r_t 42 | else: 43 | return xy_matmul 44 | 45 | 46 | def get_relative_positions_embeddings(length_x, length_y, 47 | depth, max_relative_position, name=None, last=None): 48 | """Generates tensor of size [length_x, length_y, depth].""" 49 | with tf.variable_scope(name or "rpr"): 50 | relative_positions_matrix = get_relative_positions_matrix( 51 | length_x, length_y, max_relative_position) 52 | # to handle cached decoding, where target-token incrementally grows 53 | if last is not None: 54 | relative_positions_matrix = relative_positions_matrix[-last:] 55 | vocab_size = max_relative_position * 2 + 1 56 | # Generates embedding for each relative position of dimension depth. 57 | embeddings_table = tf.get_variable("embeddings", [vocab_size, depth]) 58 | embeddings = tf.gather(embeddings_table, relative_positions_matrix) 59 | return embeddings 60 | 61 | 62 | def get_relative_positions_matrix(length_x, length_y, max_relative_position): 63 | """Generates matrix of relative positions between inputs.""" 64 | range_vec_x = tf.range(length_x) 65 | range_vec_y = tf.range(length_y) 66 | 67 | # shape: [length_x, length_y] 68 | distance_mat = tf.expand_dims(range_vec_x, -1) - tf.expand_dims(range_vec_y, 0) 69 | distance_mat_clipped = tf.clip_by_value(distance_mat, -max_relative_position, 70 | max_relative_position) 71 | 72 | # Shift values to be >= 0. Each integer still uniquely identifies a relative 73 | # position difference. 74 | final_mat = distance_mat_clipped + max_relative_position 75 | return final_mat 76 | -------------------------------------------------------------------------------- /rnns/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from rnns import gru, lstm, atr, sru, lrn, olrn 4 | 5 | 6 | def get_cell(cell_name, hidden_size, ln=False, scope=None): 7 | """Convert the cell_name into cell instance.""" 8 | cell_name = cell_name.lower() 9 | 10 | if cell_name == "gru": 11 | return gru.gru(hidden_size, ln=ln, scope=scope or "gru") 12 | elif cell_name == "lstm": 13 | return lstm.lstm(hidden_size, ln=ln, scope=scope or "lstm") 14 | elif cell_name == "atr": 15 | return atr.atr(hidden_size, ln=ln, scope=scope or "atr") 16 | elif cell_name == "sru": 17 | return sru.sru(hidden_size, ln=ln, scope=scope or "sru") 18 | elif cell_name == "lrn": 19 | return lrn.lrn(hidden_size, ln=ln, scope=scope or "lrn") 20 | elif cell_name == "olrn": 21 | return olrn.olrn(hidden_size, ln=ln, scope=scope or "olrn") 22 | else: 23 | raise NotImplementedError("{} is not supported".format(cell_name)) 24 | -------------------------------------------------------------------------------- /rnns/atr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | 9 | from func import linear 10 | from rnns import cell as cell 11 | 12 | 13 | class atr(cell.Cell): 14 | """The Addition-Subtraction Twin-Gated Recurrent Unit.""" 15 | 16 | def __init__(self, d, ln=False, twin=True, scope='atr'): 17 | super(atr, self).__init__(d, ln=ln, scope=scope) 18 | 19 | self.twin = twin 20 | 21 | def get_init_state(self, shape=None, x=None, scope=None): 22 | return self._get_init_state( 23 | self.d, shape=shape, x=x, scope=scope) 24 | 25 | def fetch_states(self, x): 26 | with tf.variable_scope( 27 | "fetch_state_{}".format(self.scope or "atr")): 28 | h = linear(x, self.d, 29 | bias=False, ln=self.ln, scope="hide_x") 30 | return (h, ) 31 | 32 | def __call__(self, h_, x): 33 | # h_: the previous hidden state 34 | # x: the current input state 35 | """ 36 | p = W x 37 | q = U h_ 38 | i = sigmoid(p + q) 39 | f = sigmoid(p - q) 40 | h = i * p + f * h_ 41 | """ 42 | if isinstance(x, (list, tuple)): 43 | x = x[0] 44 | 45 | with tf.variable_scope( 46 | "cell_{}".format(self.scope or "atr")): 47 | q = linear(h_, self.d, 48 | ln=self.ln, scope="hide_h") 49 | p = x 50 | 51 | f = tf.sigmoid(p - q) 52 | if self.twin: 53 | i = tf.sigmoid(p + q) 54 | # we empirically find that the following simple form is more stable. 55 | else: 56 | i = 1. - f 57 | 58 | h = i * p + f * h_ 59 | 60 | return h 61 | -------------------------------------------------------------------------------- /rnns/cell.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import abc 8 | import tensorflow as tf 9 | from func import linear 10 | from utils import dtype 11 | 12 | 13 | # This is an abstract class that deals with 14 | # recurrent cells, e.g. GRU, LSTM, ATR 15 | class Cell(object): 16 | def __init__(self, 17 | d, # hidden state dimension 18 | ln=False, # whether use layer normalization 19 | scope=None, # the name scope for this cell 20 | ): 21 | self.d = d 22 | self.scope = scope 23 | self.ln = ln 24 | 25 | def _get_init_state(self, d, shape=None, x=None, scope=None): 26 | # gen init state vector 27 | # if no evidence x is provided, use zero initialization 28 | if x is None: 29 | assert shape is not None, "you should provide shape" 30 | if not isinstance(shape, (tuple, list)): 31 | shape = [shape] 32 | shape = shape + [d] 33 | return dtype.tf_to_float(tf.zeros(shape)) 34 | else: 35 | return linear( 36 | x, d, bias=True, ln=self.ln, 37 | scope="{}_init".format(scope or self.scope) 38 | ) 39 | 40 | def get_hidden(self, x): 41 | return x 42 | 43 | @abc.abstractmethod 44 | def get_init_state(self, shape=None, x=None, scope=None): 45 | raise NotImplementedError("Not Supported") 46 | 47 | @abc.abstractmethod 48 | def __call__(self, h_, x): 49 | raise NotImplementedError("Not Supported") 50 | 51 | @abc.abstractmethod 52 | def fetch_states(self, x): 53 | raise NotImplementedError("Not Supported") 54 | -------------------------------------------------------------------------------- /rnns/gru.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | 9 | from func import linear 10 | from rnns import cell as cell 11 | 12 | 13 | class gru(cell.Cell): 14 | """The Gated Recurrent Unit.""" 15 | 16 | def __init__(self, d, ln=False, scope='gru'): 17 | super(gru, self).__init__(d, ln=ln, scope=scope) 18 | 19 | def get_init_state(self, shape=None, x=None, scope=None): 20 | return self._get_init_state( 21 | self.d, shape=shape, x=x, scope=scope) 22 | 23 | def fetch_states(self, x): 24 | with tf.variable_scope( 25 | "fetch_state_{}".format(self.scope or "gru")): 26 | g = linear(x, self.d * 2, 27 | bias=False, ln=self.ln, scope="gate_x") 28 | h = linear(x, self.d, 29 | bias=False, ln=self.ln, scope="hide_x") 30 | return g, h 31 | 32 | def __call__(self, h_, x): 33 | # h_: the previous hidden state 34 | # x_g/x: the current input state for gate 35 | # x_h/x: the current input state for hidden 36 | """ 37 | z = sigmoid(h_, x) 38 | r = sigmoid(h_, x) 39 | h' = tanh(x, r * h_) 40 | h = z * h_ + (1. - z) * h' 41 | """ 42 | with tf.variable_scope( 43 | "cell_{}".format(self.scope or "gru")): 44 | x_g, x_h = x 45 | 46 | h_g = linear(h_, self.d * 2, 47 | ln=self.ln, scope="gate_h") 48 | z, r = tf.split( 49 | tf.sigmoid(x_g + h_g), 2, -1) 50 | 51 | h_h = linear(h_ * r, self.d, 52 | ln=self.ln, scope="hide_h") 53 | h = tf.tanh(x_h + h_h) 54 | 55 | h = z * h_ + (1. - z) * h 56 | 57 | return h 58 | -------------------------------------------------------------------------------- /rnns/lrn.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | 9 | from func import linear 10 | from rnns import cell as cell 11 | 12 | 13 | class lrn(cell.Cell): 14 | """The Recurrence-Free Addition-Subtraction Twin-Gated Recurrent Unit. 15 | Or, lightweight recurrent network 16 | """ 17 | 18 | def __init__(self, d, ln=False, scope='lrn'): 19 | super(lrn, self).__init__(d, ln=ln, scope=scope) 20 | 21 | def get_init_state(self, shape=None, x=None, scope=None): 22 | return self._get_init_state( 23 | self.d, shape=shape, x=x, scope=scope) 24 | 25 | def fetch_states(self, x): 26 | with tf.variable_scope( 27 | "fetch_state_{}".format(self.scope or "lrn")): 28 | h = linear(x, self.d * 3, 29 | bias=False, ln=self.ln, scope="hide_x") 30 | return (h, ) 31 | 32 | def __call__(self, h_, x): 33 | # h_: the previous hidden state 34 | # p,q,r/x: the current input state 35 | """ 36 | p, q, r = W x 37 | i = sigmoid(p + h_) 38 | f = sigmoid(q - h_) 39 | h = i * r + f * h_ 40 | """ 41 | if isinstance(x, (list, tuple)): 42 | x = x[0] 43 | 44 | with tf.variable_scope( 45 | "cell_{}".format(self.scope or "atr")): 46 | p, q, r = tf.split(x, 3, -1) 47 | 48 | i = tf.sigmoid(p + h_) 49 | f = tf.sigmoid(q - h_) 50 | 51 | h = i * r + f * h_ 52 | 53 | return h 54 | -------------------------------------------------------------------------------- /rnns/lstm.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | 9 | from func import linear 10 | from rnns import cell as cell 11 | 12 | 13 | class lstm(cell.Cell): 14 | """The Long-Short Term Memory Unit.""" 15 | 16 | def __init__(self, d, ln=False, scope='lstm'): 17 | super(lstm, self).__init__(d, ln=ln, scope=scope) 18 | 19 | def get_init_state(self, shape=None, x=None, scope=None): 20 | return self._get_init_state( 21 | self.d * 2, shape=shape, x=x, scope=scope) 22 | 23 | def get_hidden(self, x): 24 | return tf.split(x, 2, -1)[0] 25 | 26 | def fetch_states(self, x): 27 | with tf.variable_scope( 28 | "fetch_state_{}".format(self.scope or "lstm")): 29 | g = linear(x, self.d * 3, 30 | bias=False, ln=self.ln, scope="gate_x") 31 | c = linear(x, self.d, 32 | bias=False, ln=self.ln, scope="hide_x") 33 | return g, c 34 | 35 | def __call__(self, h_, x): 36 | # h_: the concatenation of previous hidden state 37 | # and memory cell state 38 | # x_i/x: the current input state for input gate 39 | # x_f/x: the current input state for forget gate 40 | # x_o/x: the current input state for output gate 41 | # x_c/x: the current input state for candidate cell 42 | """ 43 | f = sigmoid(h_, x) 44 | i = sigmoid(h_, x) 45 | o = sigmoid(h_, x) 46 | c' = tanh(h_, x) 47 | c = f * c_ + i * c' 48 | h = o * tanh(c) 49 | """ 50 | with tf.variable_scope( 51 | "cell_{}".format(self.scope or "lstm")): 52 | x_g, x_c = x 53 | h_, c_ = tf.split(h_, 2, -1) 54 | 55 | h_g = linear(h_, self.d * 3, 56 | ln=self.ln, scope="gate_h") 57 | i, f, o = tf.split( 58 | tf.sigmoid(x_g + h_g), 3, -1) 59 | 60 | h_c = linear(h_, self.d, 61 | ln=self.ln, scope="hide_h") 62 | h_c = tf.tanh(x_c + h_c) 63 | 64 | c = i * h_c + f * c_ 65 | 66 | h = o * tf.tanh(c) 67 | 68 | return tf.concat([h, c], -1) 69 | -------------------------------------------------------------------------------- /rnns/olrn.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | 9 | from func import linear 10 | from rnns import cell as cell 11 | 12 | 13 | class olrn(cell.Cell): 14 | """The Recurrence-Free Addition-Subtraction Twin-Gated Recurrent Unit. 15 | Or, output-gated lightweight recurrent network 16 | """ 17 | 18 | def __init__(self, d, ln=False, scope='olrn'): 19 | super(olrn, self).__init__(d, ln=ln, scope=scope) 20 | 21 | def get_init_state(self, shape=None, x=None, scope=None): 22 | return self._get_init_state( 23 | self.d, shape=shape, x=x, scope=scope) 24 | 25 | def fetch_states(self, x): 26 | with tf.variable_scope( 27 | "fetch_state_{}".format(self.scope or "olrn")): 28 | h = linear(x, self.d * 4, 29 | bias=False, ln=self.ln, scope="hide_x") 30 | return (h, ) 31 | 32 | def __call__(self, h_, x): 33 | # h_: the previous hidden state 34 | # p,q,r,s/x: the current input state 35 | """ 36 | p, q, r, s = W x 37 | i = sigmoid(p + h_) 38 | f = sigmoid(q - h_) 39 | h = i * r + f * h_ 40 | o = simoid(s - h) 41 | h = o * h 42 | """ 43 | if isinstance(x, (list, tuple)): 44 | x = x[0] 45 | 46 | with tf.variable_scope( 47 | "cell_{}".format(self.scope or "atr")): 48 | p, q, r, s = tf.split(x, 4, -1) 49 | 50 | i = tf.sigmoid(p + h_) 51 | f = tf.sigmoid(q - h_) 52 | 53 | h = i * r + f * h_ 54 | 55 | o = tf.nn.sigmoid(s - h) 56 | h = o * h 57 | 58 | return h 59 | -------------------------------------------------------------------------------- /rnns/rnn.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | 9 | from utils import util, dtype 10 | from rnns import get_cell 11 | from func import linear, additive_attention 12 | 13 | 14 | def rnn(cell_name, x, d, mask=None, ln=False, init_state=None, sm=True): 15 | """Self implemented RNN procedure, supporting mask trick""" 16 | # cell_name: gru, lstm or atr 17 | # x: input sequence embedding matrix, [batch, seq_len, dim] 18 | # d: hidden dimension for rnn 19 | # mask: mask matrix, [batch, seq_len] 20 | # ln: whether use layer normalization 21 | # init_state: the initial hidden states, for cache purpose 22 | # sm: whether apply swap memory during rnn scan 23 | # dp: variational dropout 24 | 25 | in_shape = util.shape_list(x) 26 | batch_size, time_steps = in_shape[:2] 27 | 28 | cell = get_cell(cell_name, d, ln=ln) 29 | 30 | if init_state is None: 31 | init_state = cell.get_init_state(shape=[batch_size]) 32 | if mask is None: 33 | mask = dtype.tf_to_float(tf.ones([batch_size, time_steps])) 34 | 35 | # prepare projected input 36 | cache_inputs = cell.fetch_states(x) 37 | cache_inputs = [tf.transpose(v, [1, 0, 2]) 38 | for v in list(cache_inputs)] 39 | mask_ta = tf.transpose(tf.expand_dims(mask, -1), [1, 0, 2]) 40 | 41 | def _step_fn(prev, x): 42 | t, h_ = prev 43 | m = x[-1] 44 | v = x[:-1] 45 | 46 | h = cell(h_, v) 47 | h = m * h + (1. - m) * h_ 48 | 49 | return t + 1, h 50 | 51 | time = tf.constant(0, dtype=tf.int32, name="time") 52 | step_states = (time, init_state) 53 | step_vars = cache_inputs + [mask_ta] 54 | 55 | outputs = tf.scan(_step_fn, 56 | step_vars, 57 | initializer=step_states, 58 | parallel_iterations=32, 59 | swap_memory=sm) 60 | 61 | output_ta = outputs[1] 62 | output_state = outputs[1][-1] 63 | 64 | outputs = tf.transpose(output_ta, [1, 0, 2]) 65 | 66 | return (outputs, output_state), \ 67 | (cell.get_hidden(outputs), cell.get_hidden(output_state)) 68 | 69 | 70 | def cond_rnn(cell_name, x, memory, d, init_state=None, 71 | mask=None, mem_mask=None, ln=False, sm=True, 72 | one2one=False, num_heads=1): 73 | """Self implemented conditional-RNN procedure, supporting mask trick""" 74 | # cell_name: gru, lstm or atr 75 | # x: input sequence embedding matrix, [batch, seq_len, dim] 76 | # memory: the conditional part 77 | # d: hidden dimension for rnn 78 | # mask: mask matrix, [batch, seq_len] 79 | # mem_mask: memory mask matrix, [batch, mem_seq_len] 80 | # ln: whether use layer normalization 81 | # init_state: the initial hidden states, for cache purpose 82 | # sm: whether apply swap memory during rnn scan 83 | # one2one: whether the memory is one-to-one mapping for x 84 | # num_heads: number of attention heads, multi-head attention 85 | # dp: variational dropout 86 | 87 | in_shape = util.shape_list(x) 88 | batch_size, time_steps = in_shape[:2] 89 | mem_shape = util.shape_list(memory) 90 | 91 | cell_lower = get_cell(cell_name, d, ln=ln, 92 | scope="{}_lower".format(cell_name)) 93 | cell_higher = get_cell(cell_name, d, ln=ln, 94 | scope="{}_higher".format(cell_name)) 95 | 96 | if init_state is None: 97 | init_state = cell_lower.get_init_state(shape=[batch_size]) 98 | if mask is None: 99 | mask = dtype.tf_to_float(tf.ones([batch_size, time_steps])) 100 | if mem_mask is None: 101 | mem_mask = dtype.tf_to_float(tf.ones([batch_size, mem_shape[1]])) 102 | 103 | # prepare projected encodes and inputs 104 | cache_inputs = cell_lower.fetch_states(x) 105 | cache_inputs = [tf.transpose(v, [1, 0, 2]) 106 | for v in list(cache_inputs)] 107 | if not one2one: 108 | proj_memories = linear(memory, mem_shape[-1], bias=False, 109 | ln=ln, scope="context_att") 110 | else: 111 | cache_memories = cell_higher.fetch_states(memory) 112 | cache_memories = [tf.transpose(v, [1, 0, 2]) 113 | for v in list(cache_memories)] 114 | mask_ta = tf.transpose(tf.expand_dims(mask, -1), [1, 0, 2]) 115 | init_context = dtype.tf_to_float(tf.zeros([batch_size, mem_shape[-1]])) 116 | init_weight = dtype.tf_to_float(tf.zeros([batch_size, num_heads, mem_shape[1]])) 117 | mask_pos = len(cache_inputs) 118 | 119 | def _step_fn(prev, x): 120 | t, h_, c_, a_ = prev 121 | 122 | if not one2one: 123 | m, v = x[mask_pos], x[:mask_pos] 124 | else: 125 | c, c_c, m, v = x[-1], x[mask_pos+1:-1], x[mask_pos], x[:mask_pos] 126 | 127 | s = cell_lower(h_, v) 128 | s = m * s + (1. - m) * h_ 129 | 130 | if not one2one: 131 | vle = additive_attention( 132 | cell_lower.get_hidden(s), memory, mem_mask, 133 | mem_shape[-1], ln=ln, num_heads=num_heads, 134 | proj_memory=proj_memories, scope="attention") 135 | a, c = vle['weights'], vle['output'] 136 | c_c = cell_higher.fetch_states(c) 137 | else: 138 | a = tf.tile(tf.expand_dims(tf.range(time_steps), 0), [batch_size, 1]) 139 | a = dtype.tf_to_float(tf.equal(a, t)) 140 | a = tf.tile(tf.expand_dims(a, 1), [1, num_heads, 1]) 141 | a = tf.reshape(a, tf.shape(init_weight)) 142 | 143 | h = cell_higher(s, c_c) 144 | h = m * h + (1. - m) * s 145 | 146 | return t + 1, h, c, a 147 | 148 | time = tf.constant(0, dtype=tf.int32, name="time") 149 | step_states = (time, init_state, init_context, init_weight) 150 | step_vars = cache_inputs + [mask_ta] 151 | if one2one: 152 | step_vars += cache_memories + [memory] 153 | 154 | outputs = tf.scan(_step_fn, 155 | step_vars, 156 | initializer=step_states, 157 | parallel_iterations=32, 158 | swap_memory=sm) 159 | 160 | output_ta = outputs[1] 161 | context_ta = outputs[2] 162 | attention_ta = outputs[3] 163 | 164 | outputs = tf.transpose(output_ta, [1, 0, 2]) 165 | output_states = outputs[:, -1] 166 | contexts = tf.transpose(context_ta, [1, 0, 2]) 167 | attentions = tf.transpose(attention_ta, [1, 2, 0, 3]) 168 | 169 | return (outputs, output_states), \ 170 | (cell_higher.get_hidden(outputs), cell_higher.get_hidden(output_states)), \ 171 | contexts, attentions 172 | -------------------------------------------------------------------------------- /rnns/sru.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | 9 | from func import linear 10 | from rnns import cell as cell 11 | 12 | 13 | class sru(cell.Cell): 14 | """The Simple Recurrent Unit.""" 15 | 16 | def __init__(self, d, ln=False, scope='sru'): 17 | super(sru, self).__init__(d, ln=ln, scope=scope) 18 | 19 | def get_init_state(self, shape=None, x=None, scope=None): 20 | return self._get_init_state( 21 | self.d * 2, shape=shape, x=x, scope=scope) 22 | 23 | def get_hidden(self, x): 24 | return tf.split(x, 2, -1)[0] 25 | 26 | def fetch_states(self, x): 27 | with tf.variable_scope( 28 | "fetch_state_{}".format(self.scope or "sru")): 29 | h = linear(x, self.d * 4, 30 | bias=False, ln=self.ln, scope="hide_x") 31 | return (h, ) 32 | 33 | def __call__(self, h_, x): 34 | # h_: the concatenation of previous hidden state 35 | # and memory cell state 36 | # x_r/x: the current input state for r gate 37 | # x_f/x: the current input state for f gate 38 | # x_c/x: the current input state for candidate cell 39 | # x_h/x: the current input state for hidden output 40 | # we increase this because we do not assume that 41 | # the input dimension equals the output dimension 42 | """ 43 | f = sigmoid(Wx, vf * c_) 44 | c = f * c_ + (1 - f) * Wx 45 | r = sigmoid(Wx, vr * c_) 46 | h = r * c + (1 - r) * Ux 47 | """ 48 | if isinstance(x, (list, tuple)): 49 | x = x[0] 50 | 51 | with tf.variable_scope( 52 | "cell_{}".format(self.scope or "sru")): 53 | x_r, x_f, x_c, x_h = tf.split(x, 4, -1) 54 | h_, c_ = tf.split(h_, 2, -1) 55 | 56 | v_f = tf.get_variable("v_f", [1, self.d]) 57 | v_r = tf.get_variable("v_r", [1, self.d]) 58 | 59 | f = tf.sigmoid(x_f + v_f * c_) 60 | c = f * c_ + (1. - f) * x_c 61 | r = tf.sigmoid(x_r + v_r * c_) 62 | h = r * c + (1. - r) * x_h 63 | 64 | return tf.concat([h, c], -1) 65 | -------------------------------------------------------------------------------- /scripts/bleu_over_length.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import math 8 | import argparse 9 | 10 | from collections import Counter 11 | 12 | 13 | def closest_length(candidate, references): 14 | clen = len(candidate) 15 | closest_diff = 9999 16 | closest_len = 9999 17 | 18 | for reference in references: 19 | rlen = len(reference) 20 | diff = abs(rlen - clen) 21 | 22 | if diff < closest_diff: 23 | closest_diff = diff 24 | closest_len = rlen 25 | elif diff == closest_diff: 26 | closest_len = rlen if rlen < closest_len else closest_len 27 | 28 | return closest_len 29 | 30 | 31 | def shortest_length(references): 32 | return min([len(ref) for ref in references]) 33 | 34 | 35 | def modified_precision(candidate, references, n): 36 | tngrams = len(candidate) + 1 - n 37 | counts = Counter([tuple(candidate[i:i+n]) for i in range(tngrams)]) 38 | 39 | if len(counts) == 0: 40 | return 0, 0 41 | 42 | max_counts = {} 43 | for reference in references: 44 | rngrams = len(reference) + 1 - n 45 | ngrams = [tuple(reference[i:i+n]) for i in range(rngrams)] 46 | ref_counts = Counter(ngrams) 47 | for ngram in counts: 48 | mcount = 0 if ngram not in max_counts else max_counts[ngram] 49 | rcount = 0 if ngram not in ref_counts else ref_counts[ngram] 50 | max_counts[ngram] = max(mcount, rcount) 51 | 52 | clipped_counts = {} 53 | 54 | for ngram, count in counts.items(): 55 | clipped_counts[ngram] = min(count, max_counts[ngram]) 56 | 57 | return float(sum(clipped_counts.values())), float(sum(counts.values())) 58 | 59 | 60 | def brevity_penalty(trans, refs, mode="closest"): 61 | bp_c = 0.0 62 | bp_r = 0.0 63 | 64 | for candidate, references in zip(trans, refs): 65 | bp_c += len(candidate) 66 | 67 | if mode == "shortest": 68 | bp_r += shortest_length(references) 69 | else: 70 | bp_r += closest_length(candidate, references) 71 | 72 | # Prevent zero divide 73 | bp_c = bp_c or 1.0 74 | 75 | return math.exp(min(0, 1.0 - bp_r / bp_c)) 76 | 77 | 78 | def bleu(trans, refs, bp="closest", smooth=False, n=4, weights=None): 79 | p_norm = [0 for _ in range(n)] 80 | p_denorm = [0 for _ in range(n)] 81 | 82 | for candidate, references in zip(trans, refs): 83 | for i in range(n): 84 | ccount, tcount = modified_precision(candidate, references, i + 1) 85 | p_norm[i] += ccount 86 | p_denorm[i] += tcount 87 | 88 | bleu_n = [0 for _ in range(n)] 89 | 90 | for i in range(n): 91 | # add one smoothing 92 | if smooth and i > 0: 93 | p_norm[i] += 1 94 | p_denorm[i] += 1 95 | 96 | if p_norm[i] == 0 or p_denorm[i] == 0: 97 | bleu_n[i] = -9999 98 | else: 99 | bleu_n[i] = math.log(float(p_norm[i]) / float(p_denorm[i])) 100 | 101 | if weights: 102 | if len(weights) != n: 103 | raise ValueError("len(weights) != n: invalid weight number") 104 | log_precision = sum([bleu_n[i] * weights[i] for i in range(n)]) 105 | else: 106 | log_precision = sum(bleu_n) / float(n) 107 | 108 | bp = brevity_penalty(trans, refs, bp) 109 | 110 | score = bp * math.exp(log_precision) 111 | 112 | return score 113 | 114 | 115 | def read(f, lc=False): 116 | with open(f, 'rU') as reader: 117 | return [line.strip().split() if not lc else line.strip().lower().split() 118 | for line in reader.readlines()] 119 | 120 | 121 | if __name__ == "__main__": 122 | parser = argparse.ArgumentParser( 123 | description='BLEU score over source sentence length') 124 | parser.add_argument('-lc', help='Lowercase, i.e case-insensitive setting', action='store_true') 125 | parser.add_argument('-bp', help='Length penalty', default='closest', choices=['shortest', 'closest']) 126 | parser.add_argument('-n', type=int, default=4, help="ngram-based BLEU") 127 | parser.add_argument('-g', type=int, default=1, help="sentence groups for evaluation") 128 | parser.add_argument('-source', type=str, required=True, help='The source file') 129 | parser.add_argument('-candidate', type=str, required=True, help='The candidate translation generated by MT system') 130 | parser.add_argument('-reference', type=str, nargs='+', required=True, 131 | help='The references like reference or reference0, reference1, ...') 132 | 133 | args = parser.parse_args() 134 | 135 | cand = args.candidate 136 | refs = args.reference 137 | src = args.source 138 | 139 | src_sentences = read(src, args.lc) 140 | cand_sentences = read(cand, args.lc) 141 | refs_sentences = [read(ref, args.lc) for ref in refs] 142 | 143 | assert len(cand_sentences) == len(refs_sentences[0]), \ 144 | 'ERROR: the length of candidate and reference must be the same.' 145 | 146 | refs_sentences = list(zip(*refs_sentences)) 147 | 148 | sorted_candidate_sentences = sorted(zip(src_sentences, cand_sentences), key=lambda x: len(x[0])) 149 | sorted_reference_sentences = sorted(zip(src_sentences, refs_sentences), key=lambda x: len(x[0])) 150 | 151 | sorted_source_sentences = [v[0] for v in sorted_candidate_sentences] 152 | sorted_candidate_sentences = [v[1] for v in sorted_candidate_sentences] 153 | sorted_reference_sentences = [v[1] for v in sorted_reference_sentences] 154 | 155 | groups = args.g 156 | elements_per_group = len(sorted_source_sentences) // groups 157 | 158 | scores = [] 159 | for gidx in range(groups): 160 | group_candidate = sorted_candidate_sentences[gidx * elements_per_group: (gidx + 1) * elements_per_group] 161 | group_reference = sorted_reference_sentences[gidx * elements_per_group: (gidx + 1) * elements_per_group] 162 | group_source = sorted_source_sentences[gidx * elements_per_group: (gidx + 1) * elements_per_group] 163 | 164 | group_average_source = float(sum([len(v) for v in group_source])) / float(len(group_source)) 165 | bleu_score = bleu(group_candidate, group_reference, bp=args.bp, n=args.n) 166 | 167 | print("Group Idx {} Avg Source Lenngth {} BLEU Score {}".format(gidx, group_average_source, bleu_score)) 168 | 169 | scores.append((group_average_source, bleu_score)) 170 | 171 | print('AvgLength: [{}]'.format(','.join([str(s[0]) for s in scores]))) 172 | print('BLEU Score: [{}]'.format(','.join([str(s[1]) for s in scores]))) 173 | 174 | -------------------------------------------------------------------------------- /scripts/checkpoint_averaging.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import argparse 8 | import operator 9 | import os 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | 15 | def parseargs(): 16 | msg = "Average checkpoints" 17 | usage = "average.py [] [-h | --help]" 18 | parser = argparse.ArgumentParser(description=msg, usage=usage) 19 | 20 | parser.add_argument("--path", type=str, required=True, 21 | help="checkpoint dir") 22 | parser.add_argument("--checkpoints", type=int, required=True, 23 | help="number of checkpoints to use") 24 | parser.add_argument("--output", type=str, help="output path") 25 | parser.add_argument("--gpu", type=int, default=0, 26 | help="the default gpu device index") 27 | 28 | return parser.parse_args() 29 | 30 | 31 | def get_checkpoints(path): 32 | if not tf.gfile.Exists(os.path.join(path, "checkpoint")): 33 | raise ValueError("Cannot find checkpoints in %s" % path) 34 | 35 | checkpoint_names = [] 36 | 37 | with tf.gfile.GFile(os.path.join(path, "checkpoint")) as fd: 38 | # Skip the first line 39 | fd.readline() 40 | for line in fd: 41 | name = line.strip().split(":")[-1].strip()[1:-1] 42 | key = int(name.split("-")[-1]) 43 | checkpoint_names.append((key, os.path.join(path, name))) 44 | 45 | sorted_names = sorted(checkpoint_names, key=operator.itemgetter(0), 46 | reverse=True) 47 | 48 | return [item[-1] for item in sorted_names] 49 | 50 | 51 | def checkpoint_exists(path): 52 | return (tf.gfile.Exists(path) or tf.gfile.Exists(path + ".meta") or 53 | tf.gfile.Exists(path + ".index")) 54 | 55 | 56 | def main(_): 57 | tf.logging.set_verbosity(tf.logging.INFO) 58 | checkpoints = get_checkpoints(FLAGS.path) 59 | checkpoints = checkpoints[:FLAGS.checkpoints] 60 | 61 | if not checkpoints: 62 | raise ValueError("No checkpoints provided for averaging.") 63 | 64 | checkpoints = [c for c in checkpoints if checkpoint_exists(c)] 65 | 66 | if not checkpoints: 67 | raise ValueError( 68 | "None of the provided checkpoints exist. %s" % FLAGS.checkpoints 69 | ) 70 | 71 | var_list = tf.contrib.framework.list_variables(checkpoints[0]) 72 | var_values, var_dtypes = {}, {} 73 | 74 | for (name, shape) in var_list: 75 | if not name.startswith("global_step"): 76 | var_values[name] = np.zeros(shape) 77 | 78 | for checkpoint in checkpoints: 79 | reader = tf.contrib.framework.load_checkpoint(checkpoint) 80 | for name in var_values: 81 | tensor = reader.get_tensor(name) 82 | var_dtypes[name] = tensor.dtype 83 | var_values[name] += tensor 84 | tf.logging.info("Read from checkpoint %s", checkpoint) 85 | 86 | # Average checkpoints 87 | for name in var_values: 88 | var_values[name] /= len(checkpoints) 89 | 90 | tf_vars = [ 91 | tf.get_variable(name, shape=var_values[name].shape, 92 | dtype=var_dtypes[name]) for name in var_values 93 | ] 94 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 95 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 96 | global_step = tf.Variable(0, name="global_step", trainable=False, 97 | dtype=tf.int64) 98 | saver = tf.train.Saver(tf.global_variables()) 99 | 100 | sess_config = tf.ConfigProto(allow_soft_placement=True) 101 | sess_config.gpu_options.allow_growth = True 102 | sess_config.gpu_options.visible_device_list = "%s" % FLAGS.gpu 103 | 104 | with tf.Session(config=sess_config) as sess: 105 | sess.run(tf.global_variables_initializer()) 106 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 107 | var_values.iteritems()): 108 | sess.run(assign_op, {p: value}) 109 | saved_name = os.path.join(FLAGS.output, "average") 110 | saver.save(sess, saved_name, global_step=global_step) 111 | 112 | tf.logging.info("Averaged checkpoints saved in %s", saved_name) 113 | 114 | params_pattern = os.path.join(FLAGS.path, "*.json") 115 | params_files = tf.gfile.Glob(params_pattern) 116 | 117 | for name in params_files: 118 | new_name = name.replace(FLAGS.path.rstrip("/"), 119 | FLAGS.output.rstrip("/")) 120 | tf.gfile.Copy(name, new_name, overwrite=True) 121 | 122 | 123 | if __name__ == "__main__": 124 | FLAGS = parseargs() 125 | tf.app.run() 126 | -------------------------------------------------------------------------------- /scripts/chrF.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Author: Rico Sennrich 4 | 5 | """Compute chrF3 for machine translation evaluation 6 | 7 | Reference: 8 | Maja Popović (2015). chrF: character n-gram F-score for automatic MT evaluation. In Proceedings of the Tenth Workshop on Statistical Machine Translationn, pages 392–395, Lisbon, Portugal. 9 | """ 10 | 11 | from __future__ import print_function, unicode_literals, division 12 | 13 | import sys 14 | import codecs 15 | import io 16 | import argparse 17 | 18 | from collections import defaultdict 19 | 20 | # hack for python2/3 compatibility 21 | from io import open 22 | argparse.open = open 23 | 24 | def create_parser(): 25 | parser = argparse.ArgumentParser( 26 | formatter_class=argparse.RawDescriptionHelpFormatter, 27 | description="learn BPE-based word segmentation") 28 | 29 | parser.add_argument( 30 | '--ref', '-r', type=argparse.FileType('r'), required=True, 31 | metavar='PATH', 32 | help="Reference file") 33 | parser.add_argument( 34 | '--hyp', type=argparse.FileType('r'), metavar='PATH', 35 | default=sys.stdin, 36 | help="Hypothesis file (default: stdin).") 37 | parser.add_argument( 38 | '--beta', '-b', type=float, default=3, 39 | metavar='FLOAT', 40 | help="beta parameter (default: '%(default)s')") 41 | parser.add_argument( 42 | '--ngram', '-n', type=int, default=6, 43 | metavar='INT', 44 | help="ngram order (default: '%(default)s')") 45 | parser.add_argument( 46 | '--space', '-s', action='store_true', 47 | help="take spaces into account (default: '%(default)s')") 48 | parser.add_argument( 49 | '--precision', action='store_true', 50 | help="report precision (default: '%(default)s')") 51 | parser.add_argument( 52 | '--recall', action='store_true', 53 | help="report recall (default: '%(default)s')") 54 | 55 | return parser 56 | 57 | def extract_ngrams(words, max_length=4, spaces=False): 58 | 59 | if not spaces: 60 | words = ''.join(words.split()) 61 | else: 62 | words = words.strip() 63 | 64 | results = defaultdict(lambda: defaultdict(int)) 65 | for length in range(max_length): 66 | for start_pos in range(len(words)): 67 | end_pos = start_pos + length + 1 68 | if end_pos <= len(words): 69 | results[length][tuple(words[start_pos: end_pos])] += 1 70 | return results 71 | 72 | 73 | def get_correct(ngrams_ref, ngrams_test, correct, total): 74 | 75 | for rank in ngrams_test: 76 | for chain in ngrams_test[rank]: 77 | total[rank] += ngrams_test[rank][chain] 78 | if chain in ngrams_ref[rank]: 79 | correct[rank] += min(ngrams_test[rank][chain], ngrams_ref[rank][chain]) 80 | 81 | return correct, total 82 | 83 | 84 | def f1(correct, total_hyp, total_ref, max_length, beta=3, smooth=0): 85 | 86 | precision = 0 87 | recall = 0 88 | 89 | for i in range(max_length): 90 | if total_hyp[i] + smooth and total_ref[i] + smooth: 91 | precision += (correct[i] + smooth) / (total_hyp[i] + smooth) 92 | recall += (correct[i] + smooth) / (total_ref[i] + smooth) 93 | 94 | precision /= max_length 95 | recall /= max_length 96 | 97 | return (1 + beta**2) * (precision*recall) / ((beta**2 * precision) + recall), precision, recall 98 | 99 | def main(args): 100 | 101 | correct = [0]*args.ngram 102 | total = [0]*args.ngram 103 | total_ref = [0]*args.ngram 104 | for line in args.ref: 105 | line2 = args.hyp.readline() 106 | 107 | ngrams_ref = extract_ngrams(line, max_length=args.ngram, spaces=args.space) 108 | ngrams_test = extract_ngrams(line2, max_length=args.ngram, spaces=args.space) 109 | 110 | get_correct(ngrams_ref, ngrams_test, correct, total) 111 | 112 | for rank in ngrams_ref: 113 | for chain in ngrams_ref[rank]: 114 | total_ref[rank] += ngrams_ref[rank][chain] 115 | 116 | chrf, precision, recall = f1(correct, total, total_ref, args.ngram, args.beta) 117 | 118 | print('chrF3: {0:.4f}'.format(chrf)) 119 | if args.precision: 120 | print('chrPrec: {0:.4f}'.format(precision)) 121 | if args.recall: 122 | print('chrRec: {0:.4f}'.format(recall)) 123 | 124 | if __name__ == '__main__': 125 | 126 | # python 2/3 compatibility 127 | if sys.version_info < (3, 0): 128 | sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) 129 | sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) 130 | sys.stdin = codecs.getreader('UTF-8')(sys.stdin) 131 | else: 132 | sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') 133 | sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') 134 | sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', write_through=True, line_buffering=True) 135 | 136 | parser = create_parser() 137 | args = parser.parse_args() 138 | 139 | main(args) 140 | -------------------------------------------------------------------------------- /scripts/evaluate_pos_translation_rate.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import argparse 8 | 9 | import sys 10 | from collections import Counter 11 | 12 | 13 | def parseargs(): 14 | msg = "Evlauate P/R/F score for particular POS Tagged Tokens" 15 | usage = "{} [] [-h | --help]".format(sys.argv[0]) 16 | parser = argparse.ArgumentParser(description=msg, usage=usage) 17 | 18 | parser.add_argument("--trans", type=str, required=True, 19 | help="model translation") 20 | parser.add_argument("--refs", type=str, required=True, nargs="+", 21 | help="gold reference, one or more") 22 | parser.add_argument("--ngram", type=int, default=4, 23 | help="the maximum n for n-gram") 24 | 25 | parser.add_argument_group("POS setting") 26 | parser.add_argument("--noun", type=str, default="NN", 27 | help="the pos label for noun") 28 | parser.add_argument("--verb", type=str, default="VB", 29 | help="the pos label for verb") 30 | parser.add_argument("--adj", type=str, default="JJ", 31 | help="the pos label for adjective") 32 | parser.add_argument("--adv", type=str, default="RB", 33 | help="the pos label for adverb") 34 | 35 | parser.add_argument("--spliter", type=str, default="_", 36 | help="the spliter between word and pos label") 37 | 38 | return parser.parse_args() 39 | 40 | 41 | # POS conversion module 42 | def prepare_ngram(txt, pos, ngram): 43 | tokens = txt.strip().split() 44 | 45 | words = [] 46 | for token in tokens: 47 | if type(pos) is not list and pos in token: 48 | segs = token.strip().split('_') 49 | word = '_'.join(segs[:-1]) 50 | words.append(word) 51 | elif type(pos) is list: 52 | cvt = False 53 | for p in pos: 54 | if p in token: 55 | cvt = True 56 | break 57 | if cvt: 58 | segs = token.strip().split('_') 59 | word = '_'.join(segs[:-1]) 60 | words.append(word) 61 | else: 62 | words.append('') 63 | 64 | _ngram_list = [] 65 | for ngidx in range(ngram, len(words)): 66 | _ngram_list.append(' '.join(words[ngidx - ngram:ngidx])) 67 | ngram_list = [ng for ng in _ngram_list if '' not in ng] 68 | 69 | return Counter(ngram_list) 70 | 71 | 72 | def convert_corpus(dataset, pos, ngram): 73 | return [prepare_ngram(data, pos, ngram) for data in dataset] 74 | 75 | 76 | def score(trans, refs): 77 | 78 | def _precision_recall_fvalue(_trans, _ref): 79 | t_cngrams = 0. 80 | t_rngrams = 0. 81 | m_ngrams = 0. 82 | 83 | for cngrams, rngrams in zip(_trans, _ref): 84 | 85 | t_cngrams += sum(cngrams.values()) 86 | t_rngrams += sum(rngrams.values()) 87 | 88 | for ngram in cngrams: 89 | if ngram in rngrams: 90 | m_ngrams += min(cngrams[ngram], rngrams[ngram]) 91 | 92 | precision = m_ngrams / t_cngrams if t_cngrams > 0 else 0. 93 | recall = m_ngrams / t_rngrams if t_rngrams > 0 else 0. 94 | fvalue = 2 * (recall * precision) / (recall + precision + 1e-8) 95 | 96 | return precision, recall, fvalue 97 | 98 | eval_scores = [_precision_recall_fvalue(trans, ref) for ref in refs] 99 | eval_scores = list(zip(*eval_scores)) 100 | return [sum(v) / len(v) for v in eval_scores] 101 | 102 | 103 | def evaluate_the_rate_of_specific_gram(ref, trs, pos, ngram): 104 | # ref: reference corpus 105 | # trs: translation corpus 106 | # pos: part-of-speech tag 107 | # ngram: n-gram number 108 | 109 | references = [convert_corpus(r, pos, ngram) for r in ref] 110 | candidate = convert_corpus(trs, pos, ngram) 111 | 112 | result = score(candidate, references) 113 | 114 | return pos, ngram, result 115 | 116 | 117 | if __name__ == "__main__": 118 | params = parseargs() 119 | 120 | # loading the reference corpus 121 | corpus = [] 122 | for trans_txt in params.refs: 123 | with open(trans_txt, 'rU') as reader: 124 | corpus.append(reader.readlines()) 125 | if len(corpus) > 1: 126 | for cidx in range(1, len(corpus)): 127 | assert len(corpus[cidx]) == len(corpus[cidx - 1]), 'the length of each reference text must be the same' 128 | 129 | # the focused translation corpus 130 | with open(params.trans, 'rU') as reader: 131 | test = reader.readlines() 132 | assert len(test) == len(corpus[0]), \ 133 | 'the length of translation text should be the same as that of reference text' 134 | 135 | poses = [params.noun, 136 | params.verb, 137 | params.adj, 138 | params.adv, 139 | [params.noun, params.verb], 140 | [params.noun, params.verb, params.adj]] 141 | ngrams = range(params.ngram) 142 | for pos in poses: 143 | for ngram in ngrams: 144 | pos, ngram, evals = evaluate_the_rate_of_specific_gram(corpus, test, pos, ngram) 145 | print('Pos: %s, Ngram: %s, Score %s' % (pos, ngram + 1, str(evals))) 146 | -------------------------------------------------------------------------------- /scripts/multi-bleu-detok.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # This file uses the internal tokenization of mteval-v13a.pl, 7 | # giving the exact same (case-sensitive) results on untokenized text. 8 | # Using this script with detokenized output and untokenized references is 9 | # preferrable over multi-bleu.perl, since scores aren't affected by tokenization differences. 10 | # 11 | # like multi-bleu.perl , it supports plain text input and multiple references. 12 | 13 | # $Id$ 14 | use warnings; 15 | use strict; 16 | 17 | binmode(STDIN, ":utf8"); 18 | use open ':encoding(UTF-8)'; 19 | 20 | my $lowercase = 0; 21 | if ($ARGV[0] eq "-lc") { 22 | $lowercase = 1; 23 | shift; 24 | } 25 | 26 | my $stem = $ARGV[0]; 27 | if (!defined $stem) { 28 | print STDERR "usage: multi-bleu-detok.pl [-lc] reference < hypothesis\n"; 29 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 30 | exit(1); 31 | } 32 | 33 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 34 | 35 | my @REF; 36 | my $ref=0; 37 | while(-e "$stem$ref") { 38 | &add_to_ref("$stem$ref",\@REF); 39 | $ref++; 40 | } 41 | &add_to_ref($stem,\@REF) if -e $stem; 42 | die("ERROR: could not find reference file $stem") unless scalar @REF; 43 | 44 | # add additional references explicitly specified on the command line 45 | shift; 46 | foreach my $stem (@ARGV) { 47 | &add_to_ref($stem,\@REF) if -e $stem; 48 | } 49 | 50 | 51 | 52 | sub add_to_ref { 53 | my ($file,$REF) = @_; 54 | my $s=0; 55 | if ($file =~ /.gz$/) { 56 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 57 | } else { 58 | open(REF,$file) or die "Can't read $file"; 59 | } 60 | while() { 61 | chop; 62 | $_ = tokenization($_); 63 | push @{$$REF[$s++]}, $_; 64 | } 65 | close(REF); 66 | } 67 | 68 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 69 | my $s=0; 70 | while() { 71 | chop; 72 | $_ = lc if $lowercase; 73 | $_ = tokenization($_); 74 | my @WORD = split; 75 | my %REF_NGRAM = (); 76 | my $length_translation_this_sentence = scalar(@WORD); 77 | my ($closest_diff,$closest_length) = (9999,9999); 78 | foreach my $reference (@{$REF[$s]}) { 79 | # print "$s $_ <=> $reference\n"; 80 | $reference = lc($reference) if $lowercase; 81 | my @WORD = split(' ',$reference); 82 | my $length = scalar(@WORD); 83 | my $diff = abs($length_translation_this_sentence-$length); 84 | if ($diff < $closest_diff) { 85 | $closest_diff = $diff; 86 | $closest_length = $length; 87 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 88 | } elsif ($diff == $closest_diff) { 89 | $closest_length = $length if $length < $closest_length; 90 | # from two references with the same closeness to me 91 | # take the *shorter* into account, not the "first" one. 92 | } 93 | for(my $n=1;$n<=4;$n++) { 94 | my %REF_NGRAM_N = (); 95 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 96 | my $ngram = "$n"; 97 | for(my $w=0;$w<$n;$w++) { 98 | $ngram .= " ".$WORD[$start+$w]; 99 | } 100 | $REF_NGRAM_N{$ngram}++; 101 | } 102 | foreach my $ngram (keys %REF_NGRAM_N) { 103 | if (!defined($REF_NGRAM{$ngram}) || 104 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 105 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 106 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 107 | } 108 | } 109 | } 110 | } 111 | $length_translation += $length_translation_this_sentence; 112 | $length_reference += $closest_length; 113 | for(my $n=1;$n<=4;$n++) { 114 | my %T_NGRAM = (); 115 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 116 | my $ngram = "$n"; 117 | for(my $w=0;$w<$n;$w++) { 118 | $ngram .= " ".$WORD[$start+$w]; 119 | } 120 | $T_NGRAM{$ngram}++; 121 | } 122 | foreach my $ngram (keys %T_NGRAM) { 123 | $ngram =~ /^(\d+) /; 124 | my $n = $1; 125 | # my $corr = 0; 126 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 127 | $TOTAL[$n] += $T_NGRAM{$ngram}; 128 | if (defined($REF_NGRAM{$ngram})) { 129 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 130 | $CORRECT[$n] += $T_NGRAM{$ngram}; 131 | # $corr = $T_NGRAM{$ngram}; 132 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 133 | } 134 | else { 135 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 136 | # $corr = $REF_NGRAM{$ngram}; 137 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 138 | } 139 | } 140 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 141 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 142 | } 143 | } 144 | $s++; 145 | } 146 | my $brevity_penalty = 1; 147 | my $bleu = 0; 148 | 149 | my @bleu=(); 150 | 151 | for(my $n=1;$n<=4;$n++) { 152 | if (defined ($TOTAL[$n])){ 153 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 154 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 155 | }else{ 156 | $bleu[$n]=0; 157 | } 158 | } 159 | 160 | if ($length_reference==0){ 161 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 162 | exit(1); 163 | } 164 | 165 | if ($length_translation<$length_reference) { 166 | $brevity_penalty = exp(1-$length_reference/$length_translation); 167 | } 168 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 169 | my_log( $bleu[2] ) + 170 | my_log( $bleu[3] ) + 171 | my_log( $bleu[4] ) ) / 4) ; 172 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 173 | 100*$bleu, 174 | 100*$bleu[1], 175 | 100*$bleu[2], 176 | 100*$bleu[3], 177 | 100*$bleu[4], 178 | $brevity_penalty, 179 | $length_translation / $length_reference, 180 | $length_translation, 181 | $length_reference; 182 | 183 | sub my_log { 184 | return -9999999999 unless $_[0]; 185 | return log($_[0]); 186 | } 187 | 188 | 189 | 190 | sub tokenization 191 | { 192 | my ($norm_text) = @_; 193 | 194 | # language-independent part: 195 | $norm_text =~ s///g; # strip "skipped" tags 196 | $norm_text =~ s/-\n//g; # strip end-of-line hyphenation and join lines 197 | $norm_text =~ s/\n/ /g; # join lines 198 | $norm_text =~ s/"/"/g; # convert SGML tag for quote to " 199 | $norm_text =~ s/&/&/g; # convert SGML tag for ampersand to & 200 | $norm_text =~ s/</ 201 | $norm_text =~ s/>/>/g; # convert SGML tag for greater-than to < 202 | 203 | # language-dependent part (assuming Western languages): 204 | $norm_text = " $norm_text "; 205 | $norm_text =~ s/([\{-\~\[-\` -\&\(-\+\:-\@\/])/ $1 /g; # tokenize punctuation 206 | $norm_text =~ s/([^0-9])([\.,])/$1 $2 /g; # tokenize period and comma unless preceded by a digit 207 | $norm_text =~ s/([\.,])([^0-9])/ $1 $2/g; # tokenize period and comma unless followed by a digit 208 | $norm_text =~ s/([0-9])(-)/$1 $2 /g; # tokenize dash when preceded by a digit 209 | $norm_text =~ s/\s+/ /g; # one space only between words 210 | $norm_text =~ s/^\s+//; # no leading space 211 | $norm_text =~ s/\s+$//; # no trailing space 212 | 213 | return $norm_text; 214 | } 215 | -------------------------------------------------------------------------------- /scripts/multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | # add additional references explicitly specified on the command line 35 | shift; 36 | foreach my $stem (@ARGV) { 37 | &add_to_ref($stem,\@REF) if -e $stem; 38 | } 39 | 40 | 41 | 42 | sub add_to_ref { 43 | my ($file,$REF) = @_; 44 | my $s=0; 45 | if ($file =~ /.gz$/) { 46 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 47 | } else { 48 | open(REF,$file) or die "Can't read $file"; 49 | } 50 | while() { 51 | chomp; 52 | push @{$$REF[$s++]}, $_; 53 | } 54 | close(REF); 55 | } 56 | 57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 58 | my $s=0; 59 | while() { 60 | chomp; 61 | $_ = lc if $lowercase; 62 | my @WORD = split; 63 | my %REF_NGRAM = (); 64 | my $length_translation_this_sentence = scalar(@WORD); 65 | my ($closest_diff,$closest_length) = (9999,9999); 66 | foreach my $reference (@{$REF[$s]}) { 67 | # print "$s $_ <=> $reference\n"; 68 | $reference = lc($reference) if $lowercase; 69 | my @WORD = split(' ',$reference); 70 | my $length = scalar(@WORD); 71 | my $diff = abs($length_translation_this_sentence-$length); 72 | if ($diff < $closest_diff) { 73 | $closest_diff = $diff; 74 | $closest_length = $length; 75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 76 | } elsif ($diff == $closest_diff) { 77 | $closest_length = $length if $length < $closest_length; 78 | # from two references with the same closeness to me 79 | # take the *shorter* into account, not the "first" one. 80 | } 81 | for(my $n=1;$n<=4;$n++) { 82 | my %REF_NGRAM_N = (); 83 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 84 | my $ngram = "$n"; 85 | for(my $w=0;$w<$n;$w++) { 86 | $ngram .= " ".$WORD[$start+$w]; 87 | } 88 | $REF_NGRAM_N{$ngram}++; 89 | } 90 | foreach my $ngram (keys %REF_NGRAM_N) { 91 | if (!defined($REF_NGRAM{$ngram}) || 92 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 93 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 94 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 95 | } 96 | } 97 | } 98 | } 99 | $length_translation += $length_translation_this_sentence; 100 | $length_reference += $closest_length; 101 | for(my $n=1;$n<=4;$n++) { 102 | my %T_NGRAM = (); 103 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 104 | my $ngram = "$n"; 105 | for(my $w=0;$w<$n;$w++) { 106 | $ngram .= " ".$WORD[$start+$w]; 107 | } 108 | $T_NGRAM{$ngram}++; 109 | } 110 | foreach my $ngram (keys %T_NGRAM) { 111 | $ngram =~ /^(\d+) /; 112 | my $n = $1; 113 | # my $corr = 0; 114 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 115 | $TOTAL[$n] += $T_NGRAM{$ngram}; 116 | if (defined($REF_NGRAM{$ngram})) { 117 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 118 | $CORRECT[$n] += $T_NGRAM{$ngram}; 119 | # $corr = $T_NGRAM{$ngram}; 120 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 121 | } 122 | else { 123 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 124 | # $corr = $REF_NGRAM{$ngram}; 125 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 126 | } 127 | } 128 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 129 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 130 | } 131 | } 132 | $s++; 133 | } 134 | my $brevity_penalty = 1; 135 | my $bleu = 0; 136 | 137 | my @bleu=(); 138 | 139 | for(my $n=1;$n<=4;$n++) { 140 | if (defined ($TOTAL[$n])){ 141 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 142 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 143 | }else{ 144 | $bleu[$n]=0; 145 | } 146 | } 147 | 148 | if ($length_reference==0){ 149 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 150 | exit(1); 151 | } 152 | 153 | if ($length_translation<$length_reference) { 154 | $brevity_penalty = exp(1-$length_reference/$length_translation); 155 | } 156 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 157 | my_log( $bleu[2] ) + 158 | my_log( $bleu[3] ) + 159 | my_log( $bleu[4] ) ) / 4) ; 160 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 161 | 100*$bleu, 162 | 100*$bleu[1], 163 | 100*$bleu[2], 164 | 100*$bleu[3], 165 | 100*$bleu[4], 166 | $brevity_penalty, 167 | $length_translation / $length_reference, 168 | $length_translation, 169 | $length_reference; 170 | 171 | 172 | print STDERR "It is not advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n"; 173 | 174 | sub my_log { 175 | return -9999999999 unless $_[0]; 176 | return log($_[0]); 177 | } 178 | -------------------------------------------------------------------------------- /scripts/shuffle_corpus.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import argparse 8 | import numpy 9 | 10 | """Copyright 2018 The THUMT Authors""" 11 | 12 | 13 | def parseargs(): 14 | parser = argparse.ArgumentParser(description="Shuffle corpus") 15 | 16 | parser.add_argument("--corpus", nargs="+", required=True, 17 | help="input corpora") 18 | parser.add_argument("--suffix", type=str, default="shuf", 19 | help="Suffix of output files") 20 | parser.add_argument("--seed", type=int, help="Random seed") 21 | 22 | return parser.parse_args() 23 | 24 | 25 | def main(args): 26 | name = args.corpus 27 | suffix = "." + args.suffix 28 | stream = [open(item, "r") for item in name] 29 | data = [fd.readlines() for fd in stream] 30 | minlen = min([len(lines) for lines in data]) 31 | 32 | if args.seed: 33 | numpy.random.seed(args.seed) 34 | 35 | indices = numpy.arange(minlen) 36 | numpy.random.shuffle(indices) 37 | 38 | newstream = [open(item + suffix, "w") for item in name] 39 | 40 | for idx in indices.tolist(): 41 | lines = [item[idx] for item in data] 42 | 43 | for line, fd in zip(lines, newstream): 44 | fd.write(line) 45 | 46 | for fdr, fdw in zip(stream, newstream): 47 | fdr.close() 48 | fdw.close() 49 | 50 | 51 | if __name__ == "__main__": 52 | parsed_args = parseargs() 53 | main(parsed_args) 54 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | -------------------------------------------------------------------------------- /utils/cycle.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | from utils import dtype 9 | 10 | 11 | def _zero_variables(variables, name=None): 12 | ops = [] 13 | 14 | for var in variables: 15 | with tf.device(var.device): 16 | op = var.assign(tf.zeros_like(var)) 17 | ops.append(op) 18 | 19 | return tf.group(*ops, name=name or "zero_variables") 20 | 21 | 22 | def _replicate_variables(variables, device=None, suffix="Replica"): 23 | new_vars = [] 24 | 25 | for var in variables: 26 | device = device or var.device 27 | with tf.device(device): 28 | name = var.op.name + "/{}".format(suffix) 29 | new_vars.append(tf.Variable(tf.zeros_like(var), 30 | name=name, trainable=False)) 31 | 32 | return new_vars 33 | 34 | 35 | def _collect_gradients(gradients, variables): 36 | ops = [] 37 | 38 | for grad, var in zip(gradients, variables): 39 | if isinstance(grad, tf.Tensor): 40 | ops.append(tf.assign_add(var, grad)) 41 | else: 42 | ops.append(tf.scatter_add(var, grad.indices, grad.values)) 43 | 44 | return tf.group(*ops, name="collect_gradients") 45 | 46 | 47 | def create_train_op(named_scalars, grads_and_vars, optimizer, global_step, params): 48 | tf.get_variable_scope().set_dtype(tf.as_dtype(dtype.floatx())) 49 | 50 | gradients = [item[0] for item in grads_and_vars] 51 | variables = [item[1] for item in grads_and_vars] 52 | 53 | if params.update_cycle == 1: 54 | zero_variables_op = tf.no_op("zero_variables") 55 | collect_op = tf.no_op("collect_op") 56 | else: 57 | named_vars = {} 58 | for name in named_scalars: 59 | named_var = tf.Variable(tf.zeros([], dtype=tf.float32), 60 | name="{}/CTrainOpReplica".format(name), 61 | trainable=False) 62 | named_vars[name] = named_var 63 | count_var = tf.Variable(tf.zeros([], dtype=tf.as_dtype(dtype.floatx())), 64 | name="count/CTrainOpReplica", 65 | trainable=False) 66 | slot_variables = _replicate_variables(variables, suffix='CTrainOpReplica') 67 | zero_variables_op = _zero_variables( 68 | slot_variables + [count_var] + list(named_vars.values())) 69 | 70 | collect_ops = [] 71 | # collect gradients 72 | collect_grads_op = _collect_gradients(gradients, slot_variables) 73 | collect_ops.append(collect_grads_op) 74 | 75 | # collect other scalars 76 | for name in named_scalars: 77 | scalar = named_scalars[name] 78 | named_var = named_vars[name] 79 | collect_op = tf.assign_add(named_var, scalar) 80 | collect_ops.append(collect_op) 81 | # collect counting variable 82 | collect_count_op = tf.assign_add(count_var, 1.0) 83 | collect_ops.append(collect_count_op) 84 | 85 | collect_op = tf.group(*collect_ops, name="collect_op") 86 | scale = 1.0 / (tf.cast(count_var, tf.float32) + 1.0) 87 | gradients = [scale * (g + s) 88 | for (g, s) in zip(gradients, slot_variables)] 89 | 90 | for name in named_scalars: 91 | named_scalars[name] = scale * ( 92 | named_scalars[name] + named_vars[name]) 93 | 94 | grand_norm = tf.global_norm(gradients) 95 | param_norm = tf.global_norm(variables) 96 | 97 | # Gradient clipping 98 | if isinstance(params.clip_grad_norm or None, float): 99 | gradients, _ = tf.clip_by_global_norm(gradients, 100 | params.clip_grad_norm, 101 | use_norm=grand_norm) 102 | 103 | # Update variables 104 | grads_and_vars = list(zip(gradients, variables)) 105 | train_op = optimizer.apply_gradients(grads_and_vars, global_step) 106 | 107 | ops = { 108 | "zero_op": zero_variables_op, 109 | "collect_op": collect_op, 110 | "train_op": train_op 111 | } 112 | 113 | # apply ema 114 | if params.ema_decay > 0.: 115 | tf.logging.info('Using Exp Moving Average to train the model with decay {}.'.format(params.ema_decay)) 116 | ema = tf.train.ExponentialMovingAverage(decay=params.ema_decay, num_updates=global_step) 117 | ema_op = ema.apply(variables) 118 | with tf.control_dependencies([ops['train_op']]): 119 | ops['train_op'] = tf.group(ema_op) 120 | bck_vars = _replicate_variables(variables, suffix="CTrainOpBackUpReplica") 121 | 122 | ops['ema_backup_op'] = tf.group(*(tf.assign(bck, var.read_value()) 123 | for bck, var in zip(bck_vars, variables))) 124 | ops['ema_restore_op'] = tf.group(*(tf.assign(var, bck.read_value()) 125 | for bck, var in zip(bck_vars, variables))) 126 | ops['ema_assign_op'] = tf.group(*(tf.assign(var, ema.average(var).read_value()) 127 | for var in variables)) 128 | 129 | ret = named_scalars 130 | ret.update({ 131 | "gradient_norm": grand_norm, 132 | "parameter_norm": param_norm, 133 | }) 134 | 135 | return ret, ops 136 | -------------------------------------------------------------------------------- /utils/dtype.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | # Copied from Keras 11 | 12 | # the type of float to use throughout the session. 13 | _FLOATX = 'float32' 14 | _EPSILON = 1e-8 15 | _INF = 1e8 16 | 17 | 18 | def epsilon(): 19 | return _EPSILON 20 | 21 | 22 | def set_epsilon(e): 23 | global _EPSILON 24 | _EPSILON = e 25 | 26 | 27 | def inf(): 28 | return _INF 29 | 30 | 31 | def set_inf(e): 32 | global _INF 33 | _INF = e 34 | 35 | 36 | def floatx(): 37 | return _FLOATX 38 | 39 | 40 | def set_floatx(floatx): 41 | global _FLOATX 42 | if floatx not in {'float16', 'float32', 'float64'}: 43 | raise ValueError('Unknown floatx type: ' + str(floatx)) 44 | _FLOATX = str(floatx) 45 | 46 | 47 | def np_to_float(x): 48 | return np.asarray(x, dtype=_FLOATX) 49 | 50 | 51 | def tf_to_float(x): 52 | return tf.cast(x, tf.as_dtype(floatx())) 53 | 54 | 55 | def float32_variable_storage_getter(getter, name, shape=None, dtype=None, 56 | initializer=None, regularizer=None, 57 | trainable=True, 58 | *args, **kwargs): 59 | """Custom variable getter that forces trainable variables to be stored in 60 | float32 precision and then casts them to the training precision. 61 | """ 62 | storage_dtype = tf.float32 if trainable else dtype 63 | variable = getter(name, shape, dtype=storage_dtype, 64 | initializer=initializer, regularizer=regularizer, 65 | trainable=trainable, 66 | *args, **kwargs) 67 | if trainable and dtype != tf.float32: 68 | variable = tf.cast(variable, dtype) 69 | return variable 70 | -------------------------------------------------------------------------------- /utils/parallel.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import six 8 | import tensorflow as tf 9 | import tensorflow.contrib as tc 10 | 11 | from tensorflow.python.training import device_setter 12 | from tensorflow.python.framework import device as pydev 13 | from tensorflow.core.framework import node_def_pb2 14 | 15 | from utils import util, dtype 16 | 17 | 18 | def local_device_setter(num_devices=1, 19 | ps_device_type='cpu', 20 | worker_device='/cpu:0', 21 | ps_ops=None, 22 | ps_strategy=None): 23 | if ps_ops is None: 24 | ps_ops = ['Variable', 'VariableV2', 'VarHandleOp'] 25 | 26 | if ps_strategy is None: 27 | ps_strategy = device_setter._RoundRobinStrategy(num_devices) 28 | if not six.callable(ps_strategy): 29 | raise TypeError("ps_strategy must be callable") 30 | 31 | def _local_device_chooser(op): 32 | current_device = pydev.DeviceSpec.from_string(op.device or "") 33 | 34 | node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def 35 | if node_def.op in ps_ops: 36 | ps_device_spec = pydev.DeviceSpec.from_string( 37 | '/{}:{}'.format(ps_device_type, ps_strategy(op))) 38 | 39 | ps_device_spec.merge_from(current_device) 40 | return ps_device_spec.to_string() 41 | else: 42 | worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "") 43 | worker_device_spec.merge_from(current_device) 44 | return worker_device_spec.to_string() 45 | 46 | return _local_device_chooser 47 | 48 | 49 | def _maybe_repeat(x, n): 50 | if isinstance(x, list): 51 | assert len(x) == n 52 | return x 53 | else: 54 | return [x] * n 55 | 56 | 57 | def _reshape_output(outputs): 58 | # assumption: or outputs[0] are all tensor lists/tuples, 59 | # or outputs[0] are dictionaries 60 | if isinstance(outputs[0], (tuple, list)): 61 | outputs = list(zip(*outputs)) 62 | outputs = tuple([list(o) for o in outputs]) 63 | else: 64 | if not isinstance(outputs[0], dict): 65 | return outputs 66 | 67 | assert isinstance(outputs[0], dict), \ 68 | 'invalid data type %s' % type(outputs[0]) 69 | 70 | combine_outputs = {} 71 | for key in outputs[0]: 72 | combine_outputs[key] = [o[key] for o in outputs] 73 | outputs = combine_outputs 74 | 75 | return outputs 76 | 77 | 78 | # Data-level parallelism 79 | def data_parallelism(device_type, num_devices, fn, *args, **kwargs): 80 | # Replicate args and kwargs 81 | if args: 82 | new_args = [_maybe_repeat(arg, num_devices) for arg in args] 83 | # Transpose 84 | new_args = [list(x) for x in zip(*new_args)] 85 | else: 86 | new_args = [[] for _ in range(num_devices)] 87 | 88 | new_kwargs = [{} for _ in range(num_devices)] 89 | 90 | for k, v in kwargs.items(): 91 | vals = _maybe_repeat(v, num_devices) 92 | 93 | for i in range(num_devices): 94 | new_kwargs[i][k] = vals[i] 95 | 96 | fns = _maybe_repeat(fn, num_devices) 97 | 98 | # Now make the parallel call. 99 | outputs = [] 100 | for i in range(num_devices): 101 | worker = "/{}:{}".format(device_type, i) 102 | if device_type == 'cpu': 103 | _device_setter = local_device_setter(worker_device=worker) 104 | else: 105 | _device_setter = local_device_setter( 106 | ps_device_type='gpu', 107 | worker_device=worker, 108 | ps_strategy=tc.training.GreedyLoadBalancingStrategy( 109 | num_devices, tc.training.byte_size_load_fn) 110 | ) 111 | 112 | with tf.variable_scope(tf.get_variable_scope(), reuse=bool(i != 0), 113 | dtype=tf.as_dtype(dtype.floatx())): 114 | with tf.name_scope("tower_%d" % i): 115 | with tf.device(_device_setter): 116 | outputs.append(fns[i](*new_args[i], **new_kwargs[i])) 117 | 118 | return _reshape_output(outputs) 119 | 120 | 121 | def parallel_model(model_fn, features, devices, use_cpu=False): 122 | device_type = 'gpu' 123 | num_devices = len(devices) 124 | 125 | if use_cpu: 126 | device_type = 'cpu' 127 | num_devices = 1 128 | 129 | outputs = data_parallelism(device_type, num_devices, model_fn, features) 130 | 131 | return outputs 132 | 133 | 134 | def average_gradients(tower_grads, mask=None): 135 | """Modified from Bilm""" 136 | 137 | # optimizer for single device 138 | if len(tower_grads) == 1: 139 | return tower_grads[0] 140 | 141 | # calculate average gradient for each shared variable across all GPUs 142 | def _deduplicate_indexed_slices(values, indices): 143 | """Sums `values` associated with any non-unique `indices`.""" 144 | unique_indices, new_index_positions = tf.unique(indices) 145 | summed_values = tf.unsorted_segment_sum( 146 | values, new_index_positions, 147 | tf.shape(unique_indices)[0]) 148 | return summed_values, unique_indices 149 | 150 | average_grads = [] 151 | for grad_and_vars in zip(*tower_grads): 152 | # Note that each grad_and_vars looks like the following: 153 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 154 | # We need to average the gradients across each GPU. 155 | 156 | g0, v0 = grad_and_vars[0] 157 | 158 | if g0 is None: 159 | # no gradient for this variable, skip it 160 | tf.logging.warn("{} has no gradient".format(v0.name)) 161 | average_grads.append((g0, v0)) 162 | continue 163 | 164 | if isinstance(g0, tf.IndexedSlices): 165 | # If the gradient is type IndexedSlices then this is a sparse 166 | # gradient with attributes indices and values. 167 | # To average, need to concat them individually then create 168 | # a new IndexedSlices object. 169 | indices = [] 170 | values = [] 171 | for g, v in grad_and_vars: 172 | indices.append(g.indices) 173 | values.append(g.values) 174 | all_indices = tf.concat(indices, 0) 175 | if mask is None: 176 | avg_values = tf.concat(values, 0) / len(grad_and_vars) 177 | else: 178 | avg_values = tf.concat(values, 0) / tf.reduce_sum(mask) 179 | # deduplicate across indices 180 | av, ai = _deduplicate_indexed_slices(avg_values, all_indices) 181 | grad = tf.IndexedSlices(av, ai, dense_shape=g0.dense_shape) 182 | else: 183 | # a normal tensor can just do a simple average 184 | grads = [] 185 | for g, v in grad_and_vars: 186 | # Add 0 dimension to the gradients to represent the tower. 187 | expanded_g = tf.expand_dims(g, 0) 188 | # Append on a 'tower' dimension which we will average over 189 | grads.append(expanded_g) 190 | 191 | # Average over the 'tower' dimension. 192 | grad = tf.concat(grads, 0) 193 | if mask is not None: 194 | grad = tf.boolean_mask( 195 | grad, tf.cast(mask, tf.bool), axis=0) 196 | grad = tf.reduce_mean(grad, 0) 197 | 198 | # the Variables are redundant because they are shared 199 | # across towers. So.. just return the first tower's pointer to 200 | # the Variable. 201 | v = grad_and_vars[0][1] 202 | grad_and_var = (grad, v) 203 | 204 | average_grads.append(grad_and_var) 205 | 206 | assert len(average_grads) == len(list(zip(*tower_grads))) 207 | 208 | return average_grads 209 | -------------------------------------------------------------------------------- /utils/queuer.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | The Queue function mainly deals with reading and preparing dataset in a multi-processing manner. 5 | We didnot use the built-in tensorflow function Dataset because it lacks of flexibility. 6 | The function defined below is mainly inspired by https://github.com/ixlan/machine-learning-data-pipeline. 7 | """ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | from multiprocessing import Process, Queue 14 | 15 | TERMINATION_TOKEN = "" 16 | 17 | 18 | def create_iter_from_queue(queue, term_token): 19 | 20 | while True: 21 | input_data_chunk = queue.get() 22 | if input_data_chunk == term_token: 23 | # put it back to the queue to let other processes that feed 24 | # from the same one to know that they should also break 25 | queue.put(term_token) 26 | break 27 | else: 28 | yield input_data_chunk 29 | 30 | 31 | def combine_reader_to_processor(reader, preprocessor): 32 | for data_chunk in reader: 33 | yield preprocessor(data_chunk) 34 | 35 | 36 | class EnQueuer(object): 37 | def __init__(self, 38 | reader, 39 | preprocessor, 40 | worker_processes_num=1, 41 | input_queue_size=5, 42 | output_queue_size=5 43 | ): 44 | if worker_processes_num < 0: 45 | raise ValueError("worker_processes_num must be a " 46 | "non-negative integer.") 47 | 48 | self.worker_processes_number = worker_processes_num 49 | self.preprocessor = preprocessor 50 | self.input_queue_size = input_queue_size 51 | self.output_queue_size = output_queue_size 52 | self.reader = reader 53 | 54 | # make the queue iterable 55 | def __iter__(self): 56 | return self._create_processed_data_chunks_gen(self.reader) 57 | 58 | def _create_processed_data_chunks_gen(self, reader_gen): 59 | if self.worker_processes_number == 0: 60 | itr = self._create_single_process_gen(reader_gen) 61 | else: 62 | itr = self._create_multi_process_gen(reader_gen) 63 | return itr 64 | 65 | def _create_single_process_gen(self, data_producer): 66 | return combine_reader_to_processor(data_producer, self.preprocessor) 67 | 68 | def _create_multi_process_gen(self, reader_gen): 69 | term_tokens_received = 0 70 | output_queue = Queue(self.output_queue_size) 71 | workers = [] 72 | 73 | if self.worker_processes_number > 1: 74 | term_tokens_expected = self.worker_processes_number - 1 75 | input_queue = Queue(self.input_queue_size) 76 | reader_worker = _ParallelWorker(reader_gen, input_queue) 77 | workers.append(reader_worker) 78 | 79 | # adding workers that will process the data 80 | for _ in range(self.worker_processes_number - 1): 81 | # since data-chunks will appear in the queue, making an iterable 82 | # object over it 83 | queue_iter = create_iter_from_queue(input_queue, 84 | TERMINATION_TOKEN) 85 | 86 | data_itr = combine_reader_to_processor(queue_iter, self.preprocessor) 87 | proc_worker = _ParallelWorker(data_chunk_iter=data_itr, 88 | queue=output_queue) 89 | workers.append(proc_worker) 90 | else: 91 | term_tokens_expected = 1 92 | 93 | data_itr = combine_reader_to_processor(reader_gen, self.preprocessor) 94 | proc_worker = _ParallelWorker(data_chunk_iter=data_itr, 95 | queue=output_queue) 96 | workers.append(proc_worker) 97 | 98 | for pr in workers: 99 | pr.daemon = True 100 | pr.start() 101 | 102 | while True: 103 | data_chunk = output_queue.get() 104 | if data_chunk == TERMINATION_TOKEN: 105 | term_tokens_received += 1 106 | # need to received all tokens in order to be sure that 107 | # all data has been processed 108 | if term_tokens_received == term_tokens_expected: 109 | for pr in workers: 110 | pr.join() 111 | break 112 | continue 113 | yield data_chunk 114 | 115 | 116 | class _ParallelWorker(Process): 117 | """Worker to execute data reading or processing on a separate process.""" 118 | 119 | def __init__(self, data_chunk_iter, queue): 120 | super(_ParallelWorker, self).__init__() 121 | self._data_chunk_iterable = data_chunk_iter 122 | self._queue = queue 123 | 124 | def run(self): 125 | for data_chunk in self._data_chunk_iterable: 126 | self._queue.put(data_chunk) 127 | self._queue.put(TERMINATION_TOKEN) 128 | -------------------------------------------------------------------------------- /utils/recorder.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import json 8 | import tensorflow as tf 9 | 10 | 11 | class Recorder(object): 12 | """To save training processes, inspired by Nematus""" 13 | 14 | def load_from_json(self, file_name): 15 | tf.logging.info("Loading recoder file from {}".format(file_name)) 16 | record = json.load(open(file_name, 'rb')) 17 | record = dict((key.encode("UTF-8"), value) for (key, value) in record.items()) 18 | self.__dict__.update(record) 19 | 20 | def save_to_json(self, file_name): 21 | tf.logging.info("Saving recorder file into {}".format(file_name)) 22 | with open(file_name, 'wb') as writer: 23 | writer.write(json.dumps(self.__dict__, indent=2).encode("utf-8")) 24 | writer.close() 25 | -------------------------------------------------------------------------------- /utils/saver.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import tensorflow as tf 9 | 10 | 11 | class Saver(object): 12 | def __init__(self, 13 | checkpoints=5, # save the latest number of checkpoints 14 | output_dir=None, # the output directory 15 | best_score=-1, # the best bleu score before 16 | best_checkpoints=1, # the best checkpoints saved in best checkpoints directory 17 | ): 18 | if output_dir is None: 19 | output_dir = "./output" 20 | self.output_dir = output_dir 21 | self.output_best_dir = os.path.join(output_dir, "best") 22 | 23 | self.saver = tf.train.Saver( 24 | max_to_keep=checkpoints 25 | ) 26 | # handle disrupted checkpoints 27 | if tf.gfile.Exists(self.output_dir): 28 | ckpt = tf.train.get_checkpoint_state(self.output_dir) 29 | if ckpt and ckpt.all_model_checkpoint_paths: 30 | self.saver.recover_last_checkpoints(list(ckpt.all_model_checkpoint_paths)) 31 | 32 | self.best_saver = tf.train.Saver( 33 | max_to_keep=best_checkpoints, 34 | ) 35 | # handle disrupted checkpoints 36 | if tf.gfile.Exists(self.output_best_dir): 37 | ckpt = tf.train.get_checkpoint_state(self.output_best_dir) 38 | if ckpt and ckpt.all_model_checkpoint_paths: 39 | self.best_saver.recover_last_checkpoints(list(ckpt.all_model_checkpoint_paths)) 40 | 41 | self.best_score = best_score 42 | # check best bleu result 43 | metric_dir = os.path.join(self.output_best_dir, "metric.log") 44 | if tf.gfile.Exists(metric_dir): 45 | metric_lines = open(metric_dir).readlines() 46 | if len(metric_lines) > 0: 47 | best_score_line = metric_lines[-1] 48 | self.best_score = float(best_score_line.strip().split()[-1]) 49 | 50 | # check the top_k_best list and results 51 | self.topk_scores = [] 52 | topk_dir = os.path.join(self.output_best_dir, "topk_checkpoint") 53 | ckpt_dir = os.path.join(self.output_best_dir, "checkpoint") 54 | # direct load the topk information from topk_checkpoints 55 | if tf.gfile.Exists(topk_dir): 56 | with tf.gfile.Open(topk_dir) as reader: 57 | for line in reader: 58 | model_name, score = line.strip().split("\t") 59 | self.topk_scores.append((model_name, float(score))) 60 | # backup plan to normal checkpoints and best scores 61 | elif tf.gfile.Exists(ckpt_dir): 62 | latest_checkpoint = tf.gfile.Open(ckpt_dir).readline() 63 | model_name = latest_checkpoint.strip().split(":")[1].strip() 64 | model_name = model_name[1:-1] # remove "" 65 | self.topk_scores.append((model_name, self.best_score)) 66 | self.best_checkpoints = best_checkpoints 67 | 68 | self.score_record = tf.gfile.Open(metric_dir, mode="a+") 69 | 70 | def save(self, session, step, metric_score=None): 71 | if not tf.gfile.Exists(self.output_dir): 72 | tf.gfile.MkDir(self.output_dir) 73 | if not tf.gfile.Exists(self.output_best_dir): 74 | tf.gfile.MkDir(self.output_best_dir) 75 | 76 | self.saver.save(session, os.path.join(self.output_dir, "model"), global_step=step) 77 | 78 | def _move(path, new_path): 79 | if tf.gfile.Exists(path): 80 | if tf.gfile.Exists(new_path): 81 | tf.gfile.Remove(new_path) 82 | tf.gfile.Copy(path, new_path) 83 | 84 | if metric_score is not None and metric_score > self.best_score: 85 | self.best_score = metric_score 86 | 87 | _move(os.path.join(self.output_dir, "param.json"), 88 | os.path.join(self.output_best_dir, "param.json")) 89 | _move(os.path.join(self.output_dir, "record.json"), 90 | os.path.join(self.output_best_dir, "record.json")) 91 | 92 | # this recorder only record best scores 93 | self.score_record.write("Steps {}, Metric Score {}\n".format(step, metric_score)) 94 | self.score_record.flush() 95 | 96 | # either no model is saved, or current metric score is better than the minimum one 97 | if metric_score is not None and \ 98 | (len(self.topk_scores) == 0 or len(self.topk_scores) < self.best_checkpoints or 99 | metric_score > min([v[1] for v in self.topk_scores])): 100 | # manipulate the 'checkpoints', and change the orders 101 | ckpt_dir = os.path.join(self.output_best_dir, "checkpoint") 102 | if len(self.topk_scores) > 0: 103 | sorted_topk_scores = sorted(self.topk_scores, key=lambda x: x[1]) 104 | with tf.gfile.Open(ckpt_dir, mode='w') as writer: 105 | best_ckpt = sorted_topk_scores[-1] 106 | writer.write("model_checkpoint_path: \"{}\"\n".format(best_ckpt[0])) 107 | for model_name, _ in sorted_topk_scores: 108 | writer.write("all_model_checkpoint_paths: \"{}\"\n".format(model_name)) 109 | writer.flush() 110 | 111 | # update best_saver internal checkpoints status 112 | ckpt = tf.train.get_checkpoint_state(self.output_best_dir) 113 | if ckpt and ckpt.all_model_checkpoint_paths: 114 | self.best_saver.recover_last_checkpoints(list(ckpt.all_model_checkpoint_paths)) 115 | 116 | # this change mainly inspired by that sometimes for dataset, 117 | # the best performance is achieved by averaging top-k checkpoints 118 | self.best_saver.save( 119 | session, os.path.join(self.output_best_dir, "model"), global_step=step) 120 | 121 | # handle topk scores 122 | self.topk_scores.append(("model-{}".format(int(step)), float(metric_score))) 123 | sorted_topk_scores = sorted(self.topk_scores, key=lambda x: x[1]) 124 | self.topk_scores = sorted_topk_scores[-self.best_checkpoints:] 125 | topk_dir = os.path.join(self.output_best_dir, "topk_checkpoint") 126 | with tf.gfile.Open(topk_dir, mode='w') as writer: 127 | for model_name, score in self.topk_scores: 128 | writer.write("{}\t{}\n".format(model_name, score)) 129 | writer.flush() 130 | 131 | def restore(self, session, path=None): 132 | if path is not None and tf.gfile.Exists(path): 133 | check_dir = path 134 | else: 135 | check_dir = self.output_dir 136 | 137 | checkpoint = os.path.join(check_dir, "checkpoint") 138 | if not tf.gfile.Exists(checkpoint): 139 | tf.logging.warn("No Existing Model detected") 140 | else: 141 | latest_checkpoint = tf.gfile.Open(checkpoint).readline() 142 | model_name = latest_checkpoint.strip().split(":")[1].strip() 143 | model_name = model_name[1:-1] # remove "" 144 | model_path = os.path.join(check_dir, model_name) 145 | model_path = os.path.abspath(model_path) 146 | if not tf.gfile.Exists(model_path+".meta"): 147 | tf.logging.error("model '{}' does not exists" 148 | .format(model_path)) 149 | else: 150 | try: 151 | self.saver.restore(session, model_path) 152 | except tf.errors.NotFoundError: 153 | # In this case, we simply assume that the cycle part 154 | # is mismatched, where the replicas are missing. 155 | # This would happen if you switch from un-cycle mode 156 | # to cycle mode. 157 | tf.logging.warn("Starting Backup Restore") 158 | ops = [] 159 | reader = tf.train.load_checkpoint(model_path) 160 | for var in tf.global_variables(): 161 | name = var.op.name 162 | 163 | if reader.has_tensor(name): 164 | tf.logging.info('{} get initialization from {}' 165 | .format(name, name)) 166 | ops.append( 167 | tf.assign(var, reader.get_tensor(name))) 168 | else: 169 | tf.logging.warn("{} is missed".format(name)) 170 | restore_op = tf.group(*ops, name="restore_global_vars") 171 | session.run(restore_op) 172 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import time 9 | import pkgutil 10 | import collections 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from utils import dtype 15 | 16 | 17 | def batch_indexer(datasize, batch_size): 18 | """Just divide the datasize into batched size""" 19 | dataindex = np.arange(datasize).tolist() 20 | 21 | batchindex = [] 22 | for i in range(datasize // batch_size): 23 | batchindex.append(dataindex[i * batch_size: (i + 1) * batch_size]) 24 | if datasize % batch_size > 0: 25 | batchindex.append(dataindex[-(datasize % batch_size):]) 26 | 27 | return batchindex 28 | 29 | 30 | def token_indexer(dataset, token_size): 31 | """Divide the dataset into token-based batch""" 32 | # assume dataset format: [(len1, len2, ..., lenN)] 33 | dataindex = np.arange(len(dataset)).tolist() 34 | 35 | batchindex = [] 36 | 37 | _batcher = [0.] * len(dataset[0]) 38 | _counter = 0 39 | i = 0 40 | while True: 41 | if i >= len(dataset): break 42 | 43 | # attempt put this datapoint into batch 44 | _batcher = [max(max_l, l) 45 | for max_l, l in zip(_batcher, dataset[i])] 46 | _counter += 1 47 | for l in _batcher: 48 | if _counter * l >= token_size: 49 | # when an extreme instance occur, handle it by making a 1-size batch 50 | if _counter > 1: 51 | batchindex.append(dataindex[i-_counter+1: i]) 52 | i -= 1 53 | else: 54 | batchindex.append(dataindex[i: i+1]) 55 | 56 | _counter = 0 57 | _batcher = [0.] * len(dataset[0]) 58 | break 59 | 60 | i += 1 61 | 62 | _counter = sum([len(slice) for slice in batchindex]) 63 | if _counter != len(dataset): 64 | batchindex.append(dataindex[_counter:]) 65 | return batchindex 66 | 67 | 68 | def mask_scale(value, mask, scale=None): 69 | """Prepared for masked softmax""" 70 | if scale is None: 71 | scale = dtype.inf() 72 | return value + (1. - mask) * (-scale) 73 | 74 | 75 | def valid_apply_dropout(x, dropout): 76 | """To check whether the dropout value is valid, apply if valid""" 77 | if dropout is not None and 0. <= dropout <= 1.: 78 | return tf.nn.dropout(x, 1. - dropout) 79 | return x 80 | 81 | 82 | def layer_dropout(dropped, no_dropped, dropout_rate): 83 | """Layer Dropout""" 84 | pred = tf.random_uniform([]) < dropout_rate 85 | return tf.cond(pred, lambda: dropped, lambda: no_dropped) 86 | 87 | 88 | def label_smooth(labels, vocab_size, factor=0.1): 89 | """Smooth the gold label distribution""" 90 | if 0. < factor < 1.: 91 | n = tf.cast(vocab_size - 1, tf.float32) 92 | p = 1. - factor 93 | q = factor / n 94 | 95 | t = tf.one_hot(tf.cast(tf.reshape(labels, [-1]), tf.int32), 96 | depth=vocab_size, on_value=p, off_value=q) 97 | normalizing = -(p * tf.log(p) + n * q * tf.log(q + 1e-20)) 98 | else: 99 | t = tf.one_hot(tf.cast(tf.reshape(labels, [-1]), tf.int32), 100 | depth=vocab_size) 101 | normalizing = 0. 102 | 103 | return t, normalizing 104 | 105 | 106 | def closing_dropout(params): 107 | """Removing all dropouts""" 108 | for k, v in params.values().items(): 109 | if 'dropout' in k: 110 | setattr(params, k, 0.0) 111 | # consider closing label smoothing 112 | if 'label_smoothing' in k: 113 | setattr(params, k, 0.0) 114 | return params 115 | 116 | 117 | def dict_update(d, u): 118 | """Recursive update dictionary""" 119 | for k, v in u.items(): 120 | if isinstance(v, collections.Mapping): 121 | d[k] = dict_update(d.get(k, {}), v) 122 | else: 123 | d[k] = v 124 | return d 125 | 126 | 127 | def shape_list(x): 128 | # Copied from Tensor2Tensor 129 | """Return list of dims, statically where possible.""" 130 | x = tf.convert_to_tensor(x) 131 | 132 | # If unknown rank, return dynamic shape 133 | if x.get_shape().dims is None: 134 | return tf.shape(x) 135 | 136 | static = x.get_shape().as_list() 137 | shape = tf.shape(x) 138 | 139 | ret = [] 140 | for i in range(len(static)): 141 | dim = static[i] 142 | if dim is None: 143 | dim = shape[i] 144 | ret.append(dim) 145 | return ret 146 | 147 | 148 | def get_shape_invariants(tensor): 149 | # Copied from Tensor2Tensor 150 | """Returns the shape of the tensor but sets middle dims to None.""" 151 | shape = tensor.shape.as_list() 152 | for i in range(1, len(shape) - 1): 153 | shape[i] = None 154 | 155 | return tf.TensorShape(shape) 156 | 157 | 158 | def merge_neighbor_dims(x, axis=0): 159 | """Merge neighbor dimension of x, start by axis""" 160 | if len(x.get_shape().as_list()) < axis + 2: 161 | return x 162 | 163 | shape = shape_list(x) 164 | shape[axis] *= shape[axis+1] 165 | shape.pop(axis+1) 166 | return tf.reshape(x, shape) 167 | 168 | 169 | def unmerge_neighbor_dims(x, depth, axis=0): 170 | """Inverse of merge_neighbor_dims, axis by depth""" 171 | if len(x.get_shape().as_list()) < axis + 1: 172 | return x 173 | 174 | shape = shape_list(x) 175 | width = shape[axis] // depth 176 | new_shape = shape[:axis] + [depth, width] + shape[axis+1:] 177 | return tf.reshape(x, new_shape) 178 | 179 | 180 | def expand_tile_dims(x, depth, axis=1): 181 | """Expand and Tile x on axis by depth""" 182 | x = tf.expand_dims(x, axis=axis) 183 | tile_dims = [1] * x.shape.ndims 184 | tile_dims[axis] = depth 185 | 186 | return tf.tile(x, tile_dims) 187 | 188 | 189 | def gumbel_noise(shape, eps=None): 190 | """Generate gumbel noise shaped by shape""" 191 | if eps is None: 192 | eps = dtype.epsilon() 193 | 194 | u = tf.random_uniform(shape, minval=0, maxval=1) 195 | return -tf.log(-tf.log(u + eps) + eps) 196 | 197 | 198 | def log_prob_from_logits(logits): 199 | """Probability from un-nomalized logits""" 200 | return logits - tf.reduce_logsumexp(logits, axis=-1, keepdims=True) 201 | 202 | 203 | def batch_coordinates(batch_size, beam_size): 204 | """Batch coordinate indices under beam_size""" 205 | batch_pos = tf.range(batch_size * beam_size) // beam_size 206 | batch_pos = tf.reshape(batch_pos, [batch_size, beam_size]) 207 | 208 | return batch_pos 209 | 210 | 211 | def variable_printer(): 212 | """Print parameters""" 213 | all_weights = {v.name: v for v in tf.trainable_variables()} 214 | total_size = 0 215 | 216 | for v_name in sorted(list(all_weights)): 217 | v = all_weights[v_name] 218 | tf.logging.info("%s\tshape %s", v.name[:-2].ljust(80), 219 | str(v.shape).ljust(20)) 220 | v_size = np.prod(np.array(v.shape.as_list())).tolist() 221 | total_size += v_size 222 | tf.logging.info("Total trainable variables size: %d", total_size) 223 | 224 | 225 | def uniform_splits(total_size, num_shards): 226 | """Split the total_size into uniform num_shards lists""" 227 | size_per_shards = total_size // num_shards 228 | splits = [size_per_shards] * (num_shards - 1) + \ 229 | [total_size - (num_shards - 1) * size_per_shards] 230 | 231 | return splits 232 | 233 | 234 | def fetch_valid_ref_files(path): 235 | """Extracting valid reference files according to MT convention""" 236 | path = os.path.abspath(path) 237 | if tf.gfile.Exists(path): 238 | return [path] 239 | 240 | if not tf.gfile.Exists(path + ".ref0"): 241 | tf.logging.warn("Invalid Reference Format {}".format(path)) 242 | return None 243 | 244 | num = 0 245 | files = [] 246 | while True: 247 | file_path = path + ".ref%s" % num 248 | if tf.gfile.Exists(file_path): 249 | files.append(file_path) 250 | else: 251 | break 252 | num += 1 253 | return files 254 | 255 | 256 | def get_session(gpus): 257 | """Config session with GPUS""" 258 | 259 | sess_config = tf.ConfigProto(allow_soft_placement=True) 260 | sess_config.gpu_options.allow_growth = True 261 | if len(gpus) > 0: 262 | device_str = ",".join([str(i) for i in gpus]) 263 | sess_config.gpu_options.visible_device_list = device_str 264 | sess = tf.Session(config=sess_config) 265 | 266 | return sess 267 | 268 | 269 | def flatten_list(values): 270 | """Flatten a list""" 271 | return [v for value in values for v in value] 272 | 273 | 274 | def remove_invalid_seq(sequence, mask): 275 | """Pick valid sequence elements wrt mask""" 276 | # sequence: [batch, sequence] 277 | # mask: [batch, sequence] 278 | boolean_mask = tf.reduce_sum(mask, axis=0) 279 | 280 | # make sure that there are at least one element in the mask 281 | first_one = tf.one_hot(0, tf.shape(boolean_mask)[0], 282 | dtype=tf.as_dtype(dtype.floatx())) 283 | boolean_mask = tf.cast(boolean_mask + first_one, tf.bool) 284 | 285 | filtered_seq = tf.boolean_mask(sequence, boolean_mask, axis=1) 286 | filtered_mask = tf.boolean_mask(mask, boolean_mask, axis=1) 287 | return filtered_seq, filtered_mask 288 | 289 | 290 | def time_str(t=None): 291 | """String format of the time long data""" 292 | if t is None: 293 | t = time.time() 294 | ts = time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime(t)) 295 | return ts 296 | 297 | 298 | def dynamic_load_module(module, prefix=None): 299 | """Load submodules inside a module, mainly used for model loading, not robust!!!""" 300 | # loading all models under directory `models` dynamically 301 | if not isinstance(module, str): 302 | module = module.__path__ 303 | for importer, modname, ispkg in pkgutil.iter_modules(module): 304 | if prefix is None: 305 | __import__(modname) 306 | else: 307 | __import__("{}.{}".format(prefix, modname)) 308 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import argparse 8 | 9 | 10 | class Vocab(object): 11 | def __init__(self, vocab_file=None): 12 | self.word2id = {} 13 | self.id2word = {} 14 | self.word2count = {} 15 | 16 | self.pad_sym = "" 17 | self.eos_sym = "" 18 | self.unk_sym = "" 19 | 20 | self.insert(self.pad_sym) 21 | self.insert(self.unk_sym) 22 | self.insert(self.eos_sym) 23 | 24 | if vocab_file is not None: 25 | self.load_vocab(vocab_file) 26 | 27 | def insert(self, token): 28 | if token not in self.word2id: 29 | index = len(self.word2id) 30 | self.word2id[token] = index 31 | self.id2word[index] = token 32 | 33 | self.word2count[token] = 0 34 | self.word2count[token] += 1 35 | 36 | def size(self): 37 | return len(self.word2id) 38 | 39 | def load_vocab(self, vocab_file): 40 | with open(vocab_file, 'r') as reader: 41 | for token in reader: 42 | self.insert(token.strip()) 43 | 44 | def get_token(self, id): 45 | if id in self.id2word: 46 | return self.id2word[id] 47 | return self.unk_sym 48 | 49 | def get_id(self, token): 50 | if token in self.word2id: 51 | return self.word2id[token] 52 | return self.word2id[self.unk_sym] 53 | 54 | def sort_vocab(self): 55 | sorted_word2count = sorted( 56 | self.word2count.items(), key=lambda x: - x[1]) 57 | self.word2id, self.id2word = {}, {} 58 | self.insert(self.pad_sym) 59 | self.insert(self.unk_sym) 60 | self.insert(self.eos_sym) 61 | for word, _ in sorted_word2count: 62 | self.insert(word) 63 | 64 | def save_vocab(self, vocab_file, size=1e6): 65 | with open(vocab_file, 'w') as writer: 66 | for id in range(min(self.size(), int(size))): 67 | writer.write(self.id2word[id] + "\n") 68 | 69 | def to_id(self, tokens, append_eos=True): 70 | if not append_eos: 71 | return [self.get_id(token) for token in tokens] 72 | else: 73 | return [self.get_id(token) for token in tokens + [self.eos_sym]] 74 | 75 | def to_tokens(self, ids): 76 | return [self.get_token(id) for id in ids] 77 | 78 | def eos(self): 79 | return self.get_id(self.eos_sym) 80 | 81 | def pad(self): 82 | return self.get_id(self.pad_sym) 83 | 84 | 85 | if __name__ == "__main__": 86 | parser = argparse.ArgumentParser('Vocabulary Preparison') 87 | parser.add_argument('--size', type=int, default=1e6, help='maximum vocabulary size') 88 | parser.add_argument('input', type=str, help='the input file path') 89 | parser.add_argument('output', type=str, help='the output file name') 90 | 91 | args = parser.parse_args() 92 | 93 | vocab = Vocab() 94 | with open(args.input, 'r') as reader: 95 | for line in reader: 96 | for token in line.strip().split(): 97 | vocab.insert(token) 98 | 99 | vocab.sort_vocab() 100 | vocab.save_vocab(args.output, args.size) 101 | 102 | print("Loading {} tokens from {}".format(vocab.size(), args.input)) 103 | --------------------------------------------------------------------------------