├── .gitignore
├── LICENSE
├── README.md
├── data.py
├── docs
├── afs_speech_translation
│ ├── README.md
│ ├── afs_training.png
│ └── example.png
├── colactc
│ ├── README.md
│ ├── colactc.png
│ └── mt.png
├── conditional_language_specific_routing
│ └── README.md
├── context_aware_st
│ ├── README.md
│ ├── cast.png
│ └── training.png
├── depth_scale_init_and_merged_attention
│ ├── README.md
│ ├── dsinit.png
│ └── grad.png
├── interleaved_bidirectional_transformer
│ ├── README.md
│ └── overview.png
├── iwslt2021_uoe_submission
│ ├── README.md
│ └── overview.png
├── l0drop
│ ├── README.md
│ ├── l0drop-att.png
│ ├── l0drop.png
│ └── mt_ende.png
├── multilingual_laln_lalt
│ ├── README.md
│ ├── many-to-many-full-results-per-language.md
│ └── many-to-many.xlsx
├── rela_sparse_attention
│ ├── README.md
│ ├── aer.png
│ ├── null.png
│ └── rela.png
└── usage
│ └── README.md
├── evalu.py
├── func.py
├── lrs
├── __init__.py
├── cosinelr.py
├── epochlr.py
├── gnmtplr.py
├── lr.py
├── noamlr.py
├── scorelr.py
└── vanillalr.py
├── main.py
├── models
├── __init__.py
├── deepnmt.py
├── model.py
├── rnnsearch.py
├── rnnsearch_deepatt.py
├── transformer.py
├── transformer_aan.py
├── transformer_fixup.py
├── transformer_fuse.py
├── transformer_l0drop.py
├── transformer_rela.py
└── transformer_rpr.py
├── modules
├── __init__.py
├── fixup.py
├── initializer.py
├── l0norm.py
├── rela.py
└── rpr.py
├── rnns
├── __init__.py
├── atr.py
├── cell.py
├── gru.py
├── lrn.py
├── lstm.py
├── olrn.py
├── rnn.py
└── sru.py
├── run.py
├── scripts
├── bleu_over_length.py
├── checkpoint_averaging.py
├── chrF.py
├── evaluate_pos_translation_rate.py
├── multi-bleu-detok.perl
├── multi-bleu.perl
└── shuffle_corpus.py
├── search.py
├── utils
├── __init__.py
├── cycle.py
├── dtype.py
├── metric.py
├── parallel.py
├── queuer.py
├── recorder.py
├── saver.py
└── util.py
└── vocab.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2018, Biao Zhang
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Zero
2 | A neural machine translation system implemented by python2 + tensorflow.
3 |
4 | ## Features
5 | 1. Multi-Process Data Loading/Processing (*Problems Exist*)
6 | 2. Multi-GPU Training/Decoding
7 | 3. Gradient Aggregation
8 |
9 | ## Papers
10 |
11 | We associate each paper below with a readme file link. Please click the paper link you are interested for more details.
12 |
13 | * [Efficient CTC Regularization via Coarse Labels for End-to-End Speech Translation, EACL2023](docs/colactc)
14 | * [Revisiting End-to-End Speech-to-Text Translation From Scratch, ICML2022](https://github.com/bzhangGo/st_from_scratch)
15 | * [Sparse Attention with Linear Units, EMNLP2021](docs/rela_sparse_attention)
16 | * [Edinburgh's End-to-End Multilingual Speech Translation System for IWSLT 2021, IWSLT2021 System submission](docs/iwslt2021_uoe_submission)
17 | * [Beyond Sentence-Level End-to-End Speech Translation: Context Helps, ACL2021](docs/context_aware_st)
18 | * [On Sparsifying Encoder Outputs in Sequence-to-Sequence Models, ACL2021 Findings](docs/l0drop)
19 | * [Share or Not? Learning to Schedule Language-Specific Capacity for Multilingual Translation, ICLR2021](docs/conditional_language_specific_routing)
20 | * [Fast Interleaved Bidirectional Sequence Generation, WMT2020](docs/interleaved_bidirectional_transformer)
21 | * [Adaptive Feature Selection for End-to-End Speech Translation, EMNLP2020 Findings](docs/afs_speech_translation)
22 | * [Improving Massively Multilingual Neural Machine Translation and Zero-Shot Translation, ACL2020](docs/multilingual_laln_lalt)
23 | * [Improving Deep Transformer with Depth-Scaled Initialization and Merged Attention, EMNLP2019](docs/depth_scale_init_and_merged_attention)
24 |
25 | ## Supported Models
26 | * RNNSearch: support LSTM, GRU, SRU, [ATR, EMNLP2018](https://github.com/bzhangGo/ATR), and [LRN, ACL2019](https://github.com/bzhangGo/lrn)
27 | models.
28 | * Deep attention: [Neural Machine Translation with Deep Attention, TPAMI](https://ieeexplore.ieee.org/document/8493282)
29 | * CAEncoder: the context-aware recurrent encoder, see [the paper, TASLP](https://ieeexplore.ieee.org/document/8031316)
30 | and the original [source code](https://github.com/DeepLearnXMU/CAEncoder-NMT) (in Theano).
31 | * Transformer: [attention is all you need](https://arxiv.org/abs/1706.03762)
32 | * AAN: the [average attention model, ACL2018](https://github.com/bzhangGo/transformer-aan) that accelerates the decoding!
33 | * Fixup: [Fixup Initialization: Residual Learning Without Normalization](https://arxiv.org/abs/1901.09321)
34 | * Relative position representation: [Self-Attention with Relative Position Representations](https://arxiv.org/abs/1803.02155)
35 |
36 | ## Requirements
37 | * python2.7
38 | * tensorflow <= 1.13.2
39 |
40 | ## Usage
41 | [How to use this toolkit for machine translation?](docs/usage)
42 |
43 | ## TODO:
44 | 1. organize the parameters and interpretations in config.
45 | 2. reformat and fulfill code comments
46 | 3. simplify and remove unecessary coding
47 | 4. improve rnn models
48 |
49 | ## Citation
50 |
51 | If you use the source code, please consider citing the follow paper:
52 | ```
53 | @InProceedings{D18-1459,
54 | author = "Zhang, Biao
55 | and Xiong, Deyi
56 | and su, jinsong
57 | and Lin, Qian
58 | and Zhang, Huiji",
59 | title = "Simplifying Neural Machine Translation with Addition-Subtraction Twin-Gated Recurrent Networks",
60 | booktitle = "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing",
61 | year = "2018",
62 | publisher = "Association for Computational Linguistics",
63 | pages = "4273--4283",
64 | location = "Brussels, Belgium",
65 | url = "http://aclweb.org/anthology/D18-1459"
66 | }
67 | ```
68 |
69 | If you are interested in the CAEncoder model, please consider citing our TASLP paper:
70 | ```
71 | @article{Zhang:2017:CRE:3180104.3180106,
72 | author = {Zhang, Biao and Xiong, Deyi and Su, Jinsong and Duan, Hong},
73 | title = {A Context-Aware Recurrent Encoder for Neural Machine Translation},
74 | journal = {IEEE/ACM Trans. Audio, Speech and Lang. Proc.},
75 | issue_date = {December 2017},
76 | volume = {25},
77 | number = {12},
78 | month = dec,
79 | year = {2017},
80 | issn = {2329-9290},
81 | pages = {2424--2432},
82 | numpages = {9},
83 | url = {https://doi.org/10.1109/TASLP.2017.2751420},
84 | doi = {10.1109/TASLP.2017.2751420},
85 | acmid = {3180106},
86 | publisher = {IEEE Press},
87 | address = {Piscataway, NJ, USA},
88 | }
89 | ```
90 |
91 | ## Reference
92 | When developing this repository, I referred to the following projects:
93 |
94 | * [Nematus](https://github.com/EdinburghNLP/nematus)
95 | * [THUMT](https://github.com/thumt/THUMT)
96 | * [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor)
97 | * [Keras](https://github.com/keras-team/keras)
98 |
99 | ## Contact
100 | For any questions or suggestions, please feel free to contact [Biao Zhang](mailto:biaojiaxing@gmail.com)
101 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import numpy as np
8 | from utils.util import batch_indexer, token_indexer
9 |
10 |
11 | class Dataset(object):
12 | def __init__(self, src_file, tgt_file,
13 | src_vocab, tgt_vocab, max_len=100,
14 | batch_or_token='batch',
15 | data_leak_ratio=0.5):
16 | self.source = src_file
17 | self.target = tgt_file
18 | self.src_vocab = src_vocab
19 | self.tgt_vocab = tgt_vocab
20 | self.max_len = max_len
21 | self.batch_or_token = batch_or_token
22 | self.data_leak_ratio = data_leak_ratio
23 |
24 | self.leak_buffer = []
25 |
26 | def load_data(self):
27 | with open(self.source, 'r') as src_reader, \
28 | open(self.target, 'r') as tgt_reader:
29 | while True:
30 | src_line = src_reader.readline()
31 | tgt_line = tgt_reader.readline()
32 |
33 | if src_line == "" or tgt_line == "":
34 | break
35 |
36 | src_line = src_line.strip()
37 | tgt_line = tgt_line.strip()
38 |
39 | if src_line == "" or tgt_line == "":
40 | continue
41 |
42 | yield (
43 | self.src_vocab.to_id(src_line.strip().split()[:self.max_len]),
44 | self.tgt_vocab.to_id(tgt_line.strip().split()[:self.max_len])
45 | )
46 |
47 | def to_matrix(self, batch):
48 | batch_size = len(batch)
49 |
50 | src_lens = [len(sample[1]) for sample in batch]
51 | tgt_lens = [len(sample[2]) for sample in batch]
52 |
53 | src_len = min(self.max_len, max(src_lens))
54 | tgt_len = min(self.max_len, max(tgt_lens))
55 |
56 | s = np.zeros([batch_size, src_len], dtype=np.int32)
57 | t = np.zeros([batch_size, tgt_len], dtype=np.int32)
58 | x = []
59 | for eidx, sample in enumerate(batch):
60 | x.append(sample[0])
61 | src_ids, tgt_ids = sample[1], sample[2]
62 |
63 | s[eidx, :min(src_len, len(src_ids))] = src_ids[:src_len]
64 | t[eidx, :min(tgt_len, len(tgt_ids))] = tgt_ids[:tgt_len]
65 | return x, s, t
66 |
67 | def batcher(self, size, buffer_size=1000, shuffle=True, train=True):
68 | def _handle_buffer(_buffer):
69 | sorted_buffer = sorted(
70 | _buffer, key=lambda xx: max(len(xx[1]), len(xx[2])))
71 |
72 | if self.batch_or_token == 'batch':
73 | buffer_index = batch_indexer(len(sorted_buffer), size)
74 | else:
75 | buffer_index = token_indexer(
76 | [[len(sample[1]), len(sample[2])] for sample in sorted_buffer], size)
77 |
78 | index_over_index = batch_indexer(len(buffer_index), 1)
79 | if shuffle: np.random.shuffle(index_over_index)
80 |
81 | for ioi in index_over_index:
82 | index = buffer_index[ioi[0]]
83 | batch = [sorted_buffer[ii] for ii in index]
84 | x, s, t = self.to_matrix(batch)
85 | yield {
86 | 'src': s,
87 | 'tgt': t,
88 | 'index': x,
89 | 'raw': batch,
90 | }
91 |
92 | buffer = self.leak_buffer
93 | self.leak_buffer = []
94 | for i, (src_ids, tgt_ids) in enumerate(self.load_data()):
95 | buffer.append((i, src_ids, tgt_ids))
96 | if len(buffer) >= buffer_size:
97 | for data in _handle_buffer(buffer):
98 | # check whether the data is tailed
99 | batch_size = len(data['raw']) if self.batch_or_token == 'batch' \
100 | else max(np.sum(data['tgt'] > 0), np.sum(data['src'] > 0))
101 | if batch_size < size * self.data_leak_ratio:
102 | self.leak_buffer += data['raw']
103 | else:
104 | yield data
105 | buffer = self.leak_buffer
106 | self.leak_buffer = []
107 |
108 | # deal with data in the buffer
109 | if len(buffer) > 0:
110 | for data in _handle_buffer(buffer):
111 | # check whether the data is tailed
112 | batch_size = len(data['raw']) if self.batch_or_token == 'batch' \
113 | else max(np.sum(data['tgt'] > 0), np.sum(data['src'] > 0))
114 | if train and batch_size < size * self.data_leak_ratio:
115 | self.leak_buffer += data['raw']
116 | else:
117 | yield data
118 |
--------------------------------------------------------------------------------
/docs/afs_speech_translation/afs_training.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/afs_speech_translation/afs_training.png
--------------------------------------------------------------------------------
/docs/afs_speech_translation/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/afs_speech_translation/example.png
--------------------------------------------------------------------------------
/docs/colactc/README.md:
--------------------------------------------------------------------------------
1 | ## Efficient CTC Regularization via Coarse Labels for End-to-End Speech Translation, EACL 2023
2 |
3 | - [paper link]()
4 | - source code is available at [st_from_scratch](https://github.com/bzhangGo/st_from_scratch)
5 |
6 | ### **Why CTC Regularization?**
7 |
8 | Speech translation (ST) requires the model to capture the semantics of an audio input,
9 | but auido carries many content-irrelevant information, such as emotion and pauses, that increases the
10 | difficulty of translation modeling.
11 |
12 | CTC Regularization offers a mechanism to dynamically align speech representations with their
13 | discrete labels, down-weighting content-irrelevant information and encouraging the learning of speech semantics. In the
14 | literature, many studies have confirmed its effectiveness on ST.
15 |
16 |
17 | ### **Why NOT CTC Regularization?**
18 |
19 | CTC Regularization requires an extra projection layer similar to the word prediction layer, which bringins in
20 | many model parameters and slows the running.
21 |
22 | We are particularly interested in improving the efficiency of CTC Regularization.
23 |
24 |
25 | ### **CoLaCTC**
26 |
27 | *Why use genuine labels for CTC regulariation?* particularly considering the CTC regularizaiton layer will be dropped
28 | after training.
29 |
30 | Following this idea, we propose to use pseudo CTC labels at coarser grain for CTC regularization, which offers a direct
31 | control over the CTC space and decoupled this space with the genuine word vocabulary space.
32 | We only used some simple operations to produce CoLaCTC labels as follows:
33 |
34 |
35 |
36 |
37 | ### **How does it work?**
38 |
39 |
40 | | System | Params | BLEU | Speedup |
41 | |-------------------------------------------------------|--------|------|---------|
42 | | Baseline (no CTC) | 46.1M | 21.8 | 1.39x |
43 | | CTC Regularization + translation labels | 47.9M | 22.7 | 1.00x |
44 | | CTC Regularization + translation-based CoLaCTC labels | 46.2M | 22.7 | 1.39x |
45 | | CTC Regularization + transcript labels | 47.5M | 23.8 | 1.00x |
46 | | CTC Regularization + transcript-based CoLaCTC labels | 46.2M | 24.3 | 1.31x |
47 |
48 | (Quality on MuST-C En-De)
49 |
50 | ### **Why does it work?**
51 |
52 | We are still lack of understanding on why it could work so well on ST. One observation is that CoLaCTC label sequence is
53 | still quite informative. Using it as the source input for machine translation could achieve decent
54 | performance:
55 |
56 |
57 |
58 |
59 | ### Model Training & Evaluation
60 |
61 | We added the implementation to [st_from_scratch](https://github.com/bzhangGo/st_from_scratch). The implementation is
62 | quite simple.
63 | We change
64 | ```python
65 | seq_values.extend(sequence)
66 | ```
67 | to
68 | ```python
69 | seq_values.extend([v % self.p.cola_ctc_L for v in sequence])
70 | ```
71 | as in https://github.com/bzhangGo/st_from_scratch/blob/master/data.py#L152
72 |
73 |
74 | ### Citation
75 |
76 | Please consider cite our paper as follows:
77 | >Biao Zhang; Barry Haddow; Rico Sennrich (2023). Efficient CTC Regularization via Coarse Labels for End-to-End Speech Translation. In EACL 2023.
78 | ```
79 | @inproceedings{zhang-etal-2023-colactc,
80 | title = "Efficient CTC Regularization via Coarse Labels for End-to-End Speech Translation",
81 | author = "Zhang, Biao and
82 | Haddow, Barry and
83 | Sennrich, Rico",
84 | booktitle = "Proceedings of the 17th Conference of the European Chapter of the Association for Computational Linguistics: Main Volume",
85 | month = may,
86 | year = "2023",
87 | publisher = "Association for Computational Linguistics",
88 | }
89 | ```
90 |
--------------------------------------------------------------------------------
/docs/colactc/colactc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/colactc/colactc.png
--------------------------------------------------------------------------------
/docs/colactc/mt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/colactc/mt.png
--------------------------------------------------------------------------------
/docs/conditional_language_specific_routing/README.md:
--------------------------------------------------------------------------------
1 | ## Share or Not? Learning to Schedule Language-Specific Capacity for Multilingual Translation
2 |
3 |
4 | We host our source code for our ICLR paper here.
5 |
6 | Please go to [iclr2021_clsr branch](https://github.com/bzhangGo/zero/tree/iclr2021_clsr) for more details.
7 |
--------------------------------------------------------------------------------
/docs/context_aware_st/README.md:
--------------------------------------------------------------------------------
1 | ## Beyond Sentence-Level End-to-End Speech Translation: Context Helps
2 |
3 | [**Paper**](https://aclanthology.org/2021.acl-long.200/) |
4 | [**Highlights**](#paper-highlights) |
5 | [**Overview**](#context-aware-st) |
6 | [**Results**](#results) |
7 | [**Training&Eval**](#training-and-evaluation) |
8 | [**Citation**](#citation)
9 |
10 | ### Paper highlights
11 |
12 | Contextual information carries valuable clues for translation. So far, studies on text-based context-aware translation have shown
13 | success, but whether and how context helps end-to-end speech question is still under-studied.
14 |
15 | We believe that context would be more helpful to ST, because speech signals often contain more ambiguous expressions apart
16 | from the ones commonly occurred in texts. For example, homophones, like flower and flour, are almost indistinguishable without context.
17 |
18 | We study context-aware ST in this project and using a simple concatenation-based model. Our main findings are as follows:
19 | * Incorporating context improves overall translation quality (+0.18-2.61 BLEU) and benefits pronoun translation across different language pairs.
20 | * Context also improves the translation of homophones
21 | * ST models with contexts suffer less from (artificial) audio segmentation errors
22 | * Contextual modeling improves translation quality and reduces latency and flicker for simultaneous translation under re-translation strategy
23 |
24 |
25 | ### Context Aware ST
26 |
27 | We use AFS to reduce the audio feature length and improve training efficiency. Figure below shows our overall framework:
28 |
29 |
30 |
31 | Note creating novel context-aware ST architectures is not the key topic of this study, which is our next-step study.
32 |
33 |
34 | ### Training and Evaluation
35 |
36 | - We implement the model in [context-aware speech_translation branch](https://github.com/bzhangGo/zero/tree/context_aware_speech_translation)
37 |
38 | Our training involves two phrases, as shown below:
39 |
40 |
41 |
42 | Please refer to [our paper](https://aclanthology.org/2021.acl-long.200/) for more details.
43 |
44 |
45 | ### Results
46 |
47 | We mainly experiment with MuST-C corpus and below we show our model outputs (also BLEU) in all languages.
48 |
49 | | Model | De | Es | Fr | It | Nl | Pt | Ro | Ru |
50 | |---------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|
51 | | Baseline | [22.38](http://data.statmt.org/bzhang/acl2021_context_aware_st/baseline/de.txt) | [27.04](http://data.statmt.org/bzhang/acl2021_context_aware_st/baseline/es.txt) | [33.43](http://data.statmt.org/bzhang/acl2021_context_aware_st/baseline/fr.txt) | [23.35](http://data.statmt.org/bzhang/acl2021_context_aware_st/baseline/it.txt) | [25.05](http://data.statmt.org/bzhang/acl2021_context_aware_st/baseline/nl.txt) | [26.55](http://data.statmt.org/bzhang/acl2021_context_aware_st/baseline/pt.txt) | [21.87](http://data.statmt.org/bzhang/acl2021_context_aware_st/baseline/ro.txt) | [14.92](http://data.statmt.org/bzhang/acl2021_context_aware_st/baseline/ru.txt) |
52 | | CA ST w/ SWBD | [22.7](http://data.statmt.org/bzhang/acl2021_context_aware_st/swbd/de.txt) | [27.12](http://data.statmt.org/bzhang/acl2021_context_aware_st/swbd/es.txt) | [34.23](http://data.statmt.org/bzhang/acl2021_context_aware_st/swbd/fr.txt) | [23.46](http://data.statmt.org/bzhang/acl2021_context_aware_st/swbd/it.txt) | [25.84](http://data.statmt.org/bzhang/acl2021_context_aware_st/swbd/nl.txt) | [26.63](http://data.statmt.org/bzhang/acl2021_context_aware_st/swbd/pt.txt) | [23.7](http://data.statmt.org/bzhang/acl2021_context_aware_st/swbd/ro.txt) | [15.53](http://data.statmt.org/bzhang/acl2021_context_aware_st/swbd/ru.txt) |
53 | | CA ST w/ IMED | [22.86](http://data.statmt.org/bzhang/acl2021_context_aware_st/imed/de.txt) | [27.5](http://data.statmt.org/bzhang/acl2021_context_aware_st/imed/es.txt) | [34.28](http://data.statmt.org/bzhang/acl2021_context_aware_st/imed/fr.txt) | [23.53](http://data.statmt.org/bzhang/acl2021_context_aware_st/imed/it.txt) | [26.12](http://data.statmt.org/bzhang/acl2021_context_aware_st/imed/nl.txt) | [27.37](http://data.statmt.org/bzhang/acl2021_context_aware_st/imed/pt.txt) | [24.48](http://data.statmt.org/bzhang/acl2021_context_aware_st/imed/ro.txt) | [15.95](http://data.statmt.org/bzhang/acl2021_context_aware_st/imed/ru.txt) |
54 |
55 |
56 | ### Citation
57 |
58 | Please consider cite our paper as follows:
59 | >Biao Zhang; Ivan Titov; Barry Haddow; Rico Sennrich (2021). Beyond Sentence-Level End-to-End Speech Translation: Context Helps. In Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers).
60 | ```
61 | @inproceedings{zhang-etal-2021-beyond,
62 | title = "Beyond Sentence-Level End-to-End Speech Translation: Context Helps",
63 | author = "Zhang, Biao and
64 | Titov, Ivan and
65 | Haddow, Barry and
66 | Sennrich, Rico",
67 | booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)",
68 | month = aug,
69 | year = "2021",
70 | address = "Online",
71 | publisher = "Association for Computational Linguistics",
72 | url = "https://aclanthology.org/2021.acl-long.200",
73 | doi = "10.18653/v1/2021.acl-long.200",
74 | pages = "2566--2578",
75 | abstract = "Document-level contextual information has shown benefits to text-based machine translation, but whether and how context helps end-to-end (E2E) speech translation (ST) is still under-studied. We fill this gap through extensive experiments using a simple concatenation-based context-aware ST model, paired with adaptive feature selection on speech encodings for computational efficiency. We investigate several decoding approaches, and introduce in-model ensemble decoding which jointly performs document- and sentence-level translation using the same model. Our results on the MuST-C benchmark with Transformer demonstrate the effectiveness of context to E2E ST. Compared to sentence-level ST, context-aware ST obtains better translation quality (+0.18-2.61 BLEU), improves pronoun and homophone translation, shows better robustness to (artificial) audio segmentation errors, and reduces latency and flicker to deliver higher quality for simultaneous translation.",
76 | }
77 | ```
--------------------------------------------------------------------------------
/docs/context_aware_st/cast.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/context_aware_st/cast.png
--------------------------------------------------------------------------------
/docs/context_aware_st/training.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/context_aware_st/training.png
--------------------------------------------------------------------------------
/docs/depth_scale_init_and_merged_attention/README.md:
--------------------------------------------------------------------------------
1 | ## Improving Deep Transformer with Depth-Scaled Initialization and Merged Attention, EMNLP2019
2 |
3 | - [paper link](https://www.aclweb.org/anthology/D19-1083/)
4 |
5 | This paper focus on improving Deep Transformer.
6 | Our empirical observation suggests that simply stacking more Transformer layers makes training divergent.
7 | Rather than resorting to the pre-norm structure which shifts the layer normalization before modeling blocks,
8 | we analyze the reason why a vanilla deep Transformer suffers from poor convergence.
9 |
10 |
11 |
12 | Our evidence shows that it's because of *gradient vanishing* (shown above) caused by the interaction between residual connection
13 | and layer normalization. **In short, the residual connection increases the variance of its output, which decreases the gradient
14 | backpropagated from layer normalization. (empirically)**
15 |
16 | We solve this problem by proposing depth-scaled initialization (DS-Init), which decreases
17 | parameter variance at the initialization stage. DS-Init reduces output variance of residual connections so as to
18 | ease gradient back-propagation through normalization layers. In practice, DS-Init often produces slightly better
19 | translation quality than the pre-norm structure.
20 |
21 | We also care about the computational overhead raised by deep models. To settle this issue, we propose the merged
22 | attention network which combines a simplified average attention model and the encoder-decoder attention model on
23 | the target side. Merged attention model enables the deep Transformer matching the decoding speed of its baseline
24 | with a clear higher BLEU score.
25 |
26 | ### Approach
27 |
28 | To train a deep Transformer model for machine translation, scale your initialization for each layer as follows:
29 |
30 |
31 |
32 | where `\alpha` and `\gamma` are hyperparameters for the uniform distribution. `l` denotes the depth of the layer.
33 |
34 |
35 | ### Model Training
36 |
37 | Train 12-layer Transformer model with the following settings:
38 | >The model class is: `transformer_fuse`, the merged attention is enabled by giving `fuse_mask` into `dot_attention` function.
39 | ```
40 | python run.py --mode train --parameters=hidden_size=512,embed_size=512,filter_size=2048,\
41 | initializer="uniform_unit_scaling",initializer_gain=1.,\
42 | model_name="transformer_fuse",scope_name="transformer_fuse",\
43 | deep_transformer_init=True,\
44 | num_encoder_layer=12,\
45 | num_decoder_layer=12,\
46 | ```
47 |
48 | Other details can be found [here](../usage).
49 |
50 | ### Performance and Download
51 |
52 | We offer a range of [pretrained models](http://data.statmt.org/bzhang/emnlp19_deep_transformer/) for further study.
53 |
54 |
55 | | Task | Model | BLEU | Download |
56 | |---------------|-------------------------------------|-------| -------- |
57 | | WMT14 En-Fr | Base Transformer + 6 Layers | 39.09 | |
58 | | | Base Transformer + Ours + 12 Layers | 40.58 | |
59 | | IWSLT14 De-En | Base Transformer + 6 Layers | 34.41 | |
60 | | | Base Transformer + Ours + 12 Layers | 35.63 | |
61 | | WMT18 En-Fr | Base Transformer + 6 Layers | 15.5 | |
62 | | | Base Transformer + Ours + 12 Layers | 15.8 | |
63 | | WMT18 Zh-En | Base Transformer + 6 Layers | 21.1 | |
64 | | | Base Transformer + Ours + 12 Layers | 22.3 | |
65 | | WMT14 En-De | Base Transformer + 6 Layers | 27.59 | [download](http://data.statmt.org/bzhang/emnlp19_deep_transformer/model/base.tar.gz) |
66 | | | Base Transformer + Ours + 12 Layers | 28.55 | |
67 | | | Big Transformer + 6 Layers | 29.07 | [download](http://data.statmt.org/bzhang/emnlp19_deep_transformer/model/big.tar.gz) |
68 | | | Big Transformer + Ours + 12 Layers | 29.47 | |
69 | | | Base Transformer + Ours + 20 Layers | 28.67 | [download](http://data.statmt.org/bzhang/emnlp19_deep_transformer/model/base+fuse_init20.tar.gz) |
70 | | | Base Transformer + Ours + 30 Layers | 28.86 | [download](http://data.statmt.org/bzhang/emnlp19_deep_transformer/model/base+fuse_init30.tar.gz) |
71 | | | Big Transformer + Ours + 20 Layers | 29.62 | [download](http://data.statmt.org/bzhang/emnlp19_deep_transformer/model/big+fuse_init20.tar.gz) |
72 |
73 | Please go to [pretrained models](http://data.statmt.org/bzhang/emnlp19_deep_transformer/) for more details.
74 |
75 | ### Citation
76 |
77 | Please consider cite our paper as follows:
78 | >Biao Zhang; Ivan Titov; Rico Sennrich (2019). Improving Deep Transformer with Depth-Scaled Initialization and Merged Attention. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP). Hong Kong, China, pp. 898-909.
79 | ```
80 | @inproceedings{zhang-etal-2019-improving-deep,
81 | title = "Improving Deep Transformer with Depth-Scaled Initialization and Merged Attention",
82 | author = "Zhang, Biao and
83 | Titov, Ivan and
84 | Sennrich, Rico",
85 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)",
86 | month = nov,
87 | year = "2019",
88 | address = "Hong Kong, China",
89 | publisher = "Association for Computational Linguistics",
90 | url = "https://www.aclweb.org/anthology/D19-1083",
91 | doi = "10.18653/v1/D19-1083",
92 | pages = "898--909",
93 | abstract = "The general trend in NLP is towards increasing model capacity and performance via deeper neural networks. However, simply stacking more layers of the popular Transformer architecture for machine translation results in poor convergence and high computational overhead. Our empirical analysis suggests that convergence is poor due to gradient vanishing caused by the interaction between residual connection and layer normalization. We propose depth-scaled initialization (DS-Init), which decreases parameter variance at the initialization stage, and reduces output variance of residual connections so as to ease gradient back-propagation through normalization layers. To address computational cost, we propose a merged attention sublayer (MAtt) which combines a simplified average-based self-attention sublayer and the encoder-decoder attention sublayer on the decoder side. Results on WMT and IWSLT translation tasks with five translation directions show that deep Transformers with DS-Init and MAtt can substantially outperform their base counterpart in terms of BLEU (+1.1 BLEU on average for 12-layer models), while matching the decoding speed of the baseline model thanks to the efficiency improvements of MAtt. Source code for reproduction will be released soon.",
94 | }
95 | ```
96 |
--------------------------------------------------------------------------------
/docs/depth_scale_init_and_merged_attention/dsinit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/depth_scale_init_and_merged_attention/dsinit.png
--------------------------------------------------------------------------------
/docs/depth_scale_init_and_merged_attention/grad.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/depth_scale_init_and_merged_attention/grad.png
--------------------------------------------------------------------------------
/docs/interleaved_bidirectional_transformer/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/interleaved_bidirectional_transformer/overview.png
--------------------------------------------------------------------------------
/docs/iwslt2021_uoe_submission/README.md:
--------------------------------------------------------------------------------
1 | ## Edinburgh's End-to-End Multilingual Speech Translation System for IWSLT 2021
2 |
3 | [**Paper**](https://aclanthology.org/2021.iwslt-1.19/) |
4 | [**Highlights**](#paper-highlights) |
5 | [**Results**](#results) |
6 | [**Citation**](#citation)
7 |
8 | ### Paper highlights
9 |
10 | We participated IWSLT21 multilingual speech translation with an integrated end-to-end system. Below shows the overview of our
11 | system:
12 |
13 |
14 |
15 | For more details, please check out [our paper](https://aclanthology.org/2021.iwslt-1.19/).
16 |
17 |
18 | ### Results
19 |
20 | We release our system and outputs at [here](http://data.statmt.org/bzhang/iwslt2021_uoe_system/) to facilitate other researches, particularly for those who are interested in a
21 | reproduction and comparision.
22 |
23 |
24 | ### Citation
25 |
26 | Please consider cite our paper as follows:
27 | >Biao Zhang; Rico Sennrich (2021). Edinburgh's End-to-End Multilingual Speech Translation System for IWSLT 2021. In Proceedings of the 18th International Conference on Spoken Language Translation (IWSLT 2021).
28 | ```
29 | @inproceedings{zhang-sennrich-2021-edinburghs,
30 | title = "{E}dinburgh{'}s End-to-End Multilingual Speech Translation System for {IWSLT} 2021",
31 | author = "Zhang, Biao and
32 | Sennrich, Rico",
33 | booktitle = "Proceedings of the 18th International Conference on Spoken Language Translation (IWSLT 2021)",
34 | month = aug,
35 | year = "2021",
36 | address = "Bangkok, Thailand (online)",
37 | publisher = "Association for Computational Linguistics",
38 | url = "https://aclanthology.org/2021.iwslt-1.19",
39 | doi = "10.18653/v1/2021.iwslt-1.19",
40 | pages = "160--168",
41 | abstract = "This paper describes Edinburgh{'}s submissions to the IWSLT2021 multilingual speech translation (ST) task. We aim at improving multilingual translation and zero-shot performance in the constrained setting (without using any extra training data) through methods that encourage transfer learning and larger capacity modeling with advanced neural components. We build our end-to-end multilingual ST model based on Transformer, integrating techniques including adaptive speech feature selection, language-specific modeling, multi-task learning, deep and big Transformer, sparsified linear attention and root mean square layer normalization. We adopt data augmentation using machine translation models for ST which converts the zero-shot problem into a zero-resource one. Experimental results show that these methods deliver substantial improvements, surpassing the official baseline by {\textgreater} 15 average BLEU and outperforming our cascading system by {\textgreater} 2 average BLEU. Our final submission achieves competitive performance (runner up).",
42 | }
43 | ```
--------------------------------------------------------------------------------
/docs/iwslt2021_uoe_submission/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/iwslt2021_uoe_submission/overview.png
--------------------------------------------------------------------------------
/docs/l0drop/README.md:
--------------------------------------------------------------------------------
1 | ## On Sparsifying Encoder Outputs in Sequence-to-Sequence Models
2 |
3 | [**Paper**](https://aclanthology.org/2021.findings-acl.255/) |
4 | [**Highlights**](#paper-highlights) |
5 | [**Overview**](#l0drop) |
6 | [**CaseStudy**](#examples) |
7 | [**Training&Eval**](#training-and-evaluation) |
8 | [**Citation**](#citation)
9 |
10 | ### Paper Highlights
11 |
12 | Information is not uniformly distributed in sentences, while
13 |
14 | Standard encoder-decoder models for sequence-to-sequence learning always feed all encoder outputs to the decoder for
15 | generation. However, information in a sequence is not uniformly distributed over tokens. In translation, we have null
16 | alignments where some source tokens are not translated at all; and in summarization, source document often contains many
17 | redundant tokens. The research questions we are interested in:
18 |
19 | * Are encoder outputs compressible?
20 | * Can we identify those uninformative outputs and prune them out automatically?
21 | * Can we obtain higher inference speed with shortened encoding sequence?
22 |
23 | We propose L0Drop to this end, and our main findings are as follows:
24 |
25 | * We confirm that the encoder outputs can be compressed, around 40-70% of them can be dropped without large effects on
26 | the generation quality.
27 | * The resulting sparsity level differs across word types, the encodings corresponding to function words
28 | (such as determiners, prepositions) are more frequently pruned than those of content words (e.g., verbs and nouns).
29 | * L0Drop can improve decoding efficiency particularly for lengthy source inputs. We achieve a decoding speedup of up to
30 | 1.65x on document summarization tasks and 1.20x on character-based machine translation task.
31 | * Filtering out source encodings with rule-based sparse patterns is also feasible.
32 |
33 |
34 | ### L0Drop
35 |
36 | L0Drop forces model to route information through a subset of the encoder outputs, and the subset is learned automatically.
37 |
38 |
39 |
40 | L0Drop is different from sparse attention, which is comparably shown below.
41 |
42 |
43 |
44 | Note that L0Drop is data-driven and task-agnostic. We applied it to machine translation as well as
45 | document summarization tasks. Results on WMT14 En-De translation tasks are shown below:
46 |
47 |
48 |
49 | Please refer to [our paper](https://aclanthology.org/2021.findings-acl.255/) for more details.
50 |
51 | ### Examples
52 |
53 | Here, we show some examples learned by L0Drop on machine translation tasks (highlighted source words are dropped after encoding):
54 |
55 | | Task | Sample |
56 | |----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
57 | | WMT18 Zh-En | Source `这` 一年 `来` `,` 中国 电@@ 商 出 `了` `一份` 怎样 `的` 成绩@@ 单 ?
Reference: what sort of report card did China 's e @-@ commerce industry receive this year ?
Translation: what kind of report card was produced by Chinese telec@@ om dealers during the year ? |
58 | | | Source: 中国 `在` 地区 合作 `中` `发挥` `的` 作用 一贯 `是` 积极 正面 `的` `,` `受到` 地区 国家 高度 认可 `。`
Reference: China has always played an active and positive role in regional cooperation , which is well recognized by regional countries .
Translation: China 's role in regional cooperation has always been positive and highly recognized by the countries of the region . |
59 | | WMT14 En-De | Source: `The` cause `of` `the` b@@ `last` `was` not known `,` he said `.`
Reference: Die Ursache der Explo@@ sion sei nicht bekannt , erklärte er .
Translation: Die Ursache der Explo@@ sion war nicht bekannt , sagte er . |
60 | | | Source: `The` night `was` long , `the` music loud and `the` atmosphere good `,` but `at` some `point` everyone has `to` go home `.`
Reference: Die Nacht war lang , die Musik laut und die Stimmung gut , aber irgendwann geht es nach Hause .
Translation: Die Nacht war lang , die Musik laut und die Atmosphäre gut , aber irgendwann muss jeder nach Hause gehen . |
61 |
62 |
63 | ### Training and Evaluation
64 |
65 | - We implement the model in [transformer_l0drop](../../models/transformer_l0drop.py) and [l0norm](../../modules/l0norm.py)
66 |
67 | #### Training
68 |
69 | It's possible to train Transformer with L0Drop from scratch by setting proper schedulers for `\lambda`,
70 | a hyper-parameter loosely controlling the sparsity rate of L0Drop. Unfortunately, the optimal scheduler is
71 | data&task-dependent.
72 |
73 | We suggest first pre-train a normal Transformer model, and then finetune the Transformer+L0Drop. This could
74 | save a lot of efforts.
75 |
76 | * Step 1. train a normal Transformer model as described [here](../../docs/usage/README.md). Below is
77 | an example on WMT14 En-De for reference:
78 | ```
79 | data_dir=the preprocessed data diretory
80 | zero=the path of this code base
81 | python $zero/run.py --mode train --parameters=hidden_size=512,embed_size=512,filter_size=2048,\
82 | dropout=0.1,label_smooth=0.1,attention_dropout=0.1,\
83 | max_len=256,batch_size=80,eval_batch_size=32,\
84 | token_size=6250,batch_or_token='token',\
85 | initializer="uniform_unit_scaling",initializer_gain=1.,\
86 | model_name="transformer",scope_name="transformer",buffer_size=60000,\
87 | clip_grad_norm=0.0,\
88 | num_heads=8,\
89 | lrate=1.0,\
90 | process_num=3,\
91 | num_encoder_layer=6,\
92 | num_decoder_layer=6,\
93 | warmup_steps=4000,\
94 | lrate_strategy="noam",\
95 | epoches=5000,\
96 | update_cycle=4,\
97 | gpus=[0],\
98 | disp_freq=1,\
99 | eval_freq=5000,\
100 | sample_freq=1000,\
101 | checkpoints=5,\
102 | max_training_steps=300000,\
103 | beta1=0.9,\
104 | beta2=0.98,\
105 | epsilon=1e-8,\
106 | random_seed=1234,\
107 | src_vocab_file="$data_dir/vocab.zero.en",\
108 | tgt_vocab_file="$data_dir/vocab.zero.de",\
109 | src_train_file="$data_dir/train.32k.en.shuf",\
110 | tgt_train_file="$data_dir/train.32k.de.shuf",\
111 | src_dev_file="$data_dir/dev.32k.en",\
112 | tgt_dev_file="$data_dir/dev.32k.de",\
113 | src_test_file="$data_dir/newstest2014.32k.en",\
114 | tgt_test_file="$data_dir/newstest2014.de",\
115 | output_dir="train"
116 | ```
117 |
118 | * Step 2. finetune L0Drop using the following command:
119 | ```
120 | data_dir=the preprocessed data directory
121 | zero=the path of this code base
122 | python $zero/run.py --mode train --parameters=\
123 | l0_norm_reg_scalar=0.3,\
124 | l0_norm_warm_up=False,\
125 | model_name="transformer_l0drop",scope_name="transformer",\
126 | pretrained_model="path-to-pretrained-transformer",\
127 | max_training_steps=320000,\
128 | src_vocab_file="$data_dir/vocab.zero.en",\
129 | tgt_vocab_file="$data_dir/vocab.zero.de",\
130 | src_train_file="$data_dir/train.32k.en.shuf",\
131 | tgt_train_file="$data_dir/train.32k.de.shuf",\
132 | src_dev_file="$data_dir/dev.32k.en",\
133 | tgt_dev_file="$data_dir/dev.32k.de",\
134 | src_test_file="$data_dir/newstest2014.32k.en",\
135 | tgt_test_file="$data_dir/newstest2014.de",\
136 | output_dir="train"
137 | ```
138 | where `l0_norm_reg_scalar` is the `\lambda`, and `0.2 or 0.3` is a nice hyperparameter in our experiments.
139 |
140 | #### Evaluation
141 |
142 | The evaluation follows the same procedure as the baseline Transformer.
143 |
144 | ### Citation
145 |
146 | Please consider cite our paper as follows:
147 | >Biao Zhang; Ivan Titov; Rico Sennrich (2021). On Sparsifying Encoder Outputs in Sequence-to-Sequence Models. Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021.
148 | ```
149 | @inproceedings{zhang-etal-2021-sparsifying,
150 | title = "On Sparsifying Encoder Outputs in Sequence-to-Sequence Models",
151 | author = "Zhang, Biao and
152 | Titov, Ivan and
153 | Sennrich, Rico",
154 | booktitle = "Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021",
155 | month = aug,
156 | year = "2021",
157 | address = "Online",
158 | publisher = "Association for Computational Linguistics",
159 | url = "https://aclanthology.org/2021.findings-acl.255",
160 | doi = "10.18653/v1/2021.findings-acl.255",
161 | pages = "2888--2900",
162 | }
163 |
164 | ```
165 |
166 |
--------------------------------------------------------------------------------
/docs/l0drop/l0drop-att.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/l0drop/l0drop-att.png
--------------------------------------------------------------------------------
/docs/l0drop/l0drop.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/l0drop/l0drop.png
--------------------------------------------------------------------------------
/docs/l0drop/mt_ende.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/l0drop/mt_ende.png
--------------------------------------------------------------------------------
/docs/multilingual_laln_lalt/many-to-many.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/multilingual_laln_lalt/many-to-many.xlsx
--------------------------------------------------------------------------------
/docs/rela_sparse_attention/README.md:
--------------------------------------------------------------------------------
1 | ## Sparse Attention with Linear Units, EMNLP2021
2 |
3 |
4 | [**Paper**](https://arxiv.org/abs/2104.07012/) |
5 | [**Highlights**](#paper-highlights) |
6 | [**Training**](#model-training) |
7 | [**Results**](#results) |
8 | [**Citation**](#citation)
9 |
10 |
11 | ### Paper highlights
12 |
13 | Attention contributes substantially to NMT/NLP. It relies on `SoftMax` activation to produce a dense, categorical
14 | distribution as an estimation of the relevance between input query and contexts.
15 |
16 | **But why dense? and Why distribution?**
17 |
18 |
19 | * Low attention score in dense attention doesn't mean low relevance. By contrast, sparse attention often does better.
20 | * Attention scores estimate the relevance, which doesn't necessarily follow distribution albeit such normalization
21 | often stabilize the training.
22 |
23 |
24 | In this work, we propose rectified linear attention (ReLA), which directly uses ReLU rather than
25 | softmax as an activation function for attention scores. ReLU naturally leads to sparse attention, and we
26 | apply [RMSNorm](https://openreview.net/pdf?id=BylmcHHgIB) to attention outputs to stabilize model training. Below
27 | shows the difference between ReLA and the vanilla attention.
28 |
29 |
30 |
31 | We find that:
32 | * ReLA achieves comparable translation performance to the softmax-based attention on five translation tasks, with
33 | similar running efficiency, but faster than other sparse baselines.
34 | * ReLA delivers high sparsity rate, high head diversity, and better accuracy than all baselines with respect to
35 | word alignment.
36 | * We also observe the emergence of attention heads with a high rate of null attention, only activating for certain queries.
37 |
38 |
39 | ### Model Training
40 |
41 | - We implement the model in [transformer_rela](../../models/transformer_rela.py) and [ReLA](../../modules/rela.py)
42 |
43 | The training of `ReLA` follows the baseline. You just need to change the `model_name` to `transformer_rela` as below
44 | (take WMT14 En-De as example):
45 | ```
46 | data_dir=the preprocessed data diretory
47 | zero=the path of this code base
48 | python $zero/run.py --mode train --parameters=hidden_size=512,embed_size=512,filter_size=2048,\
49 | dropout=0.1,label_smooth=0.1,attention_dropout=0.1,\
50 | max_len=256,batch_size=80,eval_batch_size=32,\
51 | token_size=6250,batch_or_token='token',\
52 | initializer="uniform_unit_scaling",initializer_gain=1.,\
53 | model_name="transformer_rela",scope_name="transformer_rela",buffer_size=60000,\
54 | clip_grad_norm=0.0,\
55 | num_heads=8,\
56 | lrate=1.0,\
57 | process_num=3,\
58 | num_encoder_layer=6,\
59 | num_decoder_layer=6,\
60 | warmup_steps=4000,\
61 | lrate_strategy="noam",\
62 | epoches=5000,\
63 | update_cycle=4,\
64 | gpus=[0],\
65 | disp_freq=1,\
66 | eval_freq=5000,\
67 | sample_freq=1000,\
68 | checkpoints=5,\
69 | max_training_steps=300000,\
70 | beta1=0.9,\
71 | beta2=0.98,\
72 | epsilon=1e-8,\
73 | random_seed=1234,\
74 | src_vocab_file="$data_dir/vocab.zero.en",\
75 | tgt_vocab_file="$data_dir/vocab.zero.de",\
76 | src_train_file="$data_dir/train.32k.en.shuf",\
77 | tgt_train_file="$data_dir/train.32k.de.shuf",\
78 | src_dev_file="$data_dir/dev.32k.en",\
79 | tgt_dev_file="$data_dir/dev.32k.de",\
80 | src_test_file="$data_dir/newstest2014.32k.en",\
81 | tgt_test_file="$data_dir/newstest2014.de",\
82 | output_dir="train"
83 | ```
84 |
85 |
86 | ### Results
87 |
88 | * Translation performance (SacreBLEU Scores) on different WMT tasks
89 |
90 | | Model | WMT14 En-Fr | WMT18 En-Fi | WMT18 Zh-En | WMT16 Ro-En |
91 | |:----------:|:-----------:|:-----------:|:-----------:|:-----------:|
92 | | softmax | 37.2 | 15.5 | 21.1 | 32.7 |
93 | | sparsemax | 37.3 | 15.1 | 19.2 | 33.5 |
94 | | 1.5-entmax | 37.9 | 15.5 | 20.8 | 33.2 |
95 | | ReLA | 37.9 | 15.4 | 20.8 | 32.9 |
96 |
97 | * Training and decoding efficiency of ReLA (based on tensorflow 1.13)
98 |
99 | | Model | Params | Train Speedup | Decode Speedup |
100 | |:----------:|:------:|:-------------:|:--------------:|
101 | | softmax | 72.31M | 1.00x | 1.00x |
102 | | sparsemax | 72.31M | 0.26x | 0.54x |
103 | | 1.5-entmax | 72.31M | 0.27x | 0.49x |
104 | | ReLA | 72.34M | 0.93x | 0.98x |
105 |
106 |
107 | * Source-target attention of ReLA aligns better with word alignment
108 |
109 |
110 |
111 | Note solid curves are for best head per layer, while dashed curves are average results over heads.
112 |
113 | * ReLA enables null-attention: attend to nothing
114 |
115 |
116 |
117 |
118 | ### Citation
119 |
120 | Please consider cite our paper as follows:
121 | >Biao Zhang; Ivan Titov; Rico Sennrich (2021). Sparse Attention with Linear Units. In The 2021 Conference on Empirical Methods in Natural Language Processing. Punta Cana, Dominican Republic
122 | ```
123 | @inproceedings{zhang-etal-2021-sparse,
124 | title = "Sparse Attention with Linear Units",
125 | author = "Zhang, Biao and
126 | Titov, Ivan and
127 | Sennrich, Rico",
128 | booktitle = "The 2021 Conference on Empirical Methods in Natural Language Processing",
129 | month = nov,
130 | year = "2021",
131 | address = "Punta Cana, Dominican Republic",
132 | publisher = "Association for Computational Linguistics",
133 | eprint = "2104.07012"
134 | }
135 | ```
136 |
--------------------------------------------------------------------------------
/docs/rela_sparse_attention/aer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/rela_sparse_attention/aer.png
--------------------------------------------------------------------------------
/docs/rela_sparse_attention/null.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/rela_sparse_attention/null.png
--------------------------------------------------------------------------------
/docs/rela_sparse_attention/rela.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bzhangGo/zero/d97e2c21b1c6d0467fe821223042247bf2b46bf9/docs/rela_sparse_attention/rela.png
--------------------------------------------------------------------------------
/docs/usage/README.md:
--------------------------------------------------------------------------------
1 | ## Questions
2 | 1. What's the effective batch size, training and decoding
3 |
4 | * When using `token-based` training (batch_or_token=token), the effective token number equals `number_gpus * update_cycles * token_siz`.
5 | * When using `batch-based` training (batch_or_token=batch), the effective batch size equals `number_gpus * update_cycles * batch_size`
6 | * At decoding phrase, we only use batch-based decoding with size of `eval_batch_size`.
7 |
8 | 2. What's the difference between `model_name` and `scope_name`
9 |
10 | The `model_name` means which model you want to train. The model name should be a registered model, which is
11 | under the folder `models`. The `scope_name` denotes the scope name in tensorflow for each model weights or variables.
12 |
13 | For example, when you want to train a Transformer model, you should set `model_name=transformer`. But you can use
14 | any valid scope name as you want, such as transformer, nmtmodel, transformer_exp1, .etc.
15 |
16 | ## How to use it?
17 |
18 | Below is a rough procedure for WMT14 En-De translation tasks.
19 |
20 | 1. Prepare your training, development and test data.
21 |
22 | For example, you can download the preprocessed WMT14 En-De dataset from [Stanford NMT](https://nlp.stanford.edu/projects/nmt/)
23 | * The training file: [train.en](https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.en),
24 | [train.de](https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.de)
25 | * The development file: [newstest12.en](https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2012.en),
26 | [newstest12.de](https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2012.de),
27 | [newstest13.en](https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2013.en),
28 | [newstest13.de](https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2013.de)
29 |
30 | Then, concate the `newstest12.en` and `newstest13.en` into `dev.en` using command like `cat newstest12.en newstest13.en > dev.en`.
31 | The same is for German language: `cat newstest12.de newstest13.de > dev.de`
32 | * The test file: [newstest14.en](https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.en),
33 | [newstest14.de](https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.de)
34 |
35 | 2. Preprocess your dataset.
36 |
37 | Generally, you can process your dataset with a standard pipeline as given in the WMT official site.
38 | Full content can be found at the preprocessed [datasets](http://data.statmt.org/wmt17/translation-task/preprocessed/).
39 | See the `prepare.sh` for more details.
40 |
41 | In our case, for WMT14 En-De translation, however, the dataset has already been pre-processeed.
42 | So, you can do nothing at this stage.
43 |
44 | *For some languages, such as Chinese, you need to perform word segmentation (Chinese-version Tokenize) first.
45 | You can find more information about segmentation [here](https://nlp.stanford.edu/software/segmenter.shtml)*
46 |
47 | 3. Optional but strongly suggested, Perform BPE decoding
48 |
49 | BPE algorithm is the most popular and currently standard way to handle rare words, or OOVs. It iteratively
50 | merges the most frequent patterns until the maximum merging number is reached. It splits rare words into
51 | `sub-words`, such as `Bloom` => `Blo@@ om`. Another benefit of BPE is that you can control the size of vocabulary.
52 |
53 | - download the [subword project](https://github.com/rsennrich/subword-nmt)
54 | - learn the subword model:
55 | ```
56 | python subword-nmt/learn_joint_bpe_and_vocab.py --input train.en train.de -s 32000 -o bpe32k --write-vocabulary vocab.en vocab.de
57 | ```
58 | Notice that the 32000 indicates 32k pieces, or you can simply understand it as your vocabulary size.
59 | - Apply the subword model to all your datasets.
60 | - To training data
61 | ```
62 | python subword-nmt/apply_bpe.py --vocabulary vocab.en --vocabulary-threshold 50 -c bpe32k < train.en > train.32k.en
63 | python subword-nmt/apply_bpe.py --vocabulary vocab.de --vocabulary-threshold 50 -c bpe32k < train.de > train.32k.de
64 | ```
65 | - To dev data
66 | ```
67 | python subword-nmt/apply_bpe.py --vocabulary vocab.en --vocabulary-threshold 50 -c bpe32k < dev.en > dev.32k.en
68 | python subword-nmt/apply_bpe.py --vocabulary vocab.de --vocabulary-threshold 50 -c bpe32k < dev.de > dev.32k.de
69 | ```
70 | Notice that you do not have to apply bpe to the `dev.de`, but we use it in our model.
71 | - To test data
72 | ```
73 | python subword-nmt/apply_bpe.py --vocabulary vocab.en --vocabulary-threshold 50 -c bpe32k < newstest14.en > newstest14.32k.en
74 | ```
75 |
76 | 4. Extract Vocabulary
77 |
78 | You still need to prepare the vocabulary using our code, because there are some special symbols in our vocabulary.
79 | - download our project:
80 | ```git clone https://github.com/bzhangGo/zero.git```
81 | - Run the code as follows:
82 | ```
83 | python zero/vocab.py train.en vocab.en
84 | python zero/vocab.py train.de vocab.de
85 | ```
86 | Roughly, the vocabulary size would be 32000, more or less.
87 |
88 | 5. Training your model.
89 |
90 | train your model with the following settings:
91 | ```
92 | data_dir=the preprocessed data directory
93 | python zero/run.py --mode train --parameters=hidden_size=1024,embed_size=512,\
94 | dropout=0.1,label_smooth=0.1,\
95 | max_len=80,batch_size=80,eval_batch_size=240,\
96 | token_size=3000,batch_or_token='token',\
97 | model_name="rnnsearch",scope_name="rnnsearch",buffer_size=3200,\
98 | clip_grad_norm=5.0,\
99 | lrate=5e-4,\
100 | epoches=10,\
101 | update_cycle=1,\
102 | gpus=[3],\
103 | disp_freq=100,\
104 | eval_freq=10000,\
105 | sample_freq=1000,\
106 | checkpoints=5,\
107 | caencoder=True,\
108 | cell='atr',\
109 | max_training_steps=100000000,\
110 | nthreads=8,\
111 | swap_memory=True,\
112 | layer_norm=True,\
113 | max_queue_size=100,\
114 | random_seed=1234,\
115 | src_vocab_file="$data_dir/vocab.en",\
116 | tgt_vocab_file="$data_dir/vocab.de",\
117 | src_train_file="$data_dir/train.32k.en.shuf",\
118 | tgt_train_file="$data_dir/train.32k.de.shuf",\
119 | src_dev_file="$data_dir/dev.32k.en",\
120 | tgt_dev_file="$data_dir/dev.32k.de",\
121 | src_test_file="",\
122 | tgt_test_file="",\
123 | output_dir="train",\
124 | test_output=""
125 | ```
126 | Model would be saved into directory `train`
127 |
128 | 6. Testing your model
129 |
130 | - Average your checkpoints which can give you better results.
131 | ```
132 | python zero/scripts/checkpoint_averaging.py --checkpoints 5 --output avg --path ../train --gpu 0
133 | ```
134 | - Then test your model with the following code
135 | ```
136 | data_dir=the preprocessed data directory
137 | python zero/run.py --mode test --parameters=hidden_size=1024,embed_size=512,\
138 | dropout=0.1,label_smooth=0.1,\
139 | max_len=80,batch_size=80,eval_batch_size=240,\
140 | token_size=3000,batch_or_token='token',\
141 | model_name="rnnsearch",scope_name="rnnsearch",buffer_size=3200,\
142 | clip_grad_norm=5.0,\
143 | lrate=5e-4,\
144 | epoches=10,\
145 | update_cycle=1,\
146 | gpus=[3],\
147 | disp_freq=100,\
148 | eval_freq=10000,\
149 | sample_freq=1000,\
150 | checkpoints=5,\
151 | caencoder=True,\
152 | cell='atr',\
153 | max_training_steps=100000000,\
154 | nthreads=8,\
155 | swap_memory=True,\
156 | layer_norm=True,\
157 | max_queue_size=100,\
158 | random_seed=1234,\
159 | src_vocab_file="$data_dir/vocab.en",\
160 | tgt_vocab_file="$data_dir/vocab.de",\
161 | src_train_file="$data_dir/train.32k.en.shuf",\
162 | tgt_train_file="$data_dir/train.32k.de.shuf",\
163 | src_dev_file="$data_dir/dev.32k.en",\
164 | tgt_dev_file="$data_dir/dev.32k.de",\
165 | src_test_file="$data_dir/newstest14.32k.en",\
166 | tgt_test_file="$data_dir/newstest14.de",\
167 | output_dir="avg",\
168 | test_output="newstest14.trans.bpe"
169 | ```
170 | The final translation will be dumped into `newstest14.trans.bpe`.
171 |
172 | You need remove the BPE splitter as follows: `sed -r 's/(@@ )|(@@ ?$)//g' < newstest14.trans.bpe > newstest14.trans.txt`
173 |
174 | Then evaluate the BLEU score using [multi-bleu.perl](https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/generic/multi-bleu.perl):
175 | ```perl multi-bleu.perl $data_dir/newstest14.de < newstest14.trans.txt```
176 |
177 | > Notice that the official evaluation has stated clearly that researchers should not use the multi-bleu.perl anymore, because it
178 | heavily relies on the tokenization schema. In fact, tokenization could have a strong influence to the
179 | final BLEU score, particularly when the aggressive mode is used. However, in current stage, multi-bleu.perl is still
180 | the most-widely used evaluation script ~~
181 |
182 | 7. Command line or Seperate configuration file
183 |
184 | In case you dislike the long command line style, you can convert the parameters into a
185 | separate `config.py`. For the training example, you can convert the running comment into follows:
186 | ```
187 | python zero/run.py --mode train --config config.py
188 | ```
189 | where the `config.py` has the following structure:
190 | ```
191 | dict(
192 | hidden_size=1024,
193 | embed_size=512,
194 | dropout=0.1,
195 | label_smooth=0.1,
196 | max_len=80,
197 | batch_size=80,
198 | eval_batch_size=240,
199 | token_size=3000,
200 | batch_or_token='token',
201 | model_name="rnnsearch",
202 | scope_name="rnnsearch",
203 | buffer_size=3200,
204 | clip_grad_norm=5.0,
205 | lrate=5e-4,
206 | epoches=10,
207 | update_cycle=1,
208 | gpus=[3],
209 | disp_freq=100,
210 | eval_freq=10000,
211 | sample_freq=1000,
212 | checkpoints=5,
213 | caencoder=True,
214 | cell='atr',
215 | max_training_steps=100000000,
216 | nthreads=8,
217 | swap_memory=True,
218 | layer_norm=True,
219 | max_queue_size=100,
220 | random_seed=1234,
221 | src_vocab_file="$data_dir/vocab.en",
222 | tgt_vocab_file="$data_dir/vocab.de",
223 | src_train_file="$data_dir/train.32k.en.shuf",
224 | tgt_train_file="$data_dir/train.32k.de.shuf",
225 | src_dev_file="$data_dir/dev.32k.en",
226 | tgt_dev_file="$data_dir/dev.32k.de",
227 | src_test_file="",
228 | tgt_test_file="",
229 | output_dir="train",
230 | test_output="",
231 | )
232 | ```
233 |
234 |
235 | And That's it!
236 |
--------------------------------------------------------------------------------
/evalu.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import time
8 | import numpy as np
9 | import tensorflow as tf
10 |
11 | from utils import queuer, util, metric
12 |
13 |
14 | def decode_target_token(id_seq, vocab):
15 | """Convert sequence ids into tokens"""
16 | valid_id_seq = []
17 | for tok_id in id_seq:
18 | if tok_id == vocab.eos() \
19 | or tok_id == vocab.pad():
20 | break
21 | valid_id_seq.append(tok_id)
22 | return vocab.to_tokens(valid_id_seq)
23 |
24 |
25 | def decode_hypothesis(seqs, scores, params, mask=None):
26 | """Generate decoded sequence from seqs"""
27 | if mask is None:
28 | mask = [1.] * len(seqs)
29 |
30 | hypoes = []
31 | marks = []
32 | for _seqs, _scores, _m in zip(seqs, scores, mask):
33 | if _m < 1.: continue
34 |
35 | for seq, score in zip(_seqs, _scores):
36 | # Temporarily, Use top-1 decoding
37 | best_seq = seq[0]
38 | best_score = score[0]
39 |
40 | hypo = decode_target_token(best_seq, params.tgt_vocab)
41 | mark = best_score
42 |
43 | hypoes.append(hypo)
44 | marks.append(mark)
45 |
46 | return hypoes, marks
47 |
48 |
49 | def decoding(session, features, out_seqs, out_scores, dataset, params):
50 | """Performing decoding with exising information"""
51 | translations = []
52 | scores = []
53 | indices = []
54 |
55 | eval_queue = queuer.EnQueuer(
56 | dataset.batcher(params.eval_batch_size,
57 | buffer_size=params.buffer_size,
58 | shuffle=False,
59 | train=False),
60 | lambda x: x,
61 | worker_processes_num=params.process_num,
62 | input_queue_size=params.input_queue_size,
63 | output_queue_size=params.output_queue_size,
64 | )
65 |
66 | def _predict_one_batch(_data_on_gpu):
67 | feed_dicts = {}
68 |
69 | _step_indices = []
70 | for fidx, shard_data in enumerate(_data_on_gpu):
71 | # define feed_dict
72 | _feed_dict = {
73 | features[fidx]["source"]: shard_data['src'],
74 | }
75 | feed_dicts.update(_feed_dict)
76 |
77 | # collect data indices
78 | _step_indices.extend(shard_data['index'])
79 |
80 | # pick up valid outputs
81 | data_size = len(_data_on_gpu)
82 | valid_out_seqs = out_seqs[:data_size]
83 | valid_out_scores = out_scores[:data_size]
84 |
85 | _decode_seqs, _decode_scores = session.run(
86 | [valid_out_seqs, valid_out_scores], feed_dict=feed_dicts)
87 |
88 | _step_translations, _step_scores = decode_hypothesis(
89 | _decode_seqs, _decode_scores, params
90 | )
91 |
92 | return _step_translations, _step_scores, _step_indices
93 |
94 | very_begin_time = time.time()
95 | data_on_gpu = []
96 | for bidx, data in enumerate(eval_queue):
97 | if bidx == 0:
98 | # remove the data reading time
99 | very_begin_time = time.time()
100 |
101 | data_on_gpu.append(data)
102 | # use multiple gpus, and data samples is not enough
103 | if len(params.gpus) > 0 and len(data_on_gpu) < len(params.gpus):
104 | continue
105 |
106 | start_time = time.time()
107 | step_outputs = _predict_one_batch(data_on_gpu)
108 | data_on_gpu = []
109 |
110 | translations.extend(step_outputs[0])
111 | scores.extend(step_outputs[1])
112 | indices.extend(step_outputs[2])
113 |
114 | tf.logging.info(
115 | "Decoding Batch {} using {:.3f} s, translating {} "
116 | "sentences using {:.3f} s in total".format(
117 | bidx, time.time() - start_time,
118 | len(translations), time.time() - very_begin_time
119 | )
120 | )
121 |
122 | if len(data_on_gpu) > 0:
123 |
124 | start_time = time.time()
125 | step_outputs = _predict_one_batch(data_on_gpu)
126 |
127 | translations.extend(step_outputs[0])
128 | scores.extend(step_outputs[1])
129 | indices.extend(step_outputs[2])
130 |
131 | tf.logging.info(
132 | "Decoding Batch {} using {:.3f} s, translating {} "
133 | "sentences using {:.3f} s in total".format(
134 | 'final', time.time() - start_time,
135 | len(translations), time.time() - very_begin_time
136 | )
137 | )
138 |
139 | return translations, scores, indices
140 |
141 |
142 | def scoring(session, features, out_scores, dataset, params):
143 | """Performing decoding with exising information"""
144 | scores = []
145 | indices = []
146 |
147 | eval_queue = queuer.EnQueuer(
148 | dataset.batcher(params.eval_batch_size,
149 | buffer_size=params.buffer_size,
150 | shuffle=False,
151 | train=False),
152 | lambda x: x,
153 | worker_processes_num=params.process_num,
154 | input_queue_size=params.input_queue_size,
155 | output_queue_size=params.output_queue_size,
156 | )
157 |
158 | total_entropy = 0.
159 | total_tokens = 0.
160 |
161 | def _predict_one_batch(_data_on_gpu):
162 | feed_dicts = {}
163 |
164 | _step_indices = []
165 | for fidx, shard_data in enumerate(_data_on_gpu):
166 | # define feed_dict
167 | _feed_dict = {
168 | features[fidx]["source"]: shard_data['src'],
169 | features[fidx]["target"]: shard_data['tgt'],
170 | }
171 | feed_dicts.update(_feed_dict)
172 |
173 | # collect data indices
174 | _step_indices.extend(shard_data['index'])
175 |
176 | # pick up valid outputs
177 | data_size = len(_data_on_gpu)
178 | valid_out_scores = out_scores[:data_size]
179 |
180 | _decode_scores = session.run(
181 | valid_out_scores, feed_dict=feed_dicts)
182 |
183 | _batch_entropy = sum([s * float((d > 0).sum())
184 | for shard_data, shard_scores in zip(_data_on_gpu, _decode_scores)
185 | for d, s in zip(shard_data['tgt'], shard_scores.tolist())])
186 | _batch_tokens = sum([(shard_data['tgt'] > 0).sum() for shard_data in _data_on_gpu])
187 |
188 | _decode_scores = [s for _scores in _decode_scores for s in _scores]
189 |
190 | return _decode_scores, _step_indices, _batch_entropy, _batch_tokens
191 |
192 | very_begin_time = time.time()
193 | data_on_gpu = []
194 | for bidx, data in enumerate(eval_queue):
195 | if bidx == 0:
196 | # remove the data reading time
197 | very_begin_time = time.time()
198 |
199 | data_on_gpu.append(data)
200 | # use multiple gpus, and data samples is not enough
201 | if len(params.gpus) > 0 and len(data_on_gpu) < len(params.gpus):
202 | continue
203 |
204 | start_time = time.time()
205 | step_outputs = _predict_one_batch(data_on_gpu)
206 | data_on_gpu = []
207 |
208 | scores.extend(step_outputs[0])
209 | indices.extend(step_outputs[1])
210 |
211 | total_entropy += step_outputs[2]
212 | total_tokens += step_outputs[3]
213 |
214 | tf.logging.info(
215 | "Decoding Batch {} using {:.3f} s, translating {} "
216 | "sentences using {:.3f} s in total".format(
217 | bidx, time.time() - start_time,
218 | len(scores), time.time() - very_begin_time
219 | )
220 | )
221 |
222 | if len(data_on_gpu) > 0:
223 |
224 | start_time = time.time()
225 | step_outputs = _predict_one_batch(data_on_gpu)
226 |
227 | scores.extend(step_outputs[0])
228 | indices.extend(step_outputs[1])
229 |
230 | total_entropy += step_outputs[2]
231 | total_tokens += step_outputs[3]
232 |
233 | tf.logging.info(
234 | "Decoding Batch {} using {:.3f} s, translating {} "
235 | "sentences using {:.3f} s in total".format(
236 | 'final', time.time() - start_time,
237 | len(scores), time.time() - very_begin_time
238 | )
239 | )
240 |
241 | scores = [data[1] for data in
242 | sorted(zip(indices, scores), key=lambda x: x[0])]
243 |
244 | ppl = np.exp(total_entropy / total_tokens)
245 |
246 | return scores, ppl
247 |
248 |
249 | def eval_metric(trans, target_file, indices=None):
250 | """BLEU Evaluate """
251 | target_valid_files = util.fetch_valid_ref_files(target_file)
252 | if target_valid_files is None:
253 | return 0.0
254 |
255 | if indices is not None:
256 | trans = [data[1] for data in sorted(zip(indices, trans), key=lambda x: x[0])]
257 |
258 | references = []
259 | for ref_file in target_valid_files:
260 | cur_refs = tf.gfile.Open(ref_file).readlines()
261 | cur_refs = [line.strip().split() for line in cur_refs]
262 | references.append(cur_refs)
263 |
264 | references = list(zip(*references))
265 |
266 | return metric.bleu(trans, references)
267 |
268 |
269 | def dump_tanslation(tranes, output, indices=None):
270 | """save translation"""
271 | if indices is not None:
272 | tranes = [data[1] for data in
273 | sorted(zip(indices, tranes), key=lambda x: x[0])]
274 | with tf.gfile.Open(output, 'w') as writer:
275 | for hypo in tranes:
276 | if isinstance(hypo, list):
277 | writer.write(' '.join(hypo) + "\n")
278 | else:
279 | writer.write(str(hypo) + "\n")
280 | tf.logging.info("Saving translations into {}".format(output))
281 |
--------------------------------------------------------------------------------
/lrs/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from lrs import vanillalr, noamlr, scorelr, gnmtplr, epochlr, cosinelr
4 |
5 |
6 | def get_lr(params):
7 |
8 | strategy = params.lrate_strategy.lower()
9 |
10 | if strategy == "noam":
11 | return noamlr.NoamDecayLr(
12 | params.lrate,
13 | params.min_lrate,
14 | params.max_lrate,
15 | params.warmup_steps,
16 | params.hidden_size
17 | )
18 | elif strategy == "gnmt+":
19 | return gnmtplr.GNMTPDecayLr(
20 | params.lrate,
21 | params.min_lrate,
22 | params.max_lrate,
23 | params.warmup_steps,
24 | params.nstable,
25 | params.lrdecay_start,
26 | params.lrdecay_end
27 | )
28 | elif strategy == "epoch":
29 | return epochlr.EpochDecayLr(
30 | params.lrate,
31 | params.min_lrate,
32 | params.max_lrate,
33 | params.lrate_decay,
34 | )
35 | elif strategy == "score":
36 | return scorelr.ScoreDecayLr(
37 | params.lrate,
38 | params.min_lrate,
39 | params.max_lrate,
40 | history_scores=[v[1] for v in params.recorder.valid_script_scores],
41 | decay=params.lrate_decay,
42 | patience=params.lrate_patience,
43 | )
44 | elif strategy == "vanilla":
45 | return vanillalr.VanillaLR(
46 | params.lrate,
47 | params.min_lrate,
48 | params.max_lrate,
49 | )
50 | elif strategy == "cosine":
51 | return cosinelr.CosineDecayLr(
52 | params.lrate,
53 | params.min_lrate,
54 | params.max_lrate,
55 | params.warmup_steps,
56 | params.lrate_decay,
57 | t_mult=params.cosine_factor,
58 | update_period=params.cosine_period
59 | )
60 | else:
61 | raise NotImplementedError(
62 | "{} is not supported".format(strategy))
63 |
--------------------------------------------------------------------------------
/lrs/cosinelr.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import math
8 |
9 | from lrs import lr
10 |
11 |
12 | class CosineDecayLr(lr.Lr):
13 | """Decay the learning rate during each training step, follows FairSeq"""
14 | def __init__(self,
15 | init_lr, # initial learning rate => warmup_init_lr
16 | min_lr, # minimum learning rate
17 | max_lr, # maximum learning rate
18 | warmup_steps, # warmup step => warmup_updates
19 | decay, # learning rate shrink factor for annealing
20 | t_mult=1, # factor to grow the length of each period
21 | update_period=5000, # initial number of updates per period
22 | name="cosine_decay_lr" # model name, no use
23 | ):
24 | super(CosineDecayLr, self).__init__(init_lr, min_lr, max_lr, name=name)
25 |
26 | self.warmup_steps = warmup_steps
27 |
28 | self.warmup_init_lr = init_lr
29 | self.warmup_end_lr = max_lr
30 | self.t_mult = t_mult
31 | self.period = update_period
32 |
33 | if self.warmup_steps > 0:
34 | self.lr_step = (self.warmup_end_lr - self.warmup_init_lr) / self.warmup_steps
35 | else:
36 | self.lr_step = 1.
37 |
38 | self.decay = decay
39 |
40 | # initial learning rate
41 | self.lrate = init_lr
42 |
43 | def step(self, step):
44 | if step < self.warmup_steps:
45 | self.lrate = self.warmup_init_lr + step * self.lr_step
46 | else:
47 | curr_updates = step - self.warmup_steps
48 | if self.t_mult != 1:
49 | i = math.floor(math.log(1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult))
50 | t_i = self.t_mult ** i * self.period
51 | t_curr = curr_updates - (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period
52 | else:
53 | i = math.floor(curr_updates / self.period)
54 | t_i = self.period
55 | t_curr = curr_updates - (self.period * i)
56 |
57 | lr_shrink = self.decay ** i
58 | min_lr = self.min_lrate * lr_shrink
59 | max_lr = self.max_lrate * lr_shrink
60 |
61 | self.lrate = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i))
62 |
63 | return self.lrate
64 |
--------------------------------------------------------------------------------
/lrs/epochlr.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 |
8 | from lrs import lr
9 |
10 |
11 | class EpochDecayLr(lr.Lr):
12 | """Decay the learning rate after each epoch"""
13 | def __init__(self,
14 | init_lr,
15 | min_lr, # minimum learning rate
16 | max_lr, # maximum learning rate
17 | decay=0.5, # learning rate decay rate
18 | name="epoch_decay_lr"
19 | ):
20 | super(EpochDecayLr, self).__init__(init_lr, min_lr, max_lr, name=name)
21 |
22 | self.decay = decay
23 |
24 | def after_epoch(self, eidx=None):
25 | if eidx is None:
26 | self.lrate = self.init_lrate * self.decay
27 | else:
28 | self.lrate = self.init_lrate * self.decay ** int(eidx)
29 |
--------------------------------------------------------------------------------
/lrs/gnmtplr.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import numpy as np
8 |
9 | from lrs import lr
10 |
11 |
12 | class GNMTPDecayLr(lr.Lr):
13 | """Decay the learning rate during each training step, follows GNMT+"""
14 | def __init__(self,
15 | init_lr, # initial learning rate
16 | min_lr, # minimum learning rate
17 | max_lr, # maximum learning rate
18 | warmup_steps, # warmup step
19 | nstable, # number of replica
20 | lrdecay_start, # start of learning rate decay
21 | lrdecay_end, # end of learning rate decay
22 | name="gnmtp_decay_lr" # model name, no use
23 | ):
24 | super(GNMTPDecayLr, self).__init__(init_lr, min_lr, max_lr, name=name)
25 |
26 | self.warmup_steps = warmup_steps
27 | self.nstable = nstable
28 | self.lrdecay_start = lrdecay_start
29 | self.lrdecay_end = lrdecay_end
30 |
31 | if nstable < 1:
32 | raise Exception("Stabled Lrate Value should "
33 | "greater than 0, but is {}".format(nstable))
34 |
35 | def step(self, step):
36 | t = float(step)
37 | p = float(self.warmup_steps)
38 | n = float(self.nstable)
39 | s = float(self.lrdecay_start)
40 | e = float(self.lrdecay_end)
41 |
42 | decay = np.minimum(1. + t * (n - 1) / (n * p), n)
43 | decay = np.minimum(decay, n * (2 * n) ** ((s - n * t) / (e - s)))
44 |
45 | self.lrate = self.init_lrate * decay
46 |
--------------------------------------------------------------------------------
/lrs/lr.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 |
8 | # This is an abstract class that deals with
9 | # different learning rate decay strategy
10 | # Generally, we decay the learning rate with GPU computation
11 | # However, in this paper, we simply decay the learning rate
12 | # at CPU level, and feed the decayed lr into GPU for
13 | # optimization
14 | class Lr(object):
15 | def __init__(self,
16 | init_lrate, # initial learning rate
17 | min_lrate, # minimum learning rate
18 | max_lrate, # maximum learning rate
19 | name="lr", # learning rate name, no use
20 | ):
21 | self.name = name
22 | self.init_lrate = init_lrate # just record the init learning rate
23 | self.lrate = init_lrate # active learning rate, change with training
24 | self.min_lrate = min_lrate
25 | self.max_lrate = max_lrate
26 |
27 | assert self.max_lrate > self.min_lrate, "Minimum learning rate " \
28 | "should less than maximum learning rate"
29 |
30 | # suppose the eidx starts from 1
31 | def before_epoch(self, eidx=None):
32 | pass
33 |
34 | def after_epoch(self, eidx=None):
35 | pass
36 |
37 | def step(self, step):
38 | pass
39 |
40 | def after_eval(self, eval_score):
41 | pass
42 |
43 | def get_lr(self):
44 | """Return the learning rate whenever you want"""
45 | return max(min(self.lrate, self.max_lrate), self.min_lrate)
46 |
--------------------------------------------------------------------------------
/lrs/noamlr.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import numpy as np
8 |
9 | from lrs import lr
10 |
11 |
12 | class NoamDecayLr(lr.Lr):
13 | """Decay the learning rate during each training step, follows Transformer"""
14 | def __init__(self,
15 | init_lr, # initial learning rate
16 | min_lr, # minimum learning rate
17 | max_lr, # maximum learning rate
18 | warmup_steps, # warmup step
19 | hidden_size, # model hidden size
20 | name="noam_decay_lr" # model name, no use
21 | ):
22 | super(NoamDecayLr, self).__init__(init_lr, min_lr, max_lr, name=name)
23 |
24 | self.warmup_steps = warmup_steps
25 | self.hidden_size = hidden_size
26 |
27 | def step(self, step):
28 | step = float(step)
29 | warmup_steps = float(self.warmup_steps)
30 |
31 | multiplier = float(self.hidden_size) ** -0.5
32 | decay = multiplier * np.minimum((step + 1) * (warmup_steps ** -1.5),
33 | (step + 1) ** -0.5)
34 | self.lrate = self.init_lrate * decay
35 |
--------------------------------------------------------------------------------
/lrs/scorelr.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 |
8 | from lrs import lr
9 |
10 |
11 | class ScoreDecayLr(lr.Lr):
12 | """Decay the learning rate after each evaluation"""
13 | def __init__(self,
14 | init_lr,
15 | min_lr, # minimum learning rate
16 | max_lr, # maximum learning rate
17 | history_scores=None, # evaluation history metric scores, such as BLEU
18 | decay=0.5, # learning rate decay rate
19 | patience=1, # decay after this number of bad counter
20 | name="score_decay_lr" # model name, no use
21 | ):
22 | super(ScoreDecayLr, self).__init__(init_lr, min_lr, max_lr, name=name)
23 |
24 | self.decay = decay
25 | self.patience = patience
26 | self.bad_counter = 0
27 | self.best_score = -1e9
28 |
29 | if history_scores is not None:
30 | for score in history_scores:
31 | self.after_eval(score[1])
32 |
33 | def after_eval(self, eval_score):
34 | if eval_score > self.best_score:
35 | self.best_score = eval_score
36 | self.bad_counter = 0
37 | else:
38 | self.bad_counter += 1
39 | if self.bad_counter >= self.patience:
40 | self.lrate = self.lrate * self.decay
41 |
42 | self.bad_counter = 0
43 |
--------------------------------------------------------------------------------
/lrs/vanillalr.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 |
8 | from lrs import lr
9 |
10 |
11 | class VanillaLR(lr.Lr):
12 | """Very basic learning rate, constant learning rate"""
13 | def __init__(self,
14 | init_lr, # learning rate
15 | min_lr, # minimum learning rate
16 | max_lr, # maximum learning rate
17 | name="vanilla_lr"
18 | ):
19 | super(VanillaLR, self).__init__(init_lr, min_lr, max_lr, name=name)
20 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import tensorflow as tf
8 | from collections import namedtuple
9 |
10 | # global models defined in Zero
11 | _total_models = {}
12 |
13 |
14 | class ModelWrapper(namedtuple("ModelTupleWrapper",
15 | ("train_fn", "score_fn", "infer_fn"))):
16 | pass
17 |
18 |
19 | # you need register your model by your self
20 | def model_register(model_name, train_fn, score_fn, infer_fn):
21 | model_name = model_name.lower()
22 |
23 | if model_name in _total_models:
24 | raise Exception("Conflict Model Name: {}".format(model_name))
25 |
26 | tf.logging.info("Registering model: {}".format(model_name))
27 |
28 | _total_models[model_name] = ModelWrapper(
29 | train_fn=train_fn,
30 | score_fn=score_fn,
31 | infer_fn=infer_fn,
32 | )
33 |
34 |
35 | def get_model(model_name):
36 | model_name = model_name.lower()
37 |
38 | if model_name in _total_models:
39 | return _total_models[model_name]
40 |
41 | raise Exception("No supported model {}".format(model_name))
42 |
--------------------------------------------------------------------------------
/models/rnnsearch.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import copy
8 | import tensorflow as tf
9 |
10 | from func import linear
11 | from models import model
12 | from utils import util, dtype
13 | from rnns import rnn
14 |
15 |
16 | def encoder(source, params):
17 | mask = dtype.tf_to_float(tf.cast(source, tf.bool))
18 | hidden_size = params.hidden_size
19 |
20 | source, mask = util.remove_invalid_seq(source, mask)
21 |
22 | embed_name = "embedding" if params.shared_source_target_embedding \
23 | else "src_embedding"
24 | src_emb = tf.get_variable(embed_name,
25 | [params.src_vocab.size(), params.embed_size])
26 | src_bias = tf.get_variable("bias", [params.embed_size])
27 |
28 | inputs = tf.gather(src_emb, source)
29 | inputs = tf.nn.bias_add(inputs, src_bias)
30 |
31 | inputs = util.valid_apply_dropout(inputs, params.dropout)
32 |
33 | with tf.variable_scope("encoder"):
34 | # forward rnn
35 | with tf.variable_scope('forward'):
36 | outputs = rnn.rnn(params.cell, inputs, hidden_size, mask=mask,
37 | ln=params.layer_norm, sm=params.swap_memory)
38 | output_fw, state_fw = outputs[1]
39 | # backward rnn
40 | with tf.variable_scope('backward'):
41 | if not params.caencoder:
42 | outputs = rnn.rnn(params.cell, tf.reverse(inputs, [1]),
43 | hidden_size, mask=tf.reverse(mask, [1]),
44 | ln=params.layer_norm, sm=params.swap_memory)
45 | output_bw, state_bw = outputs[1]
46 | else:
47 | outputs = rnn.cond_rnn(params.cell, tf.reverse(inputs, [1]),
48 | tf.reverse(output_fw, [1]), hidden_size,
49 | mask=tf.reverse(mask, [1]),
50 | ln=params.layer_norm,
51 | sm=params.swap_memory,
52 | num_heads=params.num_heads,
53 | one2one=True)
54 | output_bw, state_bw = outputs[1]
55 |
56 | output_bw = tf.reverse(output_bw, [1])
57 |
58 | if not params.caencoder:
59 | source_encodes = tf.concat([output_fw, output_bw], -1)
60 | source_feature = tf.concat([state_fw, state_bw], -1)
61 | else:
62 | source_encodes = output_bw
63 | source_feature = state_bw
64 |
65 | with tf.variable_scope("decoder_initializer"):
66 | decoder_init = rnn.get_cell(
67 | params.cell, hidden_size, ln=params.layer_norm
68 | ).get_init_state(x=source_feature)
69 | decoder_init = tf.tanh(decoder_init)
70 |
71 | return {
72 | "encodes": source_encodes,
73 | "decoder_initializer": decoder_init,
74 | "mask": mask
75 | }
76 |
77 |
78 | def decoder(target, state, params):
79 | mask = dtype.tf_to_float(tf.cast(target, tf.bool))
80 | hidden_size = params.hidden_size
81 |
82 | is_training = ('decoder' not in state)
83 |
84 | if is_training:
85 | target, mask = util.remove_invalid_seq(target, mask)
86 |
87 | embed_name = "embedding" if params.shared_source_target_embedding \
88 | else "tgt_embedding"
89 | tgt_emb = tf.get_variable(embed_name,
90 | [params.tgt_vocab.size(), params.embed_size])
91 | tgt_bias = tf.get_variable("bias", [params.embed_size])
92 |
93 | inputs = tf.gather(tgt_emb, target)
94 | inputs = tf.nn.bias_add(inputs, tgt_bias)
95 |
96 | # shift
97 | if is_training:
98 | inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]])
99 | inputs = inputs[:, :-1, :]
100 | else:
101 | inputs = tf.cond(tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())),
102 | lambda: tf.zeros_like(inputs),
103 | lambda: inputs)
104 | mask = tf.ones_like(mask)
105 |
106 | inputs = util.valid_apply_dropout(inputs, params.dropout)
107 |
108 | with tf.variable_scope("decoder"):
109 | init_state = state["decoder_initializer"]
110 | if not is_training:
111 | init_state = state["decoder"]["state"]
112 | returns = rnn.cond_rnn(params.cell, inputs, state["encodes"], hidden_size,
113 | init_state=init_state, mask=mask,
114 | mem_mask=state["mask"], ln=params.layer_norm,
115 | sm=params.swap_memory, one2one=False)
116 | (_, hidden_state), (outputs, _), contexts, attentions = returns
117 |
118 | feature = linear([outputs, contexts, inputs], params.embed_size,
119 | ln=params.layer_norm, scope="pre_logits")
120 | if 'dev_decode' in state:
121 | feature = feature[:, -1, :]
122 |
123 | feature = tf.tanh(feature)
124 | feature = util.valid_apply_dropout(feature, params.dropout)
125 |
126 | embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \
127 | else "softmax_embedding"
128 | embed_name = "embedding" if params.shared_source_target_embedding \
129 | else embed_name
130 | softmax_emb = tf.get_variable(embed_name,
131 | [params.tgt_vocab.size(), params.embed_size])
132 | feature = tf.reshape(feature, [-1, params.embed_size])
133 | logits = tf.matmul(feature, softmax_emb, False, True)
134 |
135 | logits = tf.cast(logits, tf.float32)
136 |
137 | soft_label, normalizer = util.label_smooth(
138 | target,
139 | util.shape_list(logits)[-1],
140 | factor=params.label_smooth)
141 | centropy = tf.nn.softmax_cross_entropy_with_logits_v2(
142 | logits=logits,
143 | labels=soft_label
144 | )
145 | centropy -= normalizer
146 | centropy = tf.reshape(centropy, tf.shape(target))
147 |
148 | mask = tf.cast(mask, tf.float32)
149 | per_sample_loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum(mask, -1)
150 | loss = tf.reduce_mean(per_sample_loss)
151 |
152 | # these mask tricks mainly used to deal with zero shapes, such as [0, 1]
153 | loss = tf.cond(tf.equal(tf.shape(target)[0], 0),
154 | lambda: tf.constant(0, dtype=tf.float32),
155 | lambda: loss)
156 |
157 | if not is_training:
158 | state['decoder']['state'] = hidden_state
159 |
160 | return loss, logits, state, per_sample_loss
161 |
162 |
163 | def train_fn(features, params, initializer=None):
164 | with tf.variable_scope(params.scope_name or "model",
165 | initializer=initializer,
166 | reuse=tf.AUTO_REUSE,
167 | dtype=tf.as_dtype(dtype.floatx()),
168 | custom_getter=dtype.float32_variable_storage_getter):
169 | state = encoder(features['source'], params)
170 | loss, logits, state, _ = decoder(features['target'], state, params)
171 |
172 | return {
173 | "loss": loss
174 | }
175 |
176 |
177 | def score_fn(features, params, initializer=None):
178 | params = copy.copy(params)
179 | params = util.closing_dropout(params)
180 | params.label_smooth = 0.0
181 | with tf.variable_scope(params.scope_name or "model",
182 | initializer=initializer,
183 | reuse=tf.AUTO_REUSE,
184 | dtype=tf.as_dtype(dtype.floatx()),
185 | custom_getter=dtype.float32_variable_storage_getter):
186 | state = encoder(features['source'], params)
187 | _, _, _, scores = decoder(features['target'], state, params)
188 |
189 | return {
190 | "score": scores
191 | }
192 |
193 |
194 | def infer_fn(params):
195 | params = copy.copy(params)
196 | params = util.closing_dropout(params)
197 |
198 | def encoding_fn(source):
199 | with tf.variable_scope(params.scope_name or "model",
200 | reuse=tf.AUTO_REUSE,
201 | dtype=tf.as_dtype(dtype.floatx()),
202 | custom_getter=dtype.float32_variable_storage_getter):
203 | state = encoder(source, params)
204 | state["decoder"] = {
205 | "state": state["decoder_initializer"]
206 | }
207 | return state
208 |
209 | def decoding_fn(target, state, time):
210 | with tf.variable_scope(params.scope_name or "model",
211 | reuse=tf.AUTO_REUSE,
212 | dtype=tf.as_dtype(dtype.floatx()),
213 | custom_getter=dtype.float32_variable_storage_getter):
214 | if params.search_mode == "cache":
215 | step_loss, step_logits, step_state, _ = decoder(
216 | target, state, params)
217 | else:
218 | estate = encoder(state, params)
219 | estate['dev_decode'] = True
220 | _, step_logits, _, _ = decoder(target, estate, params)
221 | step_state = state
222 |
223 | return step_logits, step_state
224 |
225 | return encoding_fn, decoding_fn
226 |
227 |
228 | # register the model, with a unique name
229 | model.model_register("rnnsearch", train_fn, score_fn, infer_fn)
230 |
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
--------------------------------------------------------------------------------
/modules/fixup.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import math
8 | import tensorflow as tf
9 |
10 | import func
11 | from utils import util, dtype
12 | from modules import rpr, initializer
13 |
14 |
15 | def shift_layer(x, scope="shift"):
16 | with tf.variable_scope(scope or "shift"):
17 | offset = tf.get_variable("offset", [1], initializer=tf.zeros_initializer())
18 | return x - offset
19 |
20 |
21 | def scale_layer(x, init=1., scope="scale"):
22 | with tf.variable_scope(scope or "scale"):
23 | scale = tf.get_variable(
24 | "scale", [1],
25 | initializer=initializer.scale_initializer(init, tf.ones_initializer()))
26 | return x * scale
27 |
28 |
29 | def ffn_layer(x, d, d_o, dropout=None, scope=None, numblocks=None):
30 | """
31 | FFN layer in Transformer
32 | :param numblocks: size of 'L' in fixup paper
33 | :param scope:
34 | """
35 | with tf.variable_scope(scope or "ffn_layer",
36 | dtype=tf.as_dtype(dtype.floatx())) as scope:
37 | assert numblocks is not None, 'Fixup requires the total model depth L'
38 |
39 | in_initializer = initializer.scale_initializer(
40 | math.pow(numblocks, -1. / 2.), scope.initializer)
41 |
42 | x = shift_layer(x)
43 | hidden = func.linear(x, d, scope="enlarge",
44 | weight_initializer=in_initializer, bias=False)
45 | hidden = shift_layer(hidden)
46 | hidden = tf.nn.relu(hidden)
47 |
48 | hidden = util.valid_apply_dropout(hidden, dropout)
49 |
50 | hidden = shift_layer(hidden)
51 | output = func.linear(hidden, d_o, scope="output", bias=False,
52 | weight_initializer=tf.zeros_initializer())
53 | output = scale_layer(output)
54 |
55 | return output
56 |
57 |
58 | def dot_attention(query, memory, mem_mask, hidden_size,
59 | ln=False, num_heads=1, cache=None, dropout=None,
60 | use_relative_pos=False, max_relative_position=16,
61 | out_map=True, scope=None, fuse_mask=None,
62 | decode_step=None, numblocks=None):
63 | """
64 | dotted attention model
65 | :param query: [batch_size, qey_len, dim]
66 | :param memory: [batch_size, seq_len, mem_dim] or None
67 | :param mem_mask: [batch_size, seq_len]
68 | :param hidden_size: attention space dimension
69 | :param ln: whether use layer normalization
70 | :param num_heads: attention head number
71 | :param dropout: attention dropout, default disable
72 | :param out_map: output additional mapping
73 | :param cache: cache-based decoding
74 | :param fuse_mask: aan mask during training, and timestep for testing
75 | :param max_relative_position: maximum position considered for relative embedding
76 | :param use_relative_pos: whether use relative position information
77 | :param decode_step: the time step of current decoding, 0-based
78 | :param numblocks: size of 'L' in fixup paper
79 | :param scope:
80 | :return: a value matrix, [batch_size, qey_len, mem_dim]
81 | """
82 | with tf.variable_scope(scope or "dot_attention", reuse=tf.AUTO_REUSE,
83 | dtype=tf.as_dtype(dtype.floatx())) as scope:
84 | if fuse_mask:
85 | assert memory is not None, 'Fuse mechanism only applied with cross-attention'
86 | if cache and use_relative_pos:
87 | assert decode_step is not None, 'Decode Step must provide when use relative position encoding'
88 |
89 | assert numblocks is not None, 'Fixup requires the total model depth L'
90 |
91 | scale_base = 6. if fuse_mask is None else 8.
92 | in_initializer = initializer.scale_initializer(
93 | math.pow(numblocks, -1. / scale_base), scope.initializer)
94 |
95 | if memory is None:
96 | # suppose self-attention from queries alone
97 | h = func.linear(query, hidden_size * 3, ln=ln, scope="qkv_map",
98 | weight_initializer=in_initializer, bias=False)
99 | q, k, v = tf.split(h, 3, -1)
100 |
101 | if cache is not None:
102 | k = tf.concat([cache['k'], k], axis=1)
103 | v = tf.concat([cache['v'], v], axis=1)
104 | cache = {
105 | 'k': k,
106 | 'v': v,
107 | }
108 | else:
109 | q = func.linear(query, hidden_size, ln=ln, scope="q_map",
110 | weight_initializer=in_initializer, bias=False)
111 | if cache is not None and ('mk' in cache and 'mv' in cache):
112 | k, v = cache['mk'], cache['mv']
113 | else:
114 | k = func.linear(memory, hidden_size, ln=ln, scope="k_map",
115 | weight_initializer=in_initializer, bias=False)
116 | v = func.linear(memory, hidden_size, ln=ln, scope="v_map",
117 | weight_initializer=in_initializer, bias=False)
118 |
119 | if cache is not None:
120 | cache['mk'] = k
121 | cache['mv'] = v
122 |
123 | q = func.split_heads(q, num_heads)
124 | k = func.split_heads(k, num_heads)
125 | v = func.split_heads(v, num_heads)
126 |
127 | q *= (hidden_size // num_heads) ** (-0.5)
128 |
129 | q_shp = util.shape_list(q)
130 | k_shp = util.shape_list(k)
131 | v_shp = util.shape_list(v)
132 |
133 | q_len = q_shp[2] if decode_step is None else decode_step + 1
134 | r_lst = None if decode_step is None else 1
135 |
136 | # q * k => attention weights
137 | if use_relative_pos:
138 | r = rpr.get_relative_positions_embeddings(
139 | q_len, k_shp[2], k_shp[3],
140 | max_relative_position, name="rpr_keys", last=r_lst)
141 | logits = rpr.relative_attention_inner(q, k, r, transpose=True)
142 | else:
143 | logits = tf.matmul(q, k, transpose_b=True)
144 |
145 | if mem_mask is not None:
146 | logits += mem_mask
147 |
148 | weights = tf.nn.softmax(logits)
149 |
150 | dweights = util.valid_apply_dropout(weights, dropout)
151 |
152 | # weights * v => attention vectors
153 | if use_relative_pos:
154 | r = rpr.get_relative_positions_embeddings(
155 | q_len, k_shp[2], v_shp[3],
156 | max_relative_position, name="rpr_values", last=r_lst)
157 | o = rpr.relative_attention_inner(dweights, v, r, transpose=False)
158 | else:
159 | o = tf.matmul(dweights, v)
160 |
161 | o = func.combine_heads(o)
162 |
163 | if fuse_mask is not None:
164 | # This is for AAN, the important part is sharing v_map
165 | v_q = func.linear(query, hidden_size, ln=ln, scope="v_map",
166 | weight_initializer=in_initializer, bias=False)
167 |
168 | if cache is not None and 'aan' in cache:
169 | aan_o = (v_q + cache['aan']) / dtype.tf_to_float(fuse_mask + 1)
170 | else:
171 | # Simplified Average Attention Network
172 | aan_o = tf.matmul(fuse_mask, v_q)
173 |
174 | if cache is not None:
175 | if 'aan' not in cache:
176 | cache['aan'] = v_q
177 | else:
178 | cache['aan'] = v_q + cache['aan']
179 |
180 | # Directly sum both self-attention and cross attention
181 | o = o + aan_o
182 |
183 | if out_map:
184 | o = func.linear(o, hidden_size, ln=ln, scope="o_map",
185 | weight_initializer=tf.zeros_initializer(), bias=False)
186 |
187 | results = {
188 | 'weights': weights,
189 | 'output': o,
190 | 'cache': cache
191 | }
192 |
193 | return results
194 |
--------------------------------------------------------------------------------
/modules/initializer.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import tensorflow as tf
8 | from utils import dtype
9 |
10 |
11 | def get_initializer(initializer, initializer_gain):
12 | tfdtype = tf.as_dtype(dtype.floatx())
13 |
14 | if initializer == "uniform":
15 | max_val = initializer_gain
16 | return tf.random_uniform_initializer(-max_val, max_val, dtype=tfdtype)
17 | elif initializer == "normal":
18 | return tf.random_normal_initializer(0.0, initializer_gain, dtype=tfdtype)
19 | elif initializer == "normal_unit_scaling":
20 | return tf.variance_scaling_initializer(initializer_gain,
21 | mode="fan_avg",
22 | distribution="normal",
23 | dtype=tfdtype)
24 | elif initializer == "uniform_unit_scaling":
25 | return tf.variance_scaling_initializer(initializer_gain,
26 | mode="fan_avg",
27 | distribution="uniform",
28 | dtype=tfdtype)
29 | else:
30 | tf.logging.warn("Unrecognized initializer: %s" % initializer)
31 | tf.logging.warn("Return to default initializer: glorot_uniform_initializer")
32 | return tf.glorot_uniform_initializer(dtype=tfdtype)
33 |
34 |
35 | def scale_initializer(scale, initializer):
36 | """Rescale the value given by initializer"""
37 | tfdtype = tf.as_dtype(dtype.floatx())
38 |
39 | def _initializer(shape, dtype=tfdtype, partition_info=None):
40 | value = initializer(shape, dtype=dtype, partition_info=partition_info)
41 | value *= scale
42 |
43 | return value
44 |
45 | return _initializer
46 |
--------------------------------------------------------------------------------
/modules/l0norm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Defines common utilities for l0-regularization layers."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | # Small constant value to add when taking logs or sqrts to avoid NaNs
24 | EPSILON = 1e-8
25 |
26 | # The default hard-concrete distribution parameters
27 | BETA = 2.0 / 3.0
28 | GAMMA = -0.1
29 | ZETA = 1.1
30 |
31 |
32 | def hard_concrete_sample(
33 | log_alpha,
34 | beta=BETA,
35 | gamma=GAMMA,
36 | zeta=ZETA,
37 | eps=EPSILON):
38 | """Sample values from the hard concrete distribution.
39 |
40 | The hard concrete distribution is described in
41 | https://arxiv.org/abs/1712.01312.
42 |
43 | Args:
44 | log_alpha: The log alpha parameters that control the "location" of the
45 | distribution.
46 | beta: The beta parameter, which controls the "temperature" of
47 | the distribution. Defaults to 2/3 from the above paper.
48 | gamma: The gamma parameter, which controls the lower bound of the
49 | stretched distribution. Defaults to -0.1 from the above paper.
50 | zeta: The zeta parameters, which controls the upper bound of the
51 | stretched distribution. Defaults to 1.1 from the above paper.
52 | eps: A small constant value to add to logs and sqrts to avoid NaNs.
53 |
54 | Returns:
55 | A tf.Tensor representing the output of the sampling operation.
56 | """
57 | random_noise = tf.random_uniform(
58 | tf.shape(log_alpha),
59 | minval=0.0,
60 | maxval=1.0)
61 |
62 | # NOTE: We add a small constant value to the noise before taking the
63 | # log to avoid NaNs if a noise value is exactly zero. We sample values
64 | # in the range [0, 1), so the right log is not at risk of NaNs.
65 | gate_inputs = tf.log(random_noise + eps) - tf.log(1.0 - random_noise)
66 | gate_inputs = tf.sigmoid((gate_inputs + log_alpha) / beta)
67 | stretched_values = gate_inputs * (zeta - gamma) + gamma
68 |
69 | return tf.clip_by_value(
70 | stretched_values,
71 | clip_value_max=1.0,
72 | clip_value_min=0.0)
73 |
74 |
75 | def hard_concrete_mean(log_alpha, gamma=GAMMA, zeta=ZETA):
76 | """Calculate the mean of the hard concrete distribution.
77 |
78 | The hard concrete distribution is described in
79 | https://arxiv.org/abs/1712.01312.
80 |
81 | Args:
82 | log_alpha: The log alpha parameters that control the "location" of the
83 | distribution.
84 | gamma: The gamma parameter, which controls the lower bound of the
85 | stretched distribution. Defaults to -0.1 from the above paper.
86 | zeta: The zeta parameters, which controls the upper bound of the
87 | stretched distribution. Defaults to 1.1 from the above paper.
88 |
89 | Returns:
90 | A tf.Tensor representing the calculated means.
91 | """
92 | stretched_values = tf.sigmoid(log_alpha) * (zeta - gamma) + gamma
93 | return tf.clip_by_value(
94 | stretched_values,
95 | clip_value_max=1.0,
96 | clip_value_min=0.0)
97 |
98 |
99 | def l0_norm(
100 | log_alpha,
101 | beta=BETA,
102 | gamma=GAMMA,
103 | zeta=ZETA):
104 | """Calculate the l0-regularization contribution to the loss.
105 | Args:
106 | log_alpha: Tensor of the log alpha parameters for the hard concrete
107 | distribution.
108 | beta: The beta parameter, which controls the "temperature" of
109 | the distribution. Defaults to 2/3 from the above paper.
110 | gamma: The gamma parameter, which controls the lower bound of the
111 | stretched distribution. Defaults to -0.1 from the above paper.
112 | zeta: The zeta parameters, which controls the upper bound of the
113 | stretched distribution. Defaults to 1.1 from the above paper.
114 | Returns:
115 | Scalar tensor containing the unweighted l0-regularization term contribution
116 | to the loss.
117 | """
118 | # Value of the CDF of the hard-concrete distribution evaluated at 0
119 | reg_per_weight = tf.sigmoid(log_alpha - beta * tf.log(-gamma / zeta))
120 | return reg_per_weight
121 |
122 |
123 | def var_train(
124 | weight_parameters,
125 | beta=BETA,
126 | gamma=GAMMA,
127 | zeta=ZETA,
128 | eps=EPSILON):
129 | """Model training, sampling hard concrete variables"""
130 | theta, log_alpha = weight_parameters
131 |
132 | # Sample the z values from the hard-concrete distribution
133 | weight_noise = hard_concrete_sample(
134 | log_alpha,
135 | beta,
136 | gamma,
137 | zeta,
138 | eps)
139 | weights = theta * weight_noise
140 |
141 | return weights, weight_noise
142 |
143 |
144 | def l0_regularization_loss(l0_norm_loss,
145 | reg_scalar=1.0,
146 | start_reg_ramp_up=0,
147 | end_reg_ramp_up=1000,
148 | warm_up=True):
149 | """Calculate the l0-norm weight for this iteration"""
150 | step = tf.train.get_or_create_global_step()
151 | current_step_reg = tf.maximum(
152 | 0.0,
153 | tf.cast(step - start_reg_ramp_up, tf.float32))
154 |
155 | fraction_ramp_up_completed = tf.minimum(
156 | current_step_reg / (end_reg_ramp_up - start_reg_ramp_up), 1.0)
157 |
158 | if warm_up:
159 | # regularizer intensifies over the course of ramp-up
160 | reg_scalar = fraction_ramp_up_completed * reg_scalar
161 |
162 | l0_norm_loss = reg_scalar * l0_norm_loss
163 | return l0_norm_loss
164 |
165 |
166 | def var_eval(
167 | weight_parameters,
168 | gamma=GAMMA,
169 | zeta=ZETA):
170 | """Model evaluation, obtain mean value"""
171 | theta, log_alpha = weight_parameters
172 |
173 | # Use the mean of the learned hard-concrete distribution as the
174 | # deterministic weight noise at evaluation time
175 | weight_noise = hard_concrete_mean(log_alpha, gamma, zeta)
176 | weights = theta * weight_noise
177 | return weights, weight_noise
178 |
--------------------------------------------------------------------------------
/modules/rela.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import tensorflow as tf
8 |
9 | import func
10 | from utils import util, dtype
11 |
12 |
13 | def dot_attention(query, memory, mem_mask, hidden_size,
14 | ln=False, num_heads=1, cache=None, dropout=None,
15 | out_map=True, scope=None):
16 | """
17 | dotted attention model
18 | :param query: [batch_size, qey_len, dim]
19 | :param memory: [batch_size, seq_len, mem_dim] or None
20 | :param mem_mask: [batch_size, seq_len]
21 | :param hidden_size: attention space dimension
22 | :param ln: whether use layer normalization
23 | :param num_heads: attention head number
24 | :param dropout: attention dropout, default disable
25 | :param out_map: output additional mapping
26 | :param cache: cache-based decoding
27 | :param scope:
28 | :return: a value matrix, [batch_size, qey_len, mem_dim]
29 | """
30 | with tf.variable_scope(scope or "dot_attention", reuse=tf.AUTO_REUSE,
31 | dtype=tf.as_dtype(dtype.floatx())):
32 | if memory is None:
33 | # suppose self-attention from queries alone
34 | h = func.linear(query, hidden_size * 3, ln=ln, scope="qkv_map")
35 | q, k, v = tf.split(h, 3, -1)
36 |
37 | if cache is not None:
38 | k = tf.concat([cache['k'], k], axis=1)
39 | v = tf.concat([cache['v'], v], axis=1)
40 | cache = {
41 | 'k': k,
42 | 'v': v,
43 | }
44 | else:
45 | q = func.linear(query, hidden_size, ln=ln, scope="q_map")
46 | if cache is not None and ('mk' in cache and 'mv' in cache):
47 | k, v = cache['mk'], cache['mv']
48 | else:
49 | k = func.linear(memory, hidden_size, ln=ln, scope="k_map")
50 | v = func.linear(memory, hidden_size, ln=ln, scope="v_map")
51 |
52 | if cache is not None:
53 | cache['mk'] = k
54 | cache['mv'] = v
55 |
56 | q = func.split_heads(q, num_heads)
57 | k = func.split_heads(k, num_heads)
58 | v = func.split_heads(v, num_heads)
59 |
60 | q *= (hidden_size // num_heads) ** (-0.5)
61 |
62 | # q * k => attention weights
63 | logits = tf.matmul(q, k, transpose_b=True)
64 |
65 | # convert the mask to 0-1 form and multiply to logits
66 | if mem_mask is not None:
67 | zero_one_mask = tf.to_float(tf.equal(mem_mask, 0.0))
68 | logits *= zero_one_mask
69 |
70 | # replace softmax with relu
71 | # weights = tf.nn.softmax(logits)
72 | weights = tf.nn.relu(logits)
73 |
74 | dweights = util.valid_apply_dropout(weights, dropout)
75 |
76 | # weights * v => attention vectors
77 | o = tf.matmul(dweights, v)
78 | o = func.combine_heads(o)
79 |
80 | # perform RMSNorm to stabilize running
81 | o = gated_rms_norm(o, scope="post")
82 |
83 | if out_map:
84 | o = func.linear(o, hidden_size, ln=ln, scope="o_map")
85 |
86 | results = {
87 | 'weights': weights,
88 | 'output': o,
89 | 'cache': cache
90 | }
91 |
92 | return results
93 |
94 |
95 | def gated_rms_norm(x, eps=None, scope=None):
96 | """RMS-based Layer normalization layer"""
97 | if eps is None:
98 | eps = dtype.epsilon()
99 | with tf.variable_scope(scope or "rms_norm",
100 | dtype=tf.as_dtype(dtype.floatx())):
101 | layer_size = util.shape_list(x)[-1]
102 |
103 | scale = tf.get_variable("scale", [layer_size], initializer=tf.ones_initializer())
104 | gate = tf.get_variable("gate", [layer_size], initializer=None)
105 |
106 | ms = tf.reduce_mean(x ** 2, -1, keep_dims=True)
107 |
108 | # adding gating here which slightly improves quality
109 | return scale * x * tf.rsqrt(ms + eps) * tf.nn.sigmoid(gate * x)
110 |
--------------------------------------------------------------------------------
/modules/rpr.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import tensorflow as tf
8 |
9 |
10 | def relative_attention_inner(x, y, z=None, transpose=False):
11 | """Relative position-aware dot-product attention inner calculation.
12 | This batches matrix multiply calculations to avoid unnecessary broadcasting.
13 |
14 | Args:
15 | x: Tensor with shape [batch_size, heads, length, length or depth].
16 | y: Tensor with shape [batch_size, heads, length, depth].
17 | z: Tensor with shape [length, length, depth].
18 | transpose: Whether to transpose inner matrices of y and z. Should be true if
19 | last dimension of x is depth, not length.
20 |
21 | Returns:
22 | A Tensor with shape [batch_size, heads, length, length or depth].
23 | """
24 | batch_size = tf.shape(x)[0]
25 | heads = x.get_shape().as_list()[1]
26 | length = tf.shape(x)[2]
27 |
28 | # xy_matmul is [batch_size, heads, length, length or depth]
29 | xy_matmul = tf.matmul(x, y, transpose_b=transpose)
30 | if z is not None:
31 | # x_t is [length, batch_size, heads, length or depth]
32 | x_t = tf.transpose(x, [2, 0, 1, 3])
33 | # x_t_r is [length, batch_size * heads, length or depth]
34 | x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1])
35 | # x_tz_matmul is [length, batch_size * heads, length or depth]
36 | x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose)
37 | # x_tz_matmul_r is [length, batch_size, heads, length or depth]
38 | x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1])
39 | # x_tz_matmul_r_t is [batch_size, heads, length, length or depth]
40 | x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3])
41 | return xy_matmul + x_tz_matmul_r_t
42 | else:
43 | return xy_matmul
44 |
45 |
46 | def get_relative_positions_embeddings(length_x, length_y,
47 | depth, max_relative_position, name=None, last=None):
48 | """Generates tensor of size [length_x, length_y, depth]."""
49 | with tf.variable_scope(name or "rpr"):
50 | relative_positions_matrix = get_relative_positions_matrix(
51 | length_x, length_y, max_relative_position)
52 | # to handle cached decoding, where target-token incrementally grows
53 | if last is not None:
54 | relative_positions_matrix = relative_positions_matrix[-last:]
55 | vocab_size = max_relative_position * 2 + 1
56 | # Generates embedding for each relative position of dimension depth.
57 | embeddings_table = tf.get_variable("embeddings", [vocab_size, depth])
58 | embeddings = tf.gather(embeddings_table, relative_positions_matrix)
59 | return embeddings
60 |
61 |
62 | def get_relative_positions_matrix(length_x, length_y, max_relative_position):
63 | """Generates matrix of relative positions between inputs."""
64 | range_vec_x = tf.range(length_x)
65 | range_vec_y = tf.range(length_y)
66 |
67 | # shape: [length_x, length_y]
68 | distance_mat = tf.expand_dims(range_vec_x, -1) - tf.expand_dims(range_vec_y, 0)
69 | distance_mat_clipped = tf.clip_by_value(distance_mat, -max_relative_position,
70 | max_relative_position)
71 |
72 | # Shift values to be >= 0. Each integer still uniquely identifies a relative
73 | # position difference.
74 | final_mat = distance_mat_clipped + max_relative_position
75 | return final_mat
76 |
--------------------------------------------------------------------------------
/rnns/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from rnns import gru, lstm, atr, sru, lrn, olrn
4 |
5 |
6 | def get_cell(cell_name, hidden_size, ln=False, scope=None):
7 | """Convert the cell_name into cell instance."""
8 | cell_name = cell_name.lower()
9 |
10 | if cell_name == "gru":
11 | return gru.gru(hidden_size, ln=ln, scope=scope or "gru")
12 | elif cell_name == "lstm":
13 | return lstm.lstm(hidden_size, ln=ln, scope=scope or "lstm")
14 | elif cell_name == "atr":
15 | return atr.atr(hidden_size, ln=ln, scope=scope or "atr")
16 | elif cell_name == "sru":
17 | return sru.sru(hidden_size, ln=ln, scope=scope or "sru")
18 | elif cell_name == "lrn":
19 | return lrn.lrn(hidden_size, ln=ln, scope=scope or "lrn")
20 | elif cell_name == "olrn":
21 | return olrn.olrn(hidden_size, ln=ln, scope=scope or "olrn")
22 | else:
23 | raise NotImplementedError("{} is not supported".format(cell_name))
24 |
--------------------------------------------------------------------------------
/rnns/atr.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import tensorflow as tf
8 |
9 | from func import linear
10 | from rnns import cell as cell
11 |
12 |
13 | class atr(cell.Cell):
14 | """The Addition-Subtraction Twin-Gated Recurrent Unit."""
15 |
16 | def __init__(self, d, ln=False, twin=True, scope='atr'):
17 | super(atr, self).__init__(d, ln=ln, scope=scope)
18 |
19 | self.twin = twin
20 |
21 | def get_init_state(self, shape=None, x=None, scope=None):
22 | return self._get_init_state(
23 | self.d, shape=shape, x=x, scope=scope)
24 |
25 | def fetch_states(self, x):
26 | with tf.variable_scope(
27 | "fetch_state_{}".format(self.scope or "atr")):
28 | h = linear(x, self.d,
29 | bias=False, ln=self.ln, scope="hide_x")
30 | return (h, )
31 |
32 | def __call__(self, h_, x):
33 | # h_: the previous hidden state
34 | # x: the current input state
35 | """
36 | p = W x
37 | q = U h_
38 | i = sigmoid(p + q)
39 | f = sigmoid(p - q)
40 | h = i * p + f * h_
41 | """
42 | if isinstance(x, (list, tuple)):
43 | x = x[0]
44 |
45 | with tf.variable_scope(
46 | "cell_{}".format(self.scope or "atr")):
47 | q = linear(h_, self.d,
48 | ln=self.ln, scope="hide_h")
49 | p = x
50 |
51 | f = tf.sigmoid(p - q)
52 | if self.twin:
53 | i = tf.sigmoid(p + q)
54 | # we empirically find that the following simple form is more stable.
55 | else:
56 | i = 1. - f
57 |
58 | h = i * p + f * h_
59 |
60 | return h
61 |
--------------------------------------------------------------------------------
/rnns/cell.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import abc
8 | import tensorflow as tf
9 | from func import linear
10 | from utils import dtype
11 |
12 |
13 | # This is an abstract class that deals with
14 | # recurrent cells, e.g. GRU, LSTM, ATR
15 | class Cell(object):
16 | def __init__(self,
17 | d, # hidden state dimension
18 | ln=False, # whether use layer normalization
19 | scope=None, # the name scope for this cell
20 | ):
21 | self.d = d
22 | self.scope = scope
23 | self.ln = ln
24 |
25 | def _get_init_state(self, d, shape=None, x=None, scope=None):
26 | # gen init state vector
27 | # if no evidence x is provided, use zero initialization
28 | if x is None:
29 | assert shape is not None, "you should provide shape"
30 | if not isinstance(shape, (tuple, list)):
31 | shape = [shape]
32 | shape = shape + [d]
33 | return dtype.tf_to_float(tf.zeros(shape))
34 | else:
35 | return linear(
36 | x, d, bias=True, ln=self.ln,
37 | scope="{}_init".format(scope or self.scope)
38 | )
39 |
40 | def get_hidden(self, x):
41 | return x
42 |
43 | @abc.abstractmethod
44 | def get_init_state(self, shape=None, x=None, scope=None):
45 | raise NotImplementedError("Not Supported")
46 |
47 | @abc.abstractmethod
48 | def __call__(self, h_, x):
49 | raise NotImplementedError("Not Supported")
50 |
51 | @abc.abstractmethod
52 | def fetch_states(self, x):
53 | raise NotImplementedError("Not Supported")
54 |
--------------------------------------------------------------------------------
/rnns/gru.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import tensorflow as tf
8 |
9 | from func import linear
10 | from rnns import cell as cell
11 |
12 |
13 | class gru(cell.Cell):
14 | """The Gated Recurrent Unit."""
15 |
16 | def __init__(self, d, ln=False, scope='gru'):
17 | super(gru, self).__init__(d, ln=ln, scope=scope)
18 |
19 | def get_init_state(self, shape=None, x=None, scope=None):
20 | return self._get_init_state(
21 | self.d, shape=shape, x=x, scope=scope)
22 |
23 | def fetch_states(self, x):
24 | with tf.variable_scope(
25 | "fetch_state_{}".format(self.scope or "gru")):
26 | g = linear(x, self.d * 2,
27 | bias=False, ln=self.ln, scope="gate_x")
28 | h = linear(x, self.d,
29 | bias=False, ln=self.ln, scope="hide_x")
30 | return g, h
31 |
32 | def __call__(self, h_, x):
33 | # h_: the previous hidden state
34 | # x_g/x: the current input state for gate
35 | # x_h/x: the current input state for hidden
36 | """
37 | z = sigmoid(h_, x)
38 | r = sigmoid(h_, x)
39 | h' = tanh(x, r * h_)
40 | h = z * h_ + (1. - z) * h'
41 | """
42 | with tf.variable_scope(
43 | "cell_{}".format(self.scope or "gru")):
44 | x_g, x_h = x
45 |
46 | h_g = linear(h_, self.d * 2,
47 | ln=self.ln, scope="gate_h")
48 | z, r = tf.split(
49 | tf.sigmoid(x_g + h_g), 2, -1)
50 |
51 | h_h = linear(h_ * r, self.d,
52 | ln=self.ln, scope="hide_h")
53 | h = tf.tanh(x_h + h_h)
54 |
55 | h = z * h_ + (1. - z) * h
56 |
57 | return h
58 |
--------------------------------------------------------------------------------
/rnns/lrn.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import tensorflow as tf
8 |
9 | from func import linear
10 | from rnns import cell as cell
11 |
12 |
13 | class lrn(cell.Cell):
14 | """The Recurrence-Free Addition-Subtraction Twin-Gated Recurrent Unit.
15 | Or, lightweight recurrent network
16 | """
17 |
18 | def __init__(self, d, ln=False, scope='lrn'):
19 | super(lrn, self).__init__(d, ln=ln, scope=scope)
20 |
21 | def get_init_state(self, shape=None, x=None, scope=None):
22 | return self._get_init_state(
23 | self.d, shape=shape, x=x, scope=scope)
24 |
25 | def fetch_states(self, x):
26 | with tf.variable_scope(
27 | "fetch_state_{}".format(self.scope or "lrn")):
28 | h = linear(x, self.d * 3,
29 | bias=False, ln=self.ln, scope="hide_x")
30 | return (h, )
31 |
32 | def __call__(self, h_, x):
33 | # h_: the previous hidden state
34 | # p,q,r/x: the current input state
35 | """
36 | p, q, r = W x
37 | i = sigmoid(p + h_)
38 | f = sigmoid(q - h_)
39 | h = i * r + f * h_
40 | """
41 | if isinstance(x, (list, tuple)):
42 | x = x[0]
43 |
44 | with tf.variable_scope(
45 | "cell_{}".format(self.scope or "atr")):
46 | p, q, r = tf.split(x, 3, -1)
47 |
48 | i = tf.sigmoid(p + h_)
49 | f = tf.sigmoid(q - h_)
50 |
51 | h = i * r + f * h_
52 |
53 | return h
54 |
--------------------------------------------------------------------------------
/rnns/lstm.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import tensorflow as tf
8 |
9 | from func import linear
10 | from rnns import cell as cell
11 |
12 |
13 | class lstm(cell.Cell):
14 | """The Long-Short Term Memory Unit."""
15 |
16 | def __init__(self, d, ln=False, scope='lstm'):
17 | super(lstm, self).__init__(d, ln=ln, scope=scope)
18 |
19 | def get_init_state(self, shape=None, x=None, scope=None):
20 | return self._get_init_state(
21 | self.d * 2, shape=shape, x=x, scope=scope)
22 |
23 | def get_hidden(self, x):
24 | return tf.split(x, 2, -1)[0]
25 |
26 | def fetch_states(self, x):
27 | with tf.variable_scope(
28 | "fetch_state_{}".format(self.scope or "lstm")):
29 | g = linear(x, self.d * 3,
30 | bias=False, ln=self.ln, scope="gate_x")
31 | c = linear(x, self.d,
32 | bias=False, ln=self.ln, scope="hide_x")
33 | return g, c
34 |
35 | def __call__(self, h_, x):
36 | # h_: the concatenation of previous hidden state
37 | # and memory cell state
38 | # x_i/x: the current input state for input gate
39 | # x_f/x: the current input state for forget gate
40 | # x_o/x: the current input state for output gate
41 | # x_c/x: the current input state for candidate cell
42 | """
43 | f = sigmoid(h_, x)
44 | i = sigmoid(h_, x)
45 | o = sigmoid(h_, x)
46 | c' = tanh(h_, x)
47 | c = f * c_ + i * c'
48 | h = o * tanh(c)
49 | """
50 | with tf.variable_scope(
51 | "cell_{}".format(self.scope or "lstm")):
52 | x_g, x_c = x
53 | h_, c_ = tf.split(h_, 2, -1)
54 |
55 | h_g = linear(h_, self.d * 3,
56 | ln=self.ln, scope="gate_h")
57 | i, f, o = tf.split(
58 | tf.sigmoid(x_g + h_g), 3, -1)
59 |
60 | h_c = linear(h_, self.d,
61 | ln=self.ln, scope="hide_h")
62 | h_c = tf.tanh(x_c + h_c)
63 |
64 | c = i * h_c + f * c_
65 |
66 | h = o * tf.tanh(c)
67 |
68 | return tf.concat([h, c], -1)
69 |
--------------------------------------------------------------------------------
/rnns/olrn.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import tensorflow as tf
8 |
9 | from func import linear
10 | from rnns import cell as cell
11 |
12 |
13 | class olrn(cell.Cell):
14 | """The Recurrence-Free Addition-Subtraction Twin-Gated Recurrent Unit.
15 | Or, output-gated lightweight recurrent network
16 | """
17 |
18 | def __init__(self, d, ln=False, scope='olrn'):
19 | super(olrn, self).__init__(d, ln=ln, scope=scope)
20 |
21 | def get_init_state(self, shape=None, x=None, scope=None):
22 | return self._get_init_state(
23 | self.d, shape=shape, x=x, scope=scope)
24 |
25 | def fetch_states(self, x):
26 | with tf.variable_scope(
27 | "fetch_state_{}".format(self.scope or "olrn")):
28 | h = linear(x, self.d * 4,
29 | bias=False, ln=self.ln, scope="hide_x")
30 | return (h, )
31 |
32 | def __call__(self, h_, x):
33 | # h_: the previous hidden state
34 | # p,q,r,s/x: the current input state
35 | """
36 | p, q, r, s = W x
37 | i = sigmoid(p + h_)
38 | f = sigmoid(q - h_)
39 | h = i * r + f * h_
40 | o = simoid(s - h)
41 | h = o * h
42 | """
43 | if isinstance(x, (list, tuple)):
44 | x = x[0]
45 |
46 | with tf.variable_scope(
47 | "cell_{}".format(self.scope or "atr")):
48 | p, q, r, s = tf.split(x, 4, -1)
49 |
50 | i = tf.sigmoid(p + h_)
51 | f = tf.sigmoid(q - h_)
52 |
53 | h = i * r + f * h_
54 |
55 | o = tf.nn.sigmoid(s - h)
56 | h = o * h
57 |
58 | return h
59 |
--------------------------------------------------------------------------------
/rnns/rnn.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import tensorflow as tf
8 |
9 | from utils import util, dtype
10 | from rnns import get_cell
11 | from func import linear, additive_attention
12 |
13 |
14 | def rnn(cell_name, x, d, mask=None, ln=False, init_state=None, sm=True):
15 | """Self implemented RNN procedure, supporting mask trick"""
16 | # cell_name: gru, lstm or atr
17 | # x: input sequence embedding matrix, [batch, seq_len, dim]
18 | # d: hidden dimension for rnn
19 | # mask: mask matrix, [batch, seq_len]
20 | # ln: whether use layer normalization
21 | # init_state: the initial hidden states, for cache purpose
22 | # sm: whether apply swap memory during rnn scan
23 | # dp: variational dropout
24 |
25 | in_shape = util.shape_list(x)
26 | batch_size, time_steps = in_shape[:2]
27 |
28 | cell = get_cell(cell_name, d, ln=ln)
29 |
30 | if init_state is None:
31 | init_state = cell.get_init_state(shape=[batch_size])
32 | if mask is None:
33 | mask = dtype.tf_to_float(tf.ones([batch_size, time_steps]))
34 |
35 | # prepare projected input
36 | cache_inputs = cell.fetch_states(x)
37 | cache_inputs = [tf.transpose(v, [1, 0, 2])
38 | for v in list(cache_inputs)]
39 | mask_ta = tf.transpose(tf.expand_dims(mask, -1), [1, 0, 2])
40 |
41 | def _step_fn(prev, x):
42 | t, h_ = prev
43 | m = x[-1]
44 | v = x[:-1]
45 |
46 | h = cell(h_, v)
47 | h = m * h + (1. - m) * h_
48 |
49 | return t + 1, h
50 |
51 | time = tf.constant(0, dtype=tf.int32, name="time")
52 | step_states = (time, init_state)
53 | step_vars = cache_inputs + [mask_ta]
54 |
55 | outputs = tf.scan(_step_fn,
56 | step_vars,
57 | initializer=step_states,
58 | parallel_iterations=32,
59 | swap_memory=sm)
60 |
61 | output_ta = outputs[1]
62 | output_state = outputs[1][-1]
63 |
64 | outputs = tf.transpose(output_ta, [1, 0, 2])
65 |
66 | return (outputs, output_state), \
67 | (cell.get_hidden(outputs), cell.get_hidden(output_state))
68 |
69 |
70 | def cond_rnn(cell_name, x, memory, d, init_state=None,
71 | mask=None, mem_mask=None, ln=False, sm=True,
72 | one2one=False, num_heads=1):
73 | """Self implemented conditional-RNN procedure, supporting mask trick"""
74 | # cell_name: gru, lstm or atr
75 | # x: input sequence embedding matrix, [batch, seq_len, dim]
76 | # memory: the conditional part
77 | # d: hidden dimension for rnn
78 | # mask: mask matrix, [batch, seq_len]
79 | # mem_mask: memory mask matrix, [batch, mem_seq_len]
80 | # ln: whether use layer normalization
81 | # init_state: the initial hidden states, for cache purpose
82 | # sm: whether apply swap memory during rnn scan
83 | # one2one: whether the memory is one-to-one mapping for x
84 | # num_heads: number of attention heads, multi-head attention
85 | # dp: variational dropout
86 |
87 | in_shape = util.shape_list(x)
88 | batch_size, time_steps = in_shape[:2]
89 | mem_shape = util.shape_list(memory)
90 |
91 | cell_lower = get_cell(cell_name, d, ln=ln,
92 | scope="{}_lower".format(cell_name))
93 | cell_higher = get_cell(cell_name, d, ln=ln,
94 | scope="{}_higher".format(cell_name))
95 |
96 | if init_state is None:
97 | init_state = cell_lower.get_init_state(shape=[batch_size])
98 | if mask is None:
99 | mask = dtype.tf_to_float(tf.ones([batch_size, time_steps]))
100 | if mem_mask is None:
101 | mem_mask = dtype.tf_to_float(tf.ones([batch_size, mem_shape[1]]))
102 |
103 | # prepare projected encodes and inputs
104 | cache_inputs = cell_lower.fetch_states(x)
105 | cache_inputs = [tf.transpose(v, [1, 0, 2])
106 | for v in list(cache_inputs)]
107 | if not one2one:
108 | proj_memories = linear(memory, mem_shape[-1], bias=False,
109 | ln=ln, scope="context_att")
110 | else:
111 | cache_memories = cell_higher.fetch_states(memory)
112 | cache_memories = [tf.transpose(v, [1, 0, 2])
113 | for v in list(cache_memories)]
114 | mask_ta = tf.transpose(tf.expand_dims(mask, -1), [1, 0, 2])
115 | init_context = dtype.tf_to_float(tf.zeros([batch_size, mem_shape[-1]]))
116 | init_weight = dtype.tf_to_float(tf.zeros([batch_size, num_heads, mem_shape[1]]))
117 | mask_pos = len(cache_inputs)
118 |
119 | def _step_fn(prev, x):
120 | t, h_, c_, a_ = prev
121 |
122 | if not one2one:
123 | m, v = x[mask_pos], x[:mask_pos]
124 | else:
125 | c, c_c, m, v = x[-1], x[mask_pos+1:-1], x[mask_pos], x[:mask_pos]
126 |
127 | s = cell_lower(h_, v)
128 | s = m * s + (1. - m) * h_
129 |
130 | if not one2one:
131 | vle = additive_attention(
132 | cell_lower.get_hidden(s), memory, mem_mask,
133 | mem_shape[-1], ln=ln, num_heads=num_heads,
134 | proj_memory=proj_memories, scope="attention")
135 | a, c = vle['weights'], vle['output']
136 | c_c = cell_higher.fetch_states(c)
137 | else:
138 | a = tf.tile(tf.expand_dims(tf.range(time_steps), 0), [batch_size, 1])
139 | a = dtype.tf_to_float(tf.equal(a, t))
140 | a = tf.tile(tf.expand_dims(a, 1), [1, num_heads, 1])
141 | a = tf.reshape(a, tf.shape(init_weight))
142 |
143 | h = cell_higher(s, c_c)
144 | h = m * h + (1. - m) * s
145 |
146 | return t + 1, h, c, a
147 |
148 | time = tf.constant(0, dtype=tf.int32, name="time")
149 | step_states = (time, init_state, init_context, init_weight)
150 | step_vars = cache_inputs + [mask_ta]
151 | if one2one:
152 | step_vars += cache_memories + [memory]
153 |
154 | outputs = tf.scan(_step_fn,
155 | step_vars,
156 | initializer=step_states,
157 | parallel_iterations=32,
158 | swap_memory=sm)
159 |
160 | output_ta = outputs[1]
161 | context_ta = outputs[2]
162 | attention_ta = outputs[3]
163 |
164 | outputs = tf.transpose(output_ta, [1, 0, 2])
165 | output_states = outputs[:, -1]
166 | contexts = tf.transpose(context_ta, [1, 0, 2])
167 | attentions = tf.transpose(attention_ta, [1, 2, 0, 3])
168 |
169 | return (outputs, output_states), \
170 | (cell_higher.get_hidden(outputs), cell_higher.get_hidden(output_states)), \
171 | contexts, attentions
172 |
--------------------------------------------------------------------------------
/rnns/sru.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import tensorflow as tf
8 |
9 | from func import linear
10 | from rnns import cell as cell
11 |
12 |
13 | class sru(cell.Cell):
14 | """The Simple Recurrent Unit."""
15 |
16 | def __init__(self, d, ln=False, scope='sru'):
17 | super(sru, self).__init__(d, ln=ln, scope=scope)
18 |
19 | def get_init_state(self, shape=None, x=None, scope=None):
20 | return self._get_init_state(
21 | self.d * 2, shape=shape, x=x, scope=scope)
22 |
23 | def get_hidden(self, x):
24 | return tf.split(x, 2, -1)[0]
25 |
26 | def fetch_states(self, x):
27 | with tf.variable_scope(
28 | "fetch_state_{}".format(self.scope or "sru")):
29 | h = linear(x, self.d * 4,
30 | bias=False, ln=self.ln, scope="hide_x")
31 | return (h, )
32 |
33 | def __call__(self, h_, x):
34 | # h_: the concatenation of previous hidden state
35 | # and memory cell state
36 | # x_r/x: the current input state for r gate
37 | # x_f/x: the current input state for f gate
38 | # x_c/x: the current input state for candidate cell
39 | # x_h/x: the current input state for hidden output
40 | # we increase this because we do not assume that
41 | # the input dimension equals the output dimension
42 | """
43 | f = sigmoid(Wx, vf * c_)
44 | c = f * c_ + (1 - f) * Wx
45 | r = sigmoid(Wx, vr * c_)
46 | h = r * c + (1 - r) * Ux
47 | """
48 | if isinstance(x, (list, tuple)):
49 | x = x[0]
50 |
51 | with tf.variable_scope(
52 | "cell_{}".format(self.scope or "sru")):
53 | x_r, x_f, x_c, x_h = tf.split(x, 4, -1)
54 | h_, c_ = tf.split(h_, 2, -1)
55 |
56 | v_f = tf.get_variable("v_f", [1, self.d])
57 | v_r = tf.get_variable("v_r", [1, self.d])
58 |
59 | f = tf.sigmoid(x_f + v_f * c_)
60 | c = f * c_ + (1. - f) * x_c
61 | r = tf.sigmoid(x_r + v_r * c_)
62 | h = r * c + (1. - r) * x_h
63 |
64 | return tf.concat([h, c], -1)
65 |
--------------------------------------------------------------------------------
/scripts/bleu_over_length.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import math
8 | import argparse
9 |
10 | from collections import Counter
11 |
12 |
13 | def closest_length(candidate, references):
14 | clen = len(candidate)
15 | closest_diff = 9999
16 | closest_len = 9999
17 |
18 | for reference in references:
19 | rlen = len(reference)
20 | diff = abs(rlen - clen)
21 |
22 | if diff < closest_diff:
23 | closest_diff = diff
24 | closest_len = rlen
25 | elif diff == closest_diff:
26 | closest_len = rlen if rlen < closest_len else closest_len
27 |
28 | return closest_len
29 |
30 |
31 | def shortest_length(references):
32 | return min([len(ref) for ref in references])
33 |
34 |
35 | def modified_precision(candidate, references, n):
36 | tngrams = len(candidate) + 1 - n
37 | counts = Counter([tuple(candidate[i:i+n]) for i in range(tngrams)])
38 |
39 | if len(counts) == 0:
40 | return 0, 0
41 |
42 | max_counts = {}
43 | for reference in references:
44 | rngrams = len(reference) + 1 - n
45 | ngrams = [tuple(reference[i:i+n]) for i in range(rngrams)]
46 | ref_counts = Counter(ngrams)
47 | for ngram in counts:
48 | mcount = 0 if ngram not in max_counts else max_counts[ngram]
49 | rcount = 0 if ngram not in ref_counts else ref_counts[ngram]
50 | max_counts[ngram] = max(mcount, rcount)
51 |
52 | clipped_counts = {}
53 |
54 | for ngram, count in counts.items():
55 | clipped_counts[ngram] = min(count, max_counts[ngram])
56 |
57 | return float(sum(clipped_counts.values())), float(sum(counts.values()))
58 |
59 |
60 | def brevity_penalty(trans, refs, mode="closest"):
61 | bp_c = 0.0
62 | bp_r = 0.0
63 |
64 | for candidate, references in zip(trans, refs):
65 | bp_c += len(candidate)
66 |
67 | if mode == "shortest":
68 | bp_r += shortest_length(references)
69 | else:
70 | bp_r += closest_length(candidate, references)
71 |
72 | # Prevent zero divide
73 | bp_c = bp_c or 1.0
74 |
75 | return math.exp(min(0, 1.0 - bp_r / bp_c))
76 |
77 |
78 | def bleu(trans, refs, bp="closest", smooth=False, n=4, weights=None):
79 | p_norm = [0 for _ in range(n)]
80 | p_denorm = [0 for _ in range(n)]
81 |
82 | for candidate, references in zip(trans, refs):
83 | for i in range(n):
84 | ccount, tcount = modified_precision(candidate, references, i + 1)
85 | p_norm[i] += ccount
86 | p_denorm[i] += tcount
87 |
88 | bleu_n = [0 for _ in range(n)]
89 |
90 | for i in range(n):
91 | # add one smoothing
92 | if smooth and i > 0:
93 | p_norm[i] += 1
94 | p_denorm[i] += 1
95 |
96 | if p_norm[i] == 0 or p_denorm[i] == 0:
97 | bleu_n[i] = -9999
98 | else:
99 | bleu_n[i] = math.log(float(p_norm[i]) / float(p_denorm[i]))
100 |
101 | if weights:
102 | if len(weights) != n:
103 | raise ValueError("len(weights) != n: invalid weight number")
104 | log_precision = sum([bleu_n[i] * weights[i] for i in range(n)])
105 | else:
106 | log_precision = sum(bleu_n) / float(n)
107 |
108 | bp = brevity_penalty(trans, refs, bp)
109 |
110 | score = bp * math.exp(log_precision)
111 |
112 | return score
113 |
114 |
115 | def read(f, lc=False):
116 | with open(f, 'rU') as reader:
117 | return [line.strip().split() if not lc else line.strip().lower().split()
118 | for line in reader.readlines()]
119 |
120 |
121 | if __name__ == "__main__":
122 | parser = argparse.ArgumentParser(
123 | description='BLEU score over source sentence length')
124 | parser.add_argument('-lc', help='Lowercase, i.e case-insensitive setting', action='store_true')
125 | parser.add_argument('-bp', help='Length penalty', default='closest', choices=['shortest', 'closest'])
126 | parser.add_argument('-n', type=int, default=4, help="ngram-based BLEU")
127 | parser.add_argument('-g', type=int, default=1, help="sentence groups for evaluation")
128 | parser.add_argument('-source', type=str, required=True, help='The source file')
129 | parser.add_argument('-candidate', type=str, required=True, help='The candidate translation generated by MT system')
130 | parser.add_argument('-reference', type=str, nargs='+', required=True,
131 | help='The references like reference or reference0, reference1, ...')
132 |
133 | args = parser.parse_args()
134 |
135 | cand = args.candidate
136 | refs = args.reference
137 | src = args.source
138 |
139 | src_sentences = read(src, args.lc)
140 | cand_sentences = read(cand, args.lc)
141 | refs_sentences = [read(ref, args.lc) for ref in refs]
142 |
143 | assert len(cand_sentences) == len(refs_sentences[0]), \
144 | 'ERROR: the length of candidate and reference must be the same.'
145 |
146 | refs_sentences = list(zip(*refs_sentences))
147 |
148 | sorted_candidate_sentences = sorted(zip(src_sentences, cand_sentences), key=lambda x: len(x[0]))
149 | sorted_reference_sentences = sorted(zip(src_sentences, refs_sentences), key=lambda x: len(x[0]))
150 |
151 | sorted_source_sentences = [v[0] for v in sorted_candidate_sentences]
152 | sorted_candidate_sentences = [v[1] for v in sorted_candidate_sentences]
153 | sorted_reference_sentences = [v[1] for v in sorted_reference_sentences]
154 |
155 | groups = args.g
156 | elements_per_group = len(sorted_source_sentences) // groups
157 |
158 | scores = []
159 | for gidx in range(groups):
160 | group_candidate = sorted_candidate_sentences[gidx * elements_per_group: (gidx + 1) * elements_per_group]
161 | group_reference = sorted_reference_sentences[gidx * elements_per_group: (gidx + 1) * elements_per_group]
162 | group_source = sorted_source_sentences[gidx * elements_per_group: (gidx + 1) * elements_per_group]
163 |
164 | group_average_source = float(sum([len(v) for v in group_source])) / float(len(group_source))
165 | bleu_score = bleu(group_candidate, group_reference, bp=args.bp, n=args.n)
166 |
167 | print("Group Idx {} Avg Source Lenngth {} BLEU Score {}".format(gidx, group_average_source, bleu_score))
168 |
169 | scores.append((group_average_source, bleu_score))
170 |
171 | print('AvgLength: [{}]'.format(','.join([str(s[0]) for s in scores])))
172 | print('BLEU Score: [{}]'.format(','.join([str(s[1]) for s in scores])))
173 |
174 |
--------------------------------------------------------------------------------
/scripts/checkpoint_averaging.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import argparse
8 | import operator
9 | import os
10 |
11 | import numpy as np
12 | import tensorflow as tf
13 |
14 |
15 | def parseargs():
16 | msg = "Average checkpoints"
17 | usage = "average.py [] [-h | --help]"
18 | parser = argparse.ArgumentParser(description=msg, usage=usage)
19 |
20 | parser.add_argument("--path", type=str, required=True,
21 | help="checkpoint dir")
22 | parser.add_argument("--checkpoints", type=int, required=True,
23 | help="number of checkpoints to use")
24 | parser.add_argument("--output", type=str, help="output path")
25 | parser.add_argument("--gpu", type=int, default=0,
26 | help="the default gpu device index")
27 |
28 | return parser.parse_args()
29 |
30 |
31 | def get_checkpoints(path):
32 | if not tf.gfile.Exists(os.path.join(path, "checkpoint")):
33 | raise ValueError("Cannot find checkpoints in %s" % path)
34 |
35 | checkpoint_names = []
36 |
37 | with tf.gfile.GFile(os.path.join(path, "checkpoint")) as fd:
38 | # Skip the first line
39 | fd.readline()
40 | for line in fd:
41 | name = line.strip().split(":")[-1].strip()[1:-1]
42 | key = int(name.split("-")[-1])
43 | checkpoint_names.append((key, os.path.join(path, name)))
44 |
45 | sorted_names = sorted(checkpoint_names, key=operator.itemgetter(0),
46 | reverse=True)
47 |
48 | return [item[-1] for item in sorted_names]
49 |
50 |
51 | def checkpoint_exists(path):
52 | return (tf.gfile.Exists(path) or tf.gfile.Exists(path + ".meta") or
53 | tf.gfile.Exists(path + ".index"))
54 |
55 |
56 | def main(_):
57 | tf.logging.set_verbosity(tf.logging.INFO)
58 | checkpoints = get_checkpoints(FLAGS.path)
59 | checkpoints = checkpoints[:FLAGS.checkpoints]
60 |
61 | if not checkpoints:
62 | raise ValueError("No checkpoints provided for averaging.")
63 |
64 | checkpoints = [c for c in checkpoints if checkpoint_exists(c)]
65 |
66 | if not checkpoints:
67 | raise ValueError(
68 | "None of the provided checkpoints exist. %s" % FLAGS.checkpoints
69 | )
70 |
71 | var_list = tf.contrib.framework.list_variables(checkpoints[0])
72 | var_values, var_dtypes = {}, {}
73 |
74 | for (name, shape) in var_list:
75 | if not name.startswith("global_step"):
76 | var_values[name] = np.zeros(shape)
77 |
78 | for checkpoint in checkpoints:
79 | reader = tf.contrib.framework.load_checkpoint(checkpoint)
80 | for name in var_values:
81 | tensor = reader.get_tensor(name)
82 | var_dtypes[name] = tensor.dtype
83 | var_values[name] += tensor
84 | tf.logging.info("Read from checkpoint %s", checkpoint)
85 |
86 | # Average checkpoints
87 | for name in var_values:
88 | var_values[name] /= len(checkpoints)
89 |
90 | tf_vars = [
91 | tf.get_variable(name, shape=var_values[name].shape,
92 | dtype=var_dtypes[name]) for name in var_values
93 | ]
94 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
95 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
96 | global_step = tf.Variable(0, name="global_step", trainable=False,
97 | dtype=tf.int64)
98 | saver = tf.train.Saver(tf.global_variables())
99 |
100 | sess_config = tf.ConfigProto(allow_soft_placement=True)
101 | sess_config.gpu_options.allow_growth = True
102 | sess_config.gpu_options.visible_device_list = "%s" % FLAGS.gpu
103 |
104 | with tf.Session(config=sess_config) as sess:
105 | sess.run(tf.global_variables_initializer())
106 | for p, assign_op, (name, value) in zip(placeholders, assign_ops,
107 | var_values.iteritems()):
108 | sess.run(assign_op, {p: value})
109 | saved_name = os.path.join(FLAGS.output, "average")
110 | saver.save(sess, saved_name, global_step=global_step)
111 |
112 | tf.logging.info("Averaged checkpoints saved in %s", saved_name)
113 |
114 | params_pattern = os.path.join(FLAGS.path, "*.json")
115 | params_files = tf.gfile.Glob(params_pattern)
116 |
117 | for name in params_files:
118 | new_name = name.replace(FLAGS.path.rstrip("/"),
119 | FLAGS.output.rstrip("/"))
120 | tf.gfile.Copy(name, new_name, overwrite=True)
121 |
122 |
123 | if __name__ == "__main__":
124 | FLAGS = parseargs()
125 | tf.app.run()
126 |
--------------------------------------------------------------------------------
/scripts/chrF.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Author: Rico Sennrich
4 |
5 | """Compute chrF3 for machine translation evaluation
6 |
7 | Reference:
8 | Maja Popović (2015). chrF: character n-gram F-score for automatic MT evaluation. In Proceedings of the Tenth Workshop on Statistical Machine Translationn, pages 392–395, Lisbon, Portugal.
9 | """
10 |
11 | from __future__ import print_function, unicode_literals, division
12 |
13 | import sys
14 | import codecs
15 | import io
16 | import argparse
17 |
18 | from collections import defaultdict
19 |
20 | # hack for python2/3 compatibility
21 | from io import open
22 | argparse.open = open
23 |
24 | def create_parser():
25 | parser = argparse.ArgumentParser(
26 | formatter_class=argparse.RawDescriptionHelpFormatter,
27 | description="learn BPE-based word segmentation")
28 |
29 | parser.add_argument(
30 | '--ref', '-r', type=argparse.FileType('r'), required=True,
31 | metavar='PATH',
32 | help="Reference file")
33 | parser.add_argument(
34 | '--hyp', type=argparse.FileType('r'), metavar='PATH',
35 | default=sys.stdin,
36 | help="Hypothesis file (default: stdin).")
37 | parser.add_argument(
38 | '--beta', '-b', type=float, default=3,
39 | metavar='FLOAT',
40 | help="beta parameter (default: '%(default)s')")
41 | parser.add_argument(
42 | '--ngram', '-n', type=int, default=6,
43 | metavar='INT',
44 | help="ngram order (default: '%(default)s')")
45 | parser.add_argument(
46 | '--space', '-s', action='store_true',
47 | help="take spaces into account (default: '%(default)s')")
48 | parser.add_argument(
49 | '--precision', action='store_true',
50 | help="report precision (default: '%(default)s')")
51 | parser.add_argument(
52 | '--recall', action='store_true',
53 | help="report recall (default: '%(default)s')")
54 |
55 | return parser
56 |
57 | def extract_ngrams(words, max_length=4, spaces=False):
58 |
59 | if not spaces:
60 | words = ''.join(words.split())
61 | else:
62 | words = words.strip()
63 |
64 | results = defaultdict(lambda: defaultdict(int))
65 | for length in range(max_length):
66 | for start_pos in range(len(words)):
67 | end_pos = start_pos + length + 1
68 | if end_pos <= len(words):
69 | results[length][tuple(words[start_pos: end_pos])] += 1
70 | return results
71 |
72 |
73 | def get_correct(ngrams_ref, ngrams_test, correct, total):
74 |
75 | for rank in ngrams_test:
76 | for chain in ngrams_test[rank]:
77 | total[rank] += ngrams_test[rank][chain]
78 | if chain in ngrams_ref[rank]:
79 | correct[rank] += min(ngrams_test[rank][chain], ngrams_ref[rank][chain])
80 |
81 | return correct, total
82 |
83 |
84 | def f1(correct, total_hyp, total_ref, max_length, beta=3, smooth=0):
85 |
86 | precision = 0
87 | recall = 0
88 |
89 | for i in range(max_length):
90 | if total_hyp[i] + smooth and total_ref[i] + smooth:
91 | precision += (correct[i] + smooth) / (total_hyp[i] + smooth)
92 | recall += (correct[i] + smooth) / (total_ref[i] + smooth)
93 |
94 | precision /= max_length
95 | recall /= max_length
96 |
97 | return (1 + beta**2) * (precision*recall) / ((beta**2 * precision) + recall), precision, recall
98 |
99 | def main(args):
100 |
101 | correct = [0]*args.ngram
102 | total = [0]*args.ngram
103 | total_ref = [0]*args.ngram
104 | for line in args.ref:
105 | line2 = args.hyp.readline()
106 |
107 | ngrams_ref = extract_ngrams(line, max_length=args.ngram, spaces=args.space)
108 | ngrams_test = extract_ngrams(line2, max_length=args.ngram, spaces=args.space)
109 |
110 | get_correct(ngrams_ref, ngrams_test, correct, total)
111 |
112 | for rank in ngrams_ref:
113 | for chain in ngrams_ref[rank]:
114 | total_ref[rank] += ngrams_ref[rank][chain]
115 |
116 | chrf, precision, recall = f1(correct, total, total_ref, args.ngram, args.beta)
117 |
118 | print('chrF3: {0:.4f}'.format(chrf))
119 | if args.precision:
120 | print('chrPrec: {0:.4f}'.format(precision))
121 | if args.recall:
122 | print('chrRec: {0:.4f}'.format(recall))
123 |
124 | if __name__ == '__main__':
125 |
126 | # python 2/3 compatibility
127 | if sys.version_info < (3, 0):
128 | sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
129 | sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
130 | sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
131 | else:
132 | sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
133 | sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
134 | sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', write_through=True, line_buffering=True)
135 |
136 | parser = create_parser()
137 | args = parser.parse_args()
138 |
139 | main(args)
140 |
--------------------------------------------------------------------------------
/scripts/evaluate_pos_translation_rate.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import argparse
8 |
9 | import sys
10 | from collections import Counter
11 |
12 |
13 | def parseargs():
14 | msg = "Evlauate P/R/F score for particular POS Tagged Tokens"
15 | usage = "{} [] [-h | --help]".format(sys.argv[0])
16 | parser = argparse.ArgumentParser(description=msg, usage=usage)
17 |
18 | parser.add_argument("--trans", type=str, required=True,
19 | help="model translation")
20 | parser.add_argument("--refs", type=str, required=True, nargs="+",
21 | help="gold reference, one or more")
22 | parser.add_argument("--ngram", type=int, default=4,
23 | help="the maximum n for n-gram")
24 |
25 | parser.add_argument_group("POS setting")
26 | parser.add_argument("--noun", type=str, default="NN",
27 | help="the pos label for noun")
28 | parser.add_argument("--verb", type=str, default="VB",
29 | help="the pos label for verb")
30 | parser.add_argument("--adj", type=str, default="JJ",
31 | help="the pos label for adjective")
32 | parser.add_argument("--adv", type=str, default="RB",
33 | help="the pos label for adverb")
34 |
35 | parser.add_argument("--spliter", type=str, default="_",
36 | help="the spliter between word and pos label")
37 |
38 | return parser.parse_args()
39 |
40 |
41 | # POS conversion module
42 | def prepare_ngram(txt, pos, ngram):
43 | tokens = txt.strip().split()
44 |
45 | words = []
46 | for token in tokens:
47 | if type(pos) is not list and pos in token:
48 | segs = token.strip().split('_')
49 | word = '_'.join(segs[:-1])
50 | words.append(word)
51 | elif type(pos) is list:
52 | cvt = False
53 | for p in pos:
54 | if p in token:
55 | cvt = True
56 | break
57 | if cvt:
58 | segs = token.strip().split('_')
59 | word = '_'.join(segs[:-1])
60 | words.append(word)
61 | else:
62 | words.append('')
63 |
64 | _ngram_list = []
65 | for ngidx in range(ngram, len(words)):
66 | _ngram_list.append(' '.join(words[ngidx - ngram:ngidx]))
67 | ngram_list = [ng for ng in _ngram_list if '' not in ng]
68 |
69 | return Counter(ngram_list)
70 |
71 |
72 | def convert_corpus(dataset, pos, ngram):
73 | return [prepare_ngram(data, pos, ngram) for data in dataset]
74 |
75 |
76 | def score(trans, refs):
77 |
78 | def _precision_recall_fvalue(_trans, _ref):
79 | t_cngrams = 0.
80 | t_rngrams = 0.
81 | m_ngrams = 0.
82 |
83 | for cngrams, rngrams in zip(_trans, _ref):
84 |
85 | t_cngrams += sum(cngrams.values())
86 | t_rngrams += sum(rngrams.values())
87 |
88 | for ngram in cngrams:
89 | if ngram in rngrams:
90 | m_ngrams += min(cngrams[ngram], rngrams[ngram])
91 |
92 | precision = m_ngrams / t_cngrams if t_cngrams > 0 else 0.
93 | recall = m_ngrams / t_rngrams if t_rngrams > 0 else 0.
94 | fvalue = 2 * (recall * precision) / (recall + precision + 1e-8)
95 |
96 | return precision, recall, fvalue
97 |
98 | eval_scores = [_precision_recall_fvalue(trans, ref) for ref in refs]
99 | eval_scores = list(zip(*eval_scores))
100 | return [sum(v) / len(v) for v in eval_scores]
101 |
102 |
103 | def evaluate_the_rate_of_specific_gram(ref, trs, pos, ngram):
104 | # ref: reference corpus
105 | # trs: translation corpus
106 | # pos: part-of-speech tag
107 | # ngram: n-gram number
108 |
109 | references = [convert_corpus(r, pos, ngram) for r in ref]
110 | candidate = convert_corpus(trs, pos, ngram)
111 |
112 | result = score(candidate, references)
113 |
114 | return pos, ngram, result
115 |
116 |
117 | if __name__ == "__main__":
118 | params = parseargs()
119 |
120 | # loading the reference corpus
121 | corpus = []
122 | for trans_txt in params.refs:
123 | with open(trans_txt, 'rU') as reader:
124 | corpus.append(reader.readlines())
125 | if len(corpus) > 1:
126 | for cidx in range(1, len(corpus)):
127 | assert len(corpus[cidx]) == len(corpus[cidx - 1]), 'the length of each reference text must be the same'
128 |
129 | # the focused translation corpus
130 | with open(params.trans, 'rU') as reader:
131 | test = reader.readlines()
132 | assert len(test) == len(corpus[0]), \
133 | 'the length of translation text should be the same as that of reference text'
134 |
135 | poses = [params.noun,
136 | params.verb,
137 | params.adj,
138 | params.adv,
139 | [params.noun, params.verb],
140 | [params.noun, params.verb, params.adj]]
141 | ngrams = range(params.ngram)
142 | for pos in poses:
143 | for ngram in ngrams:
144 | pos, ngram, evals = evaluate_the_rate_of_specific_gram(corpus, test, pos, ngram)
145 | print('Pos: %s, Ngram: %s, Score %s' % (pos, ngram + 1, str(evals)))
146 |
--------------------------------------------------------------------------------
/scripts/multi-bleu-detok.perl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env perl
2 | #
3 | # This file is part of moses. Its use is licensed under the GNU Lesser General
4 | # Public License version 2.1 or, at your option, any later version.
5 |
6 | # This file uses the internal tokenization of mteval-v13a.pl,
7 | # giving the exact same (case-sensitive) results on untokenized text.
8 | # Using this script with detokenized output and untokenized references is
9 | # preferrable over multi-bleu.perl, since scores aren't affected by tokenization differences.
10 | #
11 | # like multi-bleu.perl , it supports plain text input and multiple references.
12 |
13 | # $Id$
14 | use warnings;
15 | use strict;
16 |
17 | binmode(STDIN, ":utf8");
18 | use open ':encoding(UTF-8)';
19 |
20 | my $lowercase = 0;
21 | if ($ARGV[0] eq "-lc") {
22 | $lowercase = 1;
23 | shift;
24 | }
25 |
26 | my $stem = $ARGV[0];
27 | if (!defined $stem) {
28 | print STDERR "usage: multi-bleu-detok.pl [-lc] reference < hypothesis\n";
29 | print STDERR "Reads the references from reference or reference0, reference1, ...\n";
30 | exit(1);
31 | }
32 |
33 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0";
34 |
35 | my @REF;
36 | my $ref=0;
37 | while(-e "$stem$ref") {
38 | &add_to_ref("$stem$ref",\@REF);
39 | $ref++;
40 | }
41 | &add_to_ref($stem,\@REF) if -e $stem;
42 | die("ERROR: could not find reference file $stem") unless scalar @REF;
43 |
44 | # add additional references explicitly specified on the command line
45 | shift;
46 | foreach my $stem (@ARGV) {
47 | &add_to_ref($stem,\@REF) if -e $stem;
48 | }
49 |
50 |
51 |
52 | sub add_to_ref {
53 | my ($file,$REF) = @_;
54 | my $s=0;
55 | if ($file =~ /.gz$/) {
56 | open(REF,"gzip -dc $file|") or die "Can't read $file";
57 | } else {
58 | open(REF,$file) or die "Can't read $file";
59 | }
60 | while([) {
61 | chop;
62 | $_ = tokenization($_);
63 | push @{$$REF[$s++]}, $_;
64 | }
65 | close(REF);
66 | }
67 |
68 | my(@CORRECT,@TOTAL,$length_translation,$length_reference);
69 | my $s=0;
70 | while() {
71 | chop;
72 | $_ = lc if $lowercase;
73 | $_ = tokenization($_);
74 | my @WORD = split;
75 | my %REF_NGRAM = ();
76 | my $length_translation_this_sentence = scalar(@WORD);
77 | my ($closest_diff,$closest_length) = (9999,9999);
78 | foreach my $reference (@{$REF[$s]}) {
79 | # print "$s $_ <=> $reference\n";
80 | $reference = lc($reference) if $lowercase;
81 | my @WORD = split(' ',$reference);
82 | my $length = scalar(@WORD);
83 | my $diff = abs($length_translation_this_sentence-$length);
84 | if ($diff < $closest_diff) {
85 | $closest_diff = $diff;
86 | $closest_length = $length;
87 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n";
88 | } elsif ($diff == $closest_diff) {
89 | $closest_length = $length if $length < $closest_length;
90 | # from two references with the same closeness to me
91 | # take the *shorter* into account, not the "first" one.
92 | }
93 | for(my $n=1;$n<=4;$n++) {
94 | my %REF_NGRAM_N = ();
95 | for(my $start=0;$start<=$#WORD-($n-1);$start++) {
96 | my $ngram = "$n";
97 | for(my $w=0;$w<$n;$w++) {
98 | $ngram .= " ".$WORD[$start+$w];
99 | }
100 | $REF_NGRAM_N{$ngram}++;
101 | }
102 | foreach my $ngram (keys %REF_NGRAM_N) {
103 | if (!defined($REF_NGRAM{$ngram}) ||
104 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) {
105 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram};
106 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}]
\n";
107 | }
108 | }
109 | }
110 | }
111 | $length_translation += $length_translation_this_sentence;
112 | $length_reference += $closest_length;
113 | for(my $n=1;$n<=4;$n++) {
114 | my %T_NGRAM = ();
115 | for(my $start=0;$start<=$#WORD-($n-1);$start++) {
116 | my $ngram = "$n";
117 | for(my $w=0;$w<$n;$w++) {
118 | $ngram .= " ".$WORD[$start+$w];
119 | }
120 | $T_NGRAM{$ngram}++;
121 | }
122 | foreach my $ngram (keys %T_NGRAM) {
123 | $ngram =~ /^(\d+) /;
124 | my $n = $1;
125 | # my $corr = 0;
126 | # print "$i e $ngram $T_NGRAM{$ngram}
\n";
127 | $TOTAL[$n] += $T_NGRAM{$ngram};
128 | if (defined($REF_NGRAM{$ngram})) {
129 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) {
130 | $CORRECT[$n] += $T_NGRAM{$ngram};
131 | # $corr = $T_NGRAM{$ngram};
132 | # print "$i e correct1 $T_NGRAM{$ngram}
\n";
133 | }
134 | else {
135 | $CORRECT[$n] += $REF_NGRAM{$ngram};
136 | # $corr = $REF_NGRAM{$ngram};
137 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n";
138 | }
139 | }
140 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram};
141 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n"
142 | }
143 | }
144 | $s++;
145 | }
146 | my $brevity_penalty = 1;
147 | my $bleu = 0;
148 |
149 | my @bleu=();
150 |
151 | for(my $n=1;$n<=4;$n++) {
152 | if (defined ($TOTAL[$n])){
153 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0;
154 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n";
155 | }else{
156 | $bleu[$n]=0;
157 | }
158 | }
159 |
160 | if ($length_reference==0){
161 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n";
162 | exit(1);
163 | }
164 |
165 | if ($length_translation<$length_reference) {
166 | $brevity_penalty = exp(1-$length_reference/$length_translation);
167 | }
168 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) +
169 | my_log( $bleu[2] ) +
170 | my_log( $bleu[3] ) +
171 | my_log( $bleu[4] ) ) / 4) ;
172 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n",
173 | 100*$bleu,
174 | 100*$bleu[1],
175 | 100*$bleu[2],
176 | 100*$bleu[3],
177 | 100*$bleu[4],
178 | $brevity_penalty,
179 | $length_translation / $length_reference,
180 | $length_translation,
181 | $length_reference;
182 |
183 | sub my_log {
184 | return -9999999999 unless $_[0];
185 | return log($_[0]);
186 | }
187 |
188 |
189 |
190 | sub tokenization
191 | {
192 | my ($norm_text) = @_;
193 |
194 | # language-independent part:
195 | $norm_text =~ s///g; # strip "skipped" tags
196 | $norm_text =~ s/-\n//g; # strip end-of-line hyphenation and join lines
197 | $norm_text =~ s/\n/ /g; # join lines
198 | $norm_text =~ s/"/"/g; # convert SGML tag for quote to "
199 | $norm_text =~ s/&/&/g; # convert SGML tag for ampersand to &
200 | $norm_text =~ s/</
201 | $norm_text =~ s/>/>/g; # convert SGML tag for greater-than to <
202 |
203 | # language-dependent part (assuming Western languages):
204 | $norm_text = " $norm_text ";
205 | $norm_text =~ s/([\{-\~\[-\` -\&\(-\+\:-\@\/])/ $1 /g; # tokenize punctuation
206 | $norm_text =~ s/([^0-9])([\.,])/$1 $2 /g; # tokenize period and comma unless preceded by a digit
207 | $norm_text =~ s/([\.,])([^0-9])/ $1 $2/g; # tokenize period and comma unless followed by a digit
208 | $norm_text =~ s/([0-9])(-)/$1 $2 /g; # tokenize dash when preceded by a digit
209 | $norm_text =~ s/\s+/ /g; # one space only between words
210 | $norm_text =~ s/^\s+//; # no leading space
211 | $norm_text =~ s/\s+$//; # no trailing space
212 |
213 | return $norm_text;
214 | }
215 |
--------------------------------------------------------------------------------
/scripts/multi-bleu.perl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env perl
2 | #
3 | # This file is part of moses. Its use is licensed under the GNU Lesser General
4 | # Public License version 2.1 or, at your option, any later version.
5 |
6 | # $Id$
7 | use warnings;
8 | use strict;
9 |
10 | my $lowercase = 0;
11 | if ($ARGV[0] eq "-lc") {
12 | $lowercase = 1;
13 | shift;
14 | }
15 |
16 | my $stem = $ARGV[0];
17 | if (!defined $stem) {
18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n";
19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n";
20 | exit(1);
21 | }
22 |
23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0";
24 |
25 | my @REF;
26 | my $ref=0;
27 | while(-e "$stem$ref") {
28 | &add_to_ref("$stem$ref",\@REF);
29 | $ref++;
30 | }
31 | &add_to_ref($stem,\@REF) if -e $stem;
32 | die("ERROR: could not find reference file $stem") unless scalar @REF;
33 |
34 | # add additional references explicitly specified on the command line
35 | shift;
36 | foreach my $stem (@ARGV) {
37 | &add_to_ref($stem,\@REF) if -e $stem;
38 | }
39 |
40 |
41 |
42 | sub add_to_ref {
43 | my ($file,$REF) = @_;
44 | my $s=0;
45 | if ($file =~ /.gz$/) {
46 | open(REF,"gzip -dc $file|") or die "Can't read $file";
47 | } else {
48 | open(REF,$file) or die "Can't read $file";
49 | }
50 | while([) {
51 | chomp;
52 | push @{$$REF[$s++]}, $_;
53 | }
54 | close(REF);
55 | }
56 |
57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference);
58 | my $s=0;
59 | while() {
60 | chomp;
61 | $_ = lc if $lowercase;
62 | my @WORD = split;
63 | my %REF_NGRAM = ();
64 | my $length_translation_this_sentence = scalar(@WORD);
65 | my ($closest_diff,$closest_length) = (9999,9999);
66 | foreach my $reference (@{$REF[$s]}) {
67 | # print "$s $_ <=> $reference\n";
68 | $reference = lc($reference) if $lowercase;
69 | my @WORD = split(' ',$reference);
70 | my $length = scalar(@WORD);
71 | my $diff = abs($length_translation_this_sentence-$length);
72 | if ($diff < $closest_diff) {
73 | $closest_diff = $diff;
74 | $closest_length = $length;
75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n";
76 | } elsif ($diff == $closest_diff) {
77 | $closest_length = $length if $length < $closest_length;
78 | # from two references with the same closeness to me
79 | # take the *shorter* into account, not the "first" one.
80 | }
81 | for(my $n=1;$n<=4;$n++) {
82 | my %REF_NGRAM_N = ();
83 | for(my $start=0;$start<=$#WORD-($n-1);$start++) {
84 | my $ngram = "$n";
85 | for(my $w=0;$w<$n;$w++) {
86 | $ngram .= " ".$WORD[$start+$w];
87 | }
88 | $REF_NGRAM_N{$ngram}++;
89 | }
90 | foreach my $ngram (keys %REF_NGRAM_N) {
91 | if (!defined($REF_NGRAM{$ngram}) ||
92 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) {
93 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram};
94 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}]
\n";
95 | }
96 | }
97 | }
98 | }
99 | $length_translation += $length_translation_this_sentence;
100 | $length_reference += $closest_length;
101 | for(my $n=1;$n<=4;$n++) {
102 | my %T_NGRAM = ();
103 | for(my $start=0;$start<=$#WORD-($n-1);$start++) {
104 | my $ngram = "$n";
105 | for(my $w=0;$w<$n;$w++) {
106 | $ngram .= " ".$WORD[$start+$w];
107 | }
108 | $T_NGRAM{$ngram}++;
109 | }
110 | foreach my $ngram (keys %T_NGRAM) {
111 | $ngram =~ /^(\d+) /;
112 | my $n = $1;
113 | # my $corr = 0;
114 | # print "$i e $ngram $T_NGRAM{$ngram}
\n";
115 | $TOTAL[$n] += $T_NGRAM{$ngram};
116 | if (defined($REF_NGRAM{$ngram})) {
117 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) {
118 | $CORRECT[$n] += $T_NGRAM{$ngram};
119 | # $corr = $T_NGRAM{$ngram};
120 | # print "$i e correct1 $T_NGRAM{$ngram}
\n";
121 | }
122 | else {
123 | $CORRECT[$n] += $REF_NGRAM{$ngram};
124 | # $corr = $REF_NGRAM{$ngram};
125 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n";
126 | }
127 | }
128 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram};
129 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n"
130 | }
131 | }
132 | $s++;
133 | }
134 | my $brevity_penalty = 1;
135 | my $bleu = 0;
136 |
137 | my @bleu=();
138 |
139 | for(my $n=1;$n<=4;$n++) {
140 | if (defined ($TOTAL[$n])){
141 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0;
142 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n";
143 | }else{
144 | $bleu[$n]=0;
145 | }
146 | }
147 |
148 | if ($length_reference==0){
149 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n";
150 | exit(1);
151 | }
152 |
153 | if ($length_translation<$length_reference) {
154 | $brevity_penalty = exp(1-$length_reference/$length_translation);
155 | }
156 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) +
157 | my_log( $bleu[2] ) +
158 | my_log( $bleu[3] ) +
159 | my_log( $bleu[4] ) ) / 4) ;
160 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n",
161 | 100*$bleu,
162 | 100*$bleu[1],
163 | 100*$bleu[2],
164 | 100*$bleu[3],
165 | 100*$bleu[4],
166 | $brevity_penalty,
167 | $length_translation / $length_reference,
168 | $length_translation,
169 | $length_reference;
170 |
171 |
172 | print STDERR "It is not advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n";
173 |
174 | sub my_log {
175 | return -9999999999 unless $_[0];
176 | return log($_[0]);
177 | }
178 |
--------------------------------------------------------------------------------
/scripts/shuffle_corpus.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import argparse
8 | import numpy
9 |
10 | """Copyright 2018 The THUMT Authors"""
11 |
12 |
13 | def parseargs():
14 | parser = argparse.ArgumentParser(description="Shuffle corpus")
15 |
16 | parser.add_argument("--corpus", nargs="+", required=True,
17 | help="input corpora")
18 | parser.add_argument("--suffix", type=str, default="shuf",
19 | help="Suffix of output files")
20 | parser.add_argument("--seed", type=int, help="Random seed")
21 |
22 | return parser.parse_args()
23 |
24 |
25 | def main(args):
26 | name = args.corpus
27 | suffix = "." + args.suffix
28 | stream = [open(item, "r") for item in name]
29 | data = [fd.readlines() for fd in stream]
30 | minlen = min([len(lines) for lines in data])
31 |
32 | if args.seed:
33 | numpy.random.seed(args.seed)
34 |
35 | indices = numpy.arange(minlen)
36 | numpy.random.shuffle(indices)
37 |
38 | newstream = [open(item + suffix, "w") for item in name]
39 |
40 | for idx in indices.tolist():
41 | lines = [item[idx] for item in data]
42 |
43 | for line, fd in zip(lines, newstream):
44 | fd.write(line)
45 |
46 | for fdr, fdw in zip(stream, newstream):
47 | fdr.close()
48 | fdw.close()
49 |
50 |
51 | if __name__ == "__main__":
52 | parsed_args = parseargs()
53 | main(parsed_args)
54 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
--------------------------------------------------------------------------------
/utils/cycle.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import tensorflow as tf
8 | from utils import dtype
9 |
10 |
11 | def _zero_variables(variables, name=None):
12 | ops = []
13 |
14 | for var in variables:
15 | with tf.device(var.device):
16 | op = var.assign(tf.zeros_like(var))
17 | ops.append(op)
18 |
19 | return tf.group(*ops, name=name or "zero_variables")
20 |
21 |
22 | def _replicate_variables(variables, device=None, suffix="Replica"):
23 | new_vars = []
24 |
25 | for var in variables:
26 | device = device or var.device
27 | with tf.device(device):
28 | name = var.op.name + "/{}".format(suffix)
29 | new_vars.append(tf.Variable(tf.zeros_like(var),
30 | name=name, trainable=False))
31 |
32 | return new_vars
33 |
34 |
35 | def _collect_gradients(gradients, variables):
36 | ops = []
37 |
38 | for grad, var in zip(gradients, variables):
39 | if isinstance(grad, tf.Tensor):
40 | ops.append(tf.assign_add(var, grad))
41 | else:
42 | ops.append(tf.scatter_add(var, grad.indices, grad.values))
43 |
44 | return tf.group(*ops, name="collect_gradients")
45 |
46 |
47 | def create_train_op(named_scalars, grads_and_vars, optimizer, global_step, params):
48 | tf.get_variable_scope().set_dtype(tf.as_dtype(dtype.floatx()))
49 |
50 | gradients = [item[0] for item in grads_and_vars]
51 | variables = [item[1] for item in grads_and_vars]
52 |
53 | if params.update_cycle == 1:
54 | zero_variables_op = tf.no_op("zero_variables")
55 | collect_op = tf.no_op("collect_op")
56 | else:
57 | named_vars = {}
58 | for name in named_scalars:
59 | named_var = tf.Variable(tf.zeros([], dtype=tf.float32),
60 | name="{}/CTrainOpReplica".format(name),
61 | trainable=False)
62 | named_vars[name] = named_var
63 | count_var = tf.Variable(tf.zeros([], dtype=tf.as_dtype(dtype.floatx())),
64 | name="count/CTrainOpReplica",
65 | trainable=False)
66 | slot_variables = _replicate_variables(variables, suffix='CTrainOpReplica')
67 | zero_variables_op = _zero_variables(
68 | slot_variables + [count_var] + list(named_vars.values()))
69 |
70 | collect_ops = []
71 | # collect gradients
72 | collect_grads_op = _collect_gradients(gradients, slot_variables)
73 | collect_ops.append(collect_grads_op)
74 |
75 | # collect other scalars
76 | for name in named_scalars:
77 | scalar = named_scalars[name]
78 | named_var = named_vars[name]
79 | collect_op = tf.assign_add(named_var, scalar)
80 | collect_ops.append(collect_op)
81 | # collect counting variable
82 | collect_count_op = tf.assign_add(count_var, 1.0)
83 | collect_ops.append(collect_count_op)
84 |
85 | collect_op = tf.group(*collect_ops, name="collect_op")
86 | scale = 1.0 / (tf.cast(count_var, tf.float32) + 1.0)
87 | gradients = [scale * (g + s)
88 | for (g, s) in zip(gradients, slot_variables)]
89 |
90 | for name in named_scalars:
91 | named_scalars[name] = scale * (
92 | named_scalars[name] + named_vars[name])
93 |
94 | grand_norm = tf.global_norm(gradients)
95 | param_norm = tf.global_norm(variables)
96 |
97 | # Gradient clipping
98 | if isinstance(params.clip_grad_norm or None, float):
99 | gradients, _ = tf.clip_by_global_norm(gradients,
100 | params.clip_grad_norm,
101 | use_norm=grand_norm)
102 |
103 | # Update variables
104 | grads_and_vars = list(zip(gradients, variables))
105 | train_op = optimizer.apply_gradients(grads_and_vars, global_step)
106 |
107 | ops = {
108 | "zero_op": zero_variables_op,
109 | "collect_op": collect_op,
110 | "train_op": train_op
111 | }
112 |
113 | # apply ema
114 | if params.ema_decay > 0.:
115 | tf.logging.info('Using Exp Moving Average to train the model with decay {}.'.format(params.ema_decay))
116 | ema = tf.train.ExponentialMovingAverage(decay=params.ema_decay, num_updates=global_step)
117 | ema_op = ema.apply(variables)
118 | with tf.control_dependencies([ops['train_op']]):
119 | ops['train_op'] = tf.group(ema_op)
120 | bck_vars = _replicate_variables(variables, suffix="CTrainOpBackUpReplica")
121 |
122 | ops['ema_backup_op'] = tf.group(*(tf.assign(bck, var.read_value())
123 | for bck, var in zip(bck_vars, variables)))
124 | ops['ema_restore_op'] = tf.group(*(tf.assign(var, bck.read_value())
125 | for bck, var in zip(bck_vars, variables)))
126 | ops['ema_assign_op'] = tf.group(*(tf.assign(var, ema.average(var).read_value())
127 | for var in variables))
128 |
129 | ret = named_scalars
130 | ret.update({
131 | "gradient_norm": grand_norm,
132 | "parameter_norm": param_norm,
133 | })
134 |
135 | return ret, ops
136 |
--------------------------------------------------------------------------------
/utils/dtype.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import numpy as np
8 | import tensorflow as tf
9 |
10 | # Copied from Keras
11 |
12 | # the type of float to use throughout the session.
13 | _FLOATX = 'float32'
14 | _EPSILON = 1e-8
15 | _INF = 1e8
16 |
17 |
18 | def epsilon():
19 | return _EPSILON
20 |
21 |
22 | def set_epsilon(e):
23 | global _EPSILON
24 | _EPSILON = e
25 |
26 |
27 | def inf():
28 | return _INF
29 |
30 |
31 | def set_inf(e):
32 | global _INF
33 | _INF = e
34 |
35 |
36 | def floatx():
37 | return _FLOATX
38 |
39 |
40 | def set_floatx(floatx):
41 | global _FLOATX
42 | if floatx not in {'float16', 'float32', 'float64'}:
43 | raise ValueError('Unknown floatx type: ' + str(floatx))
44 | _FLOATX = str(floatx)
45 |
46 |
47 | def np_to_float(x):
48 | return np.asarray(x, dtype=_FLOATX)
49 |
50 |
51 | def tf_to_float(x):
52 | return tf.cast(x, tf.as_dtype(floatx()))
53 |
54 |
55 | def float32_variable_storage_getter(getter, name, shape=None, dtype=None,
56 | initializer=None, regularizer=None,
57 | trainable=True,
58 | *args, **kwargs):
59 | """Custom variable getter that forces trainable variables to be stored in
60 | float32 precision and then casts them to the training precision.
61 | """
62 | storage_dtype = tf.float32 if trainable else dtype
63 | variable = getter(name, shape, dtype=storage_dtype,
64 | initializer=initializer, regularizer=regularizer,
65 | trainable=trainable,
66 | *args, **kwargs)
67 | if trainable and dtype != tf.float32:
68 | variable = tf.cast(variable, dtype)
69 | return variable
70 |
--------------------------------------------------------------------------------
/utils/parallel.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import six
8 | import tensorflow as tf
9 | import tensorflow.contrib as tc
10 |
11 | from tensorflow.python.training import device_setter
12 | from tensorflow.python.framework import device as pydev
13 | from tensorflow.core.framework import node_def_pb2
14 |
15 | from utils import util, dtype
16 |
17 |
18 | def local_device_setter(num_devices=1,
19 | ps_device_type='cpu',
20 | worker_device='/cpu:0',
21 | ps_ops=None,
22 | ps_strategy=None):
23 | if ps_ops is None:
24 | ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']
25 |
26 | if ps_strategy is None:
27 | ps_strategy = device_setter._RoundRobinStrategy(num_devices)
28 | if not six.callable(ps_strategy):
29 | raise TypeError("ps_strategy must be callable")
30 |
31 | def _local_device_chooser(op):
32 | current_device = pydev.DeviceSpec.from_string(op.device or "")
33 |
34 | node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
35 | if node_def.op in ps_ops:
36 | ps_device_spec = pydev.DeviceSpec.from_string(
37 | '/{}:{}'.format(ps_device_type, ps_strategy(op)))
38 |
39 | ps_device_spec.merge_from(current_device)
40 | return ps_device_spec.to_string()
41 | else:
42 | worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "")
43 | worker_device_spec.merge_from(current_device)
44 | return worker_device_spec.to_string()
45 |
46 | return _local_device_chooser
47 |
48 |
49 | def _maybe_repeat(x, n):
50 | if isinstance(x, list):
51 | assert len(x) == n
52 | return x
53 | else:
54 | return [x] * n
55 |
56 |
57 | def _reshape_output(outputs):
58 | # assumption: or outputs[0] are all tensor lists/tuples,
59 | # or outputs[0] are dictionaries
60 | if isinstance(outputs[0], (tuple, list)):
61 | outputs = list(zip(*outputs))
62 | outputs = tuple([list(o) for o in outputs])
63 | else:
64 | if not isinstance(outputs[0], dict):
65 | return outputs
66 |
67 | assert isinstance(outputs[0], dict), \
68 | 'invalid data type %s' % type(outputs[0])
69 |
70 | combine_outputs = {}
71 | for key in outputs[0]:
72 | combine_outputs[key] = [o[key] for o in outputs]
73 | outputs = combine_outputs
74 |
75 | return outputs
76 |
77 |
78 | # Data-level parallelism
79 | def data_parallelism(device_type, num_devices, fn, *args, **kwargs):
80 | # Replicate args and kwargs
81 | if args:
82 | new_args = [_maybe_repeat(arg, num_devices) for arg in args]
83 | # Transpose
84 | new_args = [list(x) for x in zip(*new_args)]
85 | else:
86 | new_args = [[] for _ in range(num_devices)]
87 |
88 | new_kwargs = [{} for _ in range(num_devices)]
89 |
90 | for k, v in kwargs.items():
91 | vals = _maybe_repeat(v, num_devices)
92 |
93 | for i in range(num_devices):
94 | new_kwargs[i][k] = vals[i]
95 |
96 | fns = _maybe_repeat(fn, num_devices)
97 |
98 | # Now make the parallel call.
99 | outputs = []
100 | for i in range(num_devices):
101 | worker = "/{}:{}".format(device_type, i)
102 | if device_type == 'cpu':
103 | _device_setter = local_device_setter(worker_device=worker)
104 | else:
105 | _device_setter = local_device_setter(
106 | ps_device_type='gpu',
107 | worker_device=worker,
108 | ps_strategy=tc.training.GreedyLoadBalancingStrategy(
109 | num_devices, tc.training.byte_size_load_fn)
110 | )
111 |
112 | with tf.variable_scope(tf.get_variable_scope(), reuse=bool(i != 0),
113 | dtype=tf.as_dtype(dtype.floatx())):
114 | with tf.name_scope("tower_%d" % i):
115 | with tf.device(_device_setter):
116 | outputs.append(fns[i](*new_args[i], **new_kwargs[i]))
117 |
118 | return _reshape_output(outputs)
119 |
120 |
121 | def parallel_model(model_fn, features, devices, use_cpu=False):
122 | device_type = 'gpu'
123 | num_devices = len(devices)
124 |
125 | if use_cpu:
126 | device_type = 'cpu'
127 | num_devices = 1
128 |
129 | outputs = data_parallelism(device_type, num_devices, model_fn, features)
130 |
131 | return outputs
132 |
133 |
134 | def average_gradients(tower_grads, mask=None):
135 | """Modified from Bilm"""
136 |
137 | # optimizer for single device
138 | if len(tower_grads) == 1:
139 | return tower_grads[0]
140 |
141 | # calculate average gradient for each shared variable across all GPUs
142 | def _deduplicate_indexed_slices(values, indices):
143 | """Sums `values` associated with any non-unique `indices`."""
144 | unique_indices, new_index_positions = tf.unique(indices)
145 | summed_values = tf.unsorted_segment_sum(
146 | values, new_index_positions,
147 | tf.shape(unique_indices)[0])
148 | return summed_values, unique_indices
149 |
150 | average_grads = []
151 | for grad_and_vars in zip(*tower_grads):
152 | # Note that each grad_and_vars looks like the following:
153 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
154 | # We need to average the gradients across each GPU.
155 |
156 | g0, v0 = grad_and_vars[0]
157 |
158 | if g0 is None:
159 | # no gradient for this variable, skip it
160 | tf.logging.warn("{} has no gradient".format(v0.name))
161 | average_grads.append((g0, v0))
162 | continue
163 |
164 | if isinstance(g0, tf.IndexedSlices):
165 | # If the gradient is type IndexedSlices then this is a sparse
166 | # gradient with attributes indices and values.
167 | # To average, need to concat them individually then create
168 | # a new IndexedSlices object.
169 | indices = []
170 | values = []
171 | for g, v in grad_and_vars:
172 | indices.append(g.indices)
173 | values.append(g.values)
174 | all_indices = tf.concat(indices, 0)
175 | if mask is None:
176 | avg_values = tf.concat(values, 0) / len(grad_and_vars)
177 | else:
178 | avg_values = tf.concat(values, 0) / tf.reduce_sum(mask)
179 | # deduplicate across indices
180 | av, ai = _deduplicate_indexed_slices(avg_values, all_indices)
181 | grad = tf.IndexedSlices(av, ai, dense_shape=g0.dense_shape)
182 | else:
183 | # a normal tensor can just do a simple average
184 | grads = []
185 | for g, v in grad_and_vars:
186 | # Add 0 dimension to the gradients to represent the tower.
187 | expanded_g = tf.expand_dims(g, 0)
188 | # Append on a 'tower' dimension which we will average over
189 | grads.append(expanded_g)
190 |
191 | # Average over the 'tower' dimension.
192 | grad = tf.concat(grads, 0)
193 | if mask is not None:
194 | grad = tf.boolean_mask(
195 | grad, tf.cast(mask, tf.bool), axis=0)
196 | grad = tf.reduce_mean(grad, 0)
197 |
198 | # the Variables are redundant because they are shared
199 | # across towers. So.. just return the first tower's pointer to
200 | # the Variable.
201 | v = grad_and_vars[0][1]
202 | grad_and_var = (grad, v)
203 |
204 | average_grads.append(grad_and_var)
205 |
206 | assert len(average_grads) == len(list(zip(*tower_grads)))
207 |
208 | return average_grads
209 |
--------------------------------------------------------------------------------
/utils/queuer.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | The Queue function mainly deals with reading and preparing dataset in a multi-processing manner.
5 | We didnot use the built-in tensorflow function Dataset because it lacks of flexibility.
6 | The function defined below is mainly inspired by https://github.com/ixlan/machine-learning-data-pipeline.
7 | """
8 |
9 | from __future__ import absolute_import
10 | from __future__ import division
11 | from __future__ import print_function
12 |
13 | from multiprocessing import Process, Queue
14 |
15 | TERMINATION_TOKEN = ""
16 |
17 |
18 | def create_iter_from_queue(queue, term_token):
19 |
20 | while True:
21 | input_data_chunk = queue.get()
22 | if input_data_chunk == term_token:
23 | # put it back to the queue to let other processes that feed
24 | # from the same one to know that they should also break
25 | queue.put(term_token)
26 | break
27 | else:
28 | yield input_data_chunk
29 |
30 |
31 | def combine_reader_to_processor(reader, preprocessor):
32 | for data_chunk in reader:
33 | yield preprocessor(data_chunk)
34 |
35 |
36 | class EnQueuer(object):
37 | def __init__(self,
38 | reader,
39 | preprocessor,
40 | worker_processes_num=1,
41 | input_queue_size=5,
42 | output_queue_size=5
43 | ):
44 | if worker_processes_num < 0:
45 | raise ValueError("worker_processes_num must be a "
46 | "non-negative integer.")
47 |
48 | self.worker_processes_number = worker_processes_num
49 | self.preprocessor = preprocessor
50 | self.input_queue_size = input_queue_size
51 | self.output_queue_size = output_queue_size
52 | self.reader = reader
53 |
54 | # make the queue iterable
55 | def __iter__(self):
56 | return self._create_processed_data_chunks_gen(self.reader)
57 |
58 | def _create_processed_data_chunks_gen(self, reader_gen):
59 | if self.worker_processes_number == 0:
60 | itr = self._create_single_process_gen(reader_gen)
61 | else:
62 | itr = self._create_multi_process_gen(reader_gen)
63 | return itr
64 |
65 | def _create_single_process_gen(self, data_producer):
66 | return combine_reader_to_processor(data_producer, self.preprocessor)
67 |
68 | def _create_multi_process_gen(self, reader_gen):
69 | term_tokens_received = 0
70 | output_queue = Queue(self.output_queue_size)
71 | workers = []
72 |
73 | if self.worker_processes_number > 1:
74 | term_tokens_expected = self.worker_processes_number - 1
75 | input_queue = Queue(self.input_queue_size)
76 | reader_worker = _ParallelWorker(reader_gen, input_queue)
77 | workers.append(reader_worker)
78 |
79 | # adding workers that will process the data
80 | for _ in range(self.worker_processes_number - 1):
81 | # since data-chunks will appear in the queue, making an iterable
82 | # object over it
83 | queue_iter = create_iter_from_queue(input_queue,
84 | TERMINATION_TOKEN)
85 |
86 | data_itr = combine_reader_to_processor(queue_iter, self.preprocessor)
87 | proc_worker = _ParallelWorker(data_chunk_iter=data_itr,
88 | queue=output_queue)
89 | workers.append(proc_worker)
90 | else:
91 | term_tokens_expected = 1
92 |
93 | data_itr = combine_reader_to_processor(reader_gen, self.preprocessor)
94 | proc_worker = _ParallelWorker(data_chunk_iter=data_itr,
95 | queue=output_queue)
96 | workers.append(proc_worker)
97 |
98 | for pr in workers:
99 | pr.daemon = True
100 | pr.start()
101 |
102 | while True:
103 | data_chunk = output_queue.get()
104 | if data_chunk == TERMINATION_TOKEN:
105 | term_tokens_received += 1
106 | # need to received all tokens in order to be sure that
107 | # all data has been processed
108 | if term_tokens_received == term_tokens_expected:
109 | for pr in workers:
110 | pr.join()
111 | break
112 | continue
113 | yield data_chunk
114 |
115 |
116 | class _ParallelWorker(Process):
117 | """Worker to execute data reading or processing on a separate process."""
118 |
119 | def __init__(self, data_chunk_iter, queue):
120 | super(_ParallelWorker, self).__init__()
121 | self._data_chunk_iterable = data_chunk_iter
122 | self._queue = queue
123 |
124 | def run(self):
125 | for data_chunk in self._data_chunk_iterable:
126 | self._queue.put(data_chunk)
127 | self._queue.put(TERMINATION_TOKEN)
128 |
--------------------------------------------------------------------------------
/utils/recorder.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import json
8 | import tensorflow as tf
9 |
10 |
11 | class Recorder(object):
12 | """To save training processes, inspired by Nematus"""
13 |
14 | def load_from_json(self, file_name):
15 | tf.logging.info("Loading recoder file from {}".format(file_name))
16 | record = json.load(open(file_name, 'rb'))
17 | record = dict((key.encode("UTF-8"), value) for (key, value) in record.items())
18 | self.__dict__.update(record)
19 |
20 | def save_to_json(self, file_name):
21 | tf.logging.info("Saving recorder file into {}".format(file_name))
22 | with open(file_name, 'wb') as writer:
23 | writer.write(json.dumps(self.__dict__, indent=2).encode("utf-8"))
24 | writer.close()
25 |
--------------------------------------------------------------------------------
/utils/saver.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import os
8 | import tensorflow as tf
9 |
10 |
11 | class Saver(object):
12 | def __init__(self,
13 | checkpoints=5, # save the latest number of checkpoints
14 | output_dir=None, # the output directory
15 | best_score=-1, # the best bleu score before
16 | best_checkpoints=1, # the best checkpoints saved in best checkpoints directory
17 | ):
18 | if output_dir is None:
19 | output_dir = "./output"
20 | self.output_dir = output_dir
21 | self.output_best_dir = os.path.join(output_dir, "best")
22 |
23 | self.saver = tf.train.Saver(
24 | max_to_keep=checkpoints
25 | )
26 | # handle disrupted checkpoints
27 | if tf.gfile.Exists(self.output_dir):
28 | ckpt = tf.train.get_checkpoint_state(self.output_dir)
29 | if ckpt and ckpt.all_model_checkpoint_paths:
30 | self.saver.recover_last_checkpoints(list(ckpt.all_model_checkpoint_paths))
31 |
32 | self.best_saver = tf.train.Saver(
33 | max_to_keep=best_checkpoints,
34 | )
35 | # handle disrupted checkpoints
36 | if tf.gfile.Exists(self.output_best_dir):
37 | ckpt = tf.train.get_checkpoint_state(self.output_best_dir)
38 | if ckpt and ckpt.all_model_checkpoint_paths:
39 | self.best_saver.recover_last_checkpoints(list(ckpt.all_model_checkpoint_paths))
40 |
41 | self.best_score = best_score
42 | # check best bleu result
43 | metric_dir = os.path.join(self.output_best_dir, "metric.log")
44 | if tf.gfile.Exists(metric_dir):
45 | metric_lines = open(metric_dir).readlines()
46 | if len(metric_lines) > 0:
47 | best_score_line = metric_lines[-1]
48 | self.best_score = float(best_score_line.strip().split()[-1])
49 |
50 | # check the top_k_best list and results
51 | self.topk_scores = []
52 | topk_dir = os.path.join(self.output_best_dir, "topk_checkpoint")
53 | ckpt_dir = os.path.join(self.output_best_dir, "checkpoint")
54 | # direct load the topk information from topk_checkpoints
55 | if tf.gfile.Exists(topk_dir):
56 | with tf.gfile.Open(topk_dir) as reader:
57 | for line in reader:
58 | model_name, score = line.strip().split("\t")
59 | self.topk_scores.append((model_name, float(score)))
60 | # backup plan to normal checkpoints and best scores
61 | elif tf.gfile.Exists(ckpt_dir):
62 | latest_checkpoint = tf.gfile.Open(ckpt_dir).readline()
63 | model_name = latest_checkpoint.strip().split(":")[1].strip()
64 | model_name = model_name[1:-1] # remove ""
65 | self.topk_scores.append((model_name, self.best_score))
66 | self.best_checkpoints = best_checkpoints
67 |
68 | self.score_record = tf.gfile.Open(metric_dir, mode="a+")
69 |
70 | def save(self, session, step, metric_score=None):
71 | if not tf.gfile.Exists(self.output_dir):
72 | tf.gfile.MkDir(self.output_dir)
73 | if not tf.gfile.Exists(self.output_best_dir):
74 | tf.gfile.MkDir(self.output_best_dir)
75 |
76 | self.saver.save(session, os.path.join(self.output_dir, "model"), global_step=step)
77 |
78 | def _move(path, new_path):
79 | if tf.gfile.Exists(path):
80 | if tf.gfile.Exists(new_path):
81 | tf.gfile.Remove(new_path)
82 | tf.gfile.Copy(path, new_path)
83 |
84 | if metric_score is not None and metric_score > self.best_score:
85 | self.best_score = metric_score
86 |
87 | _move(os.path.join(self.output_dir, "param.json"),
88 | os.path.join(self.output_best_dir, "param.json"))
89 | _move(os.path.join(self.output_dir, "record.json"),
90 | os.path.join(self.output_best_dir, "record.json"))
91 |
92 | # this recorder only record best scores
93 | self.score_record.write("Steps {}, Metric Score {}\n".format(step, metric_score))
94 | self.score_record.flush()
95 |
96 | # either no model is saved, or current metric score is better than the minimum one
97 | if metric_score is not None and \
98 | (len(self.topk_scores) == 0 or len(self.topk_scores) < self.best_checkpoints or
99 | metric_score > min([v[1] for v in self.topk_scores])):
100 | # manipulate the 'checkpoints', and change the orders
101 | ckpt_dir = os.path.join(self.output_best_dir, "checkpoint")
102 | if len(self.topk_scores) > 0:
103 | sorted_topk_scores = sorted(self.topk_scores, key=lambda x: x[1])
104 | with tf.gfile.Open(ckpt_dir, mode='w') as writer:
105 | best_ckpt = sorted_topk_scores[-1]
106 | writer.write("model_checkpoint_path: \"{}\"\n".format(best_ckpt[0]))
107 | for model_name, _ in sorted_topk_scores:
108 | writer.write("all_model_checkpoint_paths: \"{}\"\n".format(model_name))
109 | writer.flush()
110 |
111 | # update best_saver internal checkpoints status
112 | ckpt = tf.train.get_checkpoint_state(self.output_best_dir)
113 | if ckpt and ckpt.all_model_checkpoint_paths:
114 | self.best_saver.recover_last_checkpoints(list(ckpt.all_model_checkpoint_paths))
115 |
116 | # this change mainly inspired by that sometimes for dataset,
117 | # the best performance is achieved by averaging top-k checkpoints
118 | self.best_saver.save(
119 | session, os.path.join(self.output_best_dir, "model"), global_step=step)
120 |
121 | # handle topk scores
122 | self.topk_scores.append(("model-{}".format(int(step)), float(metric_score)))
123 | sorted_topk_scores = sorted(self.topk_scores, key=lambda x: x[1])
124 | self.topk_scores = sorted_topk_scores[-self.best_checkpoints:]
125 | topk_dir = os.path.join(self.output_best_dir, "topk_checkpoint")
126 | with tf.gfile.Open(topk_dir, mode='w') as writer:
127 | for model_name, score in self.topk_scores:
128 | writer.write("{}\t{}\n".format(model_name, score))
129 | writer.flush()
130 |
131 | def restore(self, session, path=None):
132 | if path is not None and tf.gfile.Exists(path):
133 | check_dir = path
134 | else:
135 | check_dir = self.output_dir
136 |
137 | checkpoint = os.path.join(check_dir, "checkpoint")
138 | if not tf.gfile.Exists(checkpoint):
139 | tf.logging.warn("No Existing Model detected")
140 | else:
141 | latest_checkpoint = tf.gfile.Open(checkpoint).readline()
142 | model_name = latest_checkpoint.strip().split(":")[1].strip()
143 | model_name = model_name[1:-1] # remove ""
144 | model_path = os.path.join(check_dir, model_name)
145 | model_path = os.path.abspath(model_path)
146 | if not tf.gfile.Exists(model_path+".meta"):
147 | tf.logging.error("model '{}' does not exists"
148 | .format(model_path))
149 | else:
150 | try:
151 | self.saver.restore(session, model_path)
152 | except tf.errors.NotFoundError:
153 | # In this case, we simply assume that the cycle part
154 | # is mismatched, where the replicas are missing.
155 | # This would happen if you switch from un-cycle mode
156 | # to cycle mode.
157 | tf.logging.warn("Starting Backup Restore")
158 | ops = []
159 | reader = tf.train.load_checkpoint(model_path)
160 | for var in tf.global_variables():
161 | name = var.op.name
162 |
163 | if reader.has_tensor(name):
164 | tf.logging.info('{} get initialization from {}'
165 | .format(name, name))
166 | ops.append(
167 | tf.assign(var, reader.get_tensor(name)))
168 | else:
169 | tf.logging.warn("{} is missed".format(name))
170 | restore_op = tf.group(*ops, name="restore_global_vars")
171 | session.run(restore_op)
172 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import os
8 | import time
9 | import pkgutil
10 | import collections
11 | import numpy as np
12 | import tensorflow as tf
13 |
14 | from utils import dtype
15 |
16 |
17 | def batch_indexer(datasize, batch_size):
18 | """Just divide the datasize into batched size"""
19 | dataindex = np.arange(datasize).tolist()
20 |
21 | batchindex = []
22 | for i in range(datasize // batch_size):
23 | batchindex.append(dataindex[i * batch_size: (i + 1) * batch_size])
24 | if datasize % batch_size > 0:
25 | batchindex.append(dataindex[-(datasize % batch_size):])
26 |
27 | return batchindex
28 |
29 |
30 | def token_indexer(dataset, token_size):
31 | """Divide the dataset into token-based batch"""
32 | # assume dataset format: [(len1, len2, ..., lenN)]
33 | dataindex = np.arange(len(dataset)).tolist()
34 |
35 | batchindex = []
36 |
37 | _batcher = [0.] * len(dataset[0])
38 | _counter = 0
39 | i = 0
40 | while True:
41 | if i >= len(dataset): break
42 |
43 | # attempt put this datapoint into batch
44 | _batcher = [max(max_l, l)
45 | for max_l, l in zip(_batcher, dataset[i])]
46 | _counter += 1
47 | for l in _batcher:
48 | if _counter * l >= token_size:
49 | # when an extreme instance occur, handle it by making a 1-size batch
50 | if _counter > 1:
51 | batchindex.append(dataindex[i-_counter+1: i])
52 | i -= 1
53 | else:
54 | batchindex.append(dataindex[i: i+1])
55 |
56 | _counter = 0
57 | _batcher = [0.] * len(dataset[0])
58 | break
59 |
60 | i += 1
61 |
62 | _counter = sum([len(slice) for slice in batchindex])
63 | if _counter != len(dataset):
64 | batchindex.append(dataindex[_counter:])
65 | return batchindex
66 |
67 |
68 | def mask_scale(value, mask, scale=None):
69 | """Prepared for masked softmax"""
70 | if scale is None:
71 | scale = dtype.inf()
72 | return value + (1. - mask) * (-scale)
73 |
74 |
75 | def valid_apply_dropout(x, dropout):
76 | """To check whether the dropout value is valid, apply if valid"""
77 | if dropout is not None and 0. <= dropout <= 1.:
78 | return tf.nn.dropout(x, 1. - dropout)
79 | return x
80 |
81 |
82 | def layer_dropout(dropped, no_dropped, dropout_rate):
83 | """Layer Dropout"""
84 | pred = tf.random_uniform([]) < dropout_rate
85 | return tf.cond(pred, lambda: dropped, lambda: no_dropped)
86 |
87 |
88 | def label_smooth(labels, vocab_size, factor=0.1):
89 | """Smooth the gold label distribution"""
90 | if 0. < factor < 1.:
91 | n = tf.cast(vocab_size - 1, tf.float32)
92 | p = 1. - factor
93 | q = factor / n
94 |
95 | t = tf.one_hot(tf.cast(tf.reshape(labels, [-1]), tf.int32),
96 | depth=vocab_size, on_value=p, off_value=q)
97 | normalizing = -(p * tf.log(p) + n * q * tf.log(q + 1e-20))
98 | else:
99 | t = tf.one_hot(tf.cast(tf.reshape(labels, [-1]), tf.int32),
100 | depth=vocab_size)
101 | normalizing = 0.
102 |
103 | return t, normalizing
104 |
105 |
106 | def closing_dropout(params):
107 | """Removing all dropouts"""
108 | for k, v in params.values().items():
109 | if 'dropout' in k:
110 | setattr(params, k, 0.0)
111 | # consider closing label smoothing
112 | if 'label_smoothing' in k:
113 | setattr(params, k, 0.0)
114 | return params
115 |
116 |
117 | def dict_update(d, u):
118 | """Recursive update dictionary"""
119 | for k, v in u.items():
120 | if isinstance(v, collections.Mapping):
121 | d[k] = dict_update(d.get(k, {}), v)
122 | else:
123 | d[k] = v
124 | return d
125 |
126 |
127 | def shape_list(x):
128 | # Copied from Tensor2Tensor
129 | """Return list of dims, statically where possible."""
130 | x = tf.convert_to_tensor(x)
131 |
132 | # If unknown rank, return dynamic shape
133 | if x.get_shape().dims is None:
134 | return tf.shape(x)
135 |
136 | static = x.get_shape().as_list()
137 | shape = tf.shape(x)
138 |
139 | ret = []
140 | for i in range(len(static)):
141 | dim = static[i]
142 | if dim is None:
143 | dim = shape[i]
144 | ret.append(dim)
145 | return ret
146 |
147 |
148 | def get_shape_invariants(tensor):
149 | # Copied from Tensor2Tensor
150 | """Returns the shape of the tensor but sets middle dims to None."""
151 | shape = tensor.shape.as_list()
152 | for i in range(1, len(shape) - 1):
153 | shape[i] = None
154 |
155 | return tf.TensorShape(shape)
156 |
157 |
158 | def merge_neighbor_dims(x, axis=0):
159 | """Merge neighbor dimension of x, start by axis"""
160 | if len(x.get_shape().as_list()) < axis + 2:
161 | return x
162 |
163 | shape = shape_list(x)
164 | shape[axis] *= shape[axis+1]
165 | shape.pop(axis+1)
166 | return tf.reshape(x, shape)
167 |
168 |
169 | def unmerge_neighbor_dims(x, depth, axis=0):
170 | """Inverse of merge_neighbor_dims, axis by depth"""
171 | if len(x.get_shape().as_list()) < axis + 1:
172 | return x
173 |
174 | shape = shape_list(x)
175 | width = shape[axis] // depth
176 | new_shape = shape[:axis] + [depth, width] + shape[axis+1:]
177 | return tf.reshape(x, new_shape)
178 |
179 |
180 | def expand_tile_dims(x, depth, axis=1):
181 | """Expand and Tile x on axis by depth"""
182 | x = tf.expand_dims(x, axis=axis)
183 | tile_dims = [1] * x.shape.ndims
184 | tile_dims[axis] = depth
185 |
186 | return tf.tile(x, tile_dims)
187 |
188 |
189 | def gumbel_noise(shape, eps=None):
190 | """Generate gumbel noise shaped by shape"""
191 | if eps is None:
192 | eps = dtype.epsilon()
193 |
194 | u = tf.random_uniform(shape, minval=0, maxval=1)
195 | return -tf.log(-tf.log(u + eps) + eps)
196 |
197 |
198 | def log_prob_from_logits(logits):
199 | """Probability from un-nomalized logits"""
200 | return logits - tf.reduce_logsumexp(logits, axis=-1, keepdims=True)
201 |
202 |
203 | def batch_coordinates(batch_size, beam_size):
204 | """Batch coordinate indices under beam_size"""
205 | batch_pos = tf.range(batch_size * beam_size) // beam_size
206 | batch_pos = tf.reshape(batch_pos, [batch_size, beam_size])
207 |
208 | return batch_pos
209 |
210 |
211 | def variable_printer():
212 | """Print parameters"""
213 | all_weights = {v.name: v for v in tf.trainable_variables()}
214 | total_size = 0
215 |
216 | for v_name in sorted(list(all_weights)):
217 | v = all_weights[v_name]
218 | tf.logging.info("%s\tshape %s", v.name[:-2].ljust(80),
219 | str(v.shape).ljust(20))
220 | v_size = np.prod(np.array(v.shape.as_list())).tolist()
221 | total_size += v_size
222 | tf.logging.info("Total trainable variables size: %d", total_size)
223 |
224 |
225 | def uniform_splits(total_size, num_shards):
226 | """Split the total_size into uniform num_shards lists"""
227 | size_per_shards = total_size // num_shards
228 | splits = [size_per_shards] * (num_shards - 1) + \
229 | [total_size - (num_shards - 1) * size_per_shards]
230 |
231 | return splits
232 |
233 |
234 | def fetch_valid_ref_files(path):
235 | """Extracting valid reference files according to MT convention"""
236 | path = os.path.abspath(path)
237 | if tf.gfile.Exists(path):
238 | return [path]
239 |
240 | if not tf.gfile.Exists(path + ".ref0"):
241 | tf.logging.warn("Invalid Reference Format {}".format(path))
242 | return None
243 |
244 | num = 0
245 | files = []
246 | while True:
247 | file_path = path + ".ref%s" % num
248 | if tf.gfile.Exists(file_path):
249 | files.append(file_path)
250 | else:
251 | break
252 | num += 1
253 | return files
254 |
255 |
256 | def get_session(gpus):
257 | """Config session with GPUS"""
258 |
259 | sess_config = tf.ConfigProto(allow_soft_placement=True)
260 | sess_config.gpu_options.allow_growth = True
261 | if len(gpus) > 0:
262 | device_str = ",".join([str(i) for i in gpus])
263 | sess_config.gpu_options.visible_device_list = device_str
264 | sess = tf.Session(config=sess_config)
265 |
266 | return sess
267 |
268 |
269 | def flatten_list(values):
270 | """Flatten a list"""
271 | return [v for value in values for v in value]
272 |
273 |
274 | def remove_invalid_seq(sequence, mask):
275 | """Pick valid sequence elements wrt mask"""
276 | # sequence: [batch, sequence]
277 | # mask: [batch, sequence]
278 | boolean_mask = tf.reduce_sum(mask, axis=0)
279 |
280 | # make sure that there are at least one element in the mask
281 | first_one = tf.one_hot(0, tf.shape(boolean_mask)[0],
282 | dtype=tf.as_dtype(dtype.floatx()))
283 | boolean_mask = tf.cast(boolean_mask + first_one, tf.bool)
284 |
285 | filtered_seq = tf.boolean_mask(sequence, boolean_mask, axis=1)
286 | filtered_mask = tf.boolean_mask(mask, boolean_mask, axis=1)
287 | return filtered_seq, filtered_mask
288 |
289 |
290 | def time_str(t=None):
291 | """String format of the time long data"""
292 | if t is None:
293 | t = time.time()
294 | ts = time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime(t))
295 | return ts
296 |
297 |
298 | def dynamic_load_module(module, prefix=None):
299 | """Load submodules inside a module, mainly used for model loading, not robust!!!"""
300 | # loading all models under directory `models` dynamically
301 | if not isinstance(module, str):
302 | module = module.__path__
303 | for importer, modname, ispkg in pkgutil.iter_modules(module):
304 | if prefix is None:
305 | __import__(modname)
306 | else:
307 | __import__("{}.{}".format(prefix, modname))
308 |
--------------------------------------------------------------------------------
/vocab.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import argparse
8 |
9 |
10 | class Vocab(object):
11 | def __init__(self, vocab_file=None):
12 | self.word2id = {}
13 | self.id2word = {}
14 | self.word2count = {}
15 |
16 | self.pad_sym = ""
17 | self.eos_sym = ""
18 | self.unk_sym = ""
19 |
20 | self.insert(self.pad_sym)
21 | self.insert(self.unk_sym)
22 | self.insert(self.eos_sym)
23 |
24 | if vocab_file is not None:
25 | self.load_vocab(vocab_file)
26 |
27 | def insert(self, token):
28 | if token not in self.word2id:
29 | index = len(self.word2id)
30 | self.word2id[token] = index
31 | self.id2word[index] = token
32 |
33 | self.word2count[token] = 0
34 | self.word2count[token] += 1
35 |
36 | def size(self):
37 | return len(self.word2id)
38 |
39 | def load_vocab(self, vocab_file):
40 | with open(vocab_file, 'r') as reader:
41 | for token in reader:
42 | self.insert(token.strip())
43 |
44 | def get_token(self, id):
45 | if id in self.id2word:
46 | return self.id2word[id]
47 | return self.unk_sym
48 |
49 | def get_id(self, token):
50 | if token in self.word2id:
51 | return self.word2id[token]
52 | return self.word2id[self.unk_sym]
53 |
54 | def sort_vocab(self):
55 | sorted_word2count = sorted(
56 | self.word2count.items(), key=lambda x: - x[1])
57 | self.word2id, self.id2word = {}, {}
58 | self.insert(self.pad_sym)
59 | self.insert(self.unk_sym)
60 | self.insert(self.eos_sym)
61 | for word, _ in sorted_word2count:
62 | self.insert(word)
63 |
64 | def save_vocab(self, vocab_file, size=1e6):
65 | with open(vocab_file, 'w') as writer:
66 | for id in range(min(self.size(), int(size))):
67 | writer.write(self.id2word[id] + "\n")
68 |
69 | def to_id(self, tokens, append_eos=True):
70 | if not append_eos:
71 | return [self.get_id(token) for token in tokens]
72 | else:
73 | return [self.get_id(token) for token in tokens + [self.eos_sym]]
74 |
75 | def to_tokens(self, ids):
76 | return [self.get_token(id) for id in ids]
77 |
78 | def eos(self):
79 | return self.get_id(self.eos_sym)
80 |
81 | def pad(self):
82 | return self.get_id(self.pad_sym)
83 |
84 |
85 | if __name__ == "__main__":
86 | parser = argparse.ArgumentParser('Vocabulary Preparison')
87 | parser.add_argument('--size', type=int, default=1e6, help='maximum vocabulary size')
88 | parser.add_argument('input', type=str, help='the input file path')
89 | parser.add_argument('output', type=str, help='the output file name')
90 |
91 | args = parser.parse_args()
92 |
93 | vocab = Vocab()
94 | with open(args.input, 'r') as reader:
95 | for line in reader:
96 | for token in line.strip().split():
97 | vocab.insert(token)
98 |
99 | vocab.sort_vocab()
100 | vocab.save_vocab(args.output, args.size)
101 |
102 | print("Loading {} tokens from {}".format(vocab.size(), args.input))
103 |
--------------------------------------------------------------------------------