├── .gitignore ├── LICENSE ├── README.md ├── data.py ├── evalu.py ├── example ├── README.md ├── prepare.sh ├── test.sh └── train.sh ├── func.py ├── lrs ├── __init__.py ├── cosinelr.py ├── epochlr.py ├── lr.py └── noamlr.py ├── main.py ├── models ├── __init__.py ├── model.py └── transformer.py ├── modules ├── __init__.py ├── initializer.py └── speech.py ├── overview.png ├── run.py ├── scripts ├── checkpoint_averaging.py ├── chrF.py ├── multi-bleu-detok.perl ├── multi-bleu.perl └── shuffle_corpus.py ├── search.py ├── utils ├── __init__.py ├── cycle.py ├── dtype.py ├── metric.py ├── parallel.py ├── queuer.py ├── recorder.py ├── saver.py └── util.py └── vocab.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, Biao Zhang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Revisiting End-to-End Speech-to-Text Translation From Scratch 2 | 3 | 4 | [**Paper**]() | 5 | [**Highlights**](#paper-highlights) | 6 | [**Overview**](#model-visualization) | 7 | [**Model**](#pretrained-models) | 8 | [**Training&Eval**](#training-and-evaluation) | 9 | [**Citation**](#citation) 10 | [**Updates**](#updates) 11 | 12 | This repository contains source code, models, and also instructions for our ICML paper. 13 | 14 | >Note, by ST from scratch, we refer to the setup where ST models are trained on speech-translation pairs 15 | alone without using transcripts or any type of pretraining. 16 | 17 | >By pretraining, we mainly refer to ASR/MT pretraining using the triplet training data. 18 | 19 | ## Updates 20 | 21 | * [2023/02/21] Add support to CoLaCTC, using pseudo labels for regularization 22 | * [2023/02/21] Add support to flexible CTC labels, such as using transcript as labels 23 | 24 | ## Paper Highlights 25 | 26 | We explore the extent to which the quality of end-to-end speech-translation trained on speech-translation pairs alone and from 27 | scratch can be improved. 28 | 29 | - Techniques that are helpful for ST from scratch 30 | * deep encoder with post-layernorm structure (12 encoder + 6 decoder) 31 | * wide feed-forward layer (4096) 32 | * CTC regularization on top of the encoder with translation as labels 33 | * parameterized distance penalty (new proposal) 34 | * neural acoustic modeling (new proposal) 35 | * beam search hyperparameter tuning 36 | * smaller vocabulary size 37 | 38 | - We find that: 39 | * The quality gap between ST w/ and w/o pretraining is overestimated in the literature 40 | * By adapting ST towards scratch training, we can match and even outperform previous studies adopting pretraining 41 | * Pretraining matters: 1) extremely low-resource setup; 2) when large-scale external resources are available 42 | 43 | ## Model Visualization 44 | 45 | ![Overview of ur proposal](overview.png) 46 | 47 | Apart from parameterized distance penalty, we propose to jointly apply MLE and CTC objective for training, **even though we use translation as CTC labels.** 48 | 49 | ## Pretrained Models 50 | 51 | 52 | | Model | BLEU on MuST-C En-De | 53 | |-----------|-------------------------| 54 | | Fairseq (pretrain-finetune) | 22.7 | 55 | | NeurST (pretrain-finetune) | 22.8 | 56 | | Espnet (pretrain-finetune) | 22.9 | 57 | | this work (ST from scratch) | [22.7](https://data.statmt.org/bzhang/icml2022_revisiting/) | 58 | 59 | 60 | 61 | ## Requirement 62 | 63 | The source code is based on older tensorflow. 64 | 65 | - python==3.6 66 | - tensorflow==1.15+ 67 | 68 | 69 | ## Training and Evaluation 70 | 71 | Please check out the [example](./example) for reference. 72 | 73 | * [preprocessing](./example/prepare.sh) 74 | * [training](./example/train.sh) 75 | * [decoding](./example/test.sh) 76 | 77 | ## Citation 78 | 79 | If you draw any inspiration from our study, please consider to cite our paper: 80 | ``` 81 | @inproceedings{ 82 | zhang2022revisiting, 83 | title={Revisiting End-to-End Speech-to-Text Translation From Scratch}, 84 | author={Biao Zhang and Barry Haddow and Rico Sennrich}, 85 | booktitle={International Conference on Machine Learning}, 86 | year={2022}, 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import yaml 9 | import numpy as np 10 | import librosa 11 | from utils.util import batch_indexer, token_indexer 12 | 13 | 14 | def audio_encode(wav_path, offset=0.0, duration=None, sample_rate=16000): 15 | """ 16 | Encoding audio files into float list given the offset and duration 17 | We assume the sample rate to be 16k. 18 | """ 19 | # load data, sr=None enforce to use the native sample rate 20 | data, rate = librosa.load(wav_path, sr=None, offset=offset, duration=duration) 21 | if sample_rate is not None and rate != sample_rate: 22 | data, rate = librosa.load(wav_path, sr=sample_rate, offset=offset, duration=duration) 23 | assert len(data.shape) == 1 and rate == sample_rate, (data.shape, rate) 24 | 25 | if data.dtype not in [np.float32, np.float64]: 26 | data = data.astype(np.float32) / np.iinfo(data.dtype).max 27 | return data.astype(np.float32) 28 | 29 | 30 | def get_rough_length(audio_infor, p): 31 | duration = audio_infor['duration'] # in seconds 32 | # total signals 33 | num_signal = int(duration * p.audio_sample_rate) 34 | # windows properties 35 | frame_step = int(p.audio_frame_step * p.audio_sample_rate / 1e3) 36 | # total frame 37 | num_frame = (num_signal + frame_step - 1) // frame_step 38 | return num_frame 39 | 40 | 41 | class Dataset(object): 42 | def __init__(self, 43 | params, 44 | src_file, # audio/speech file 45 | tgt_file, # translation file 46 | src_vocab, # source vocabulary used for ctc file 47 | tgt_vocab, # translation vocabulary file 48 | ctc_file='', # either translation or transcript file 49 | batch_or_token='batch', 50 | data_leak_ratio=0.5, 51 | src_audio_path=''): 52 | self.source = src_file 53 | self.target = tgt_file 54 | self.src_vocab = src_vocab # Note source vocabulary here is meaningless 55 | self.tgt_vocab = tgt_vocab 56 | self.batch_or_token = batch_or_token 57 | self.data_leak_ratio = data_leak_ratio 58 | 59 | self.p = params 60 | self.sr = params.audio_sample_rate 61 | self.src_audio_path = src_audio_path 62 | 63 | # if no regularization file provided, use the translations directly 64 | # this could be useful for inference: where ctc file is not used at all. 65 | self.ctcref = ctc_file if ctc_file != '' else tgt_file 66 | 67 | self.max_frame_len = params.max_frame_len 68 | self.max_text_len = params.max_text_len 69 | 70 | self.leak_buffer = [] 71 | 72 | # loading dataset 73 | def load_data(self, is_train=False): 74 | sources = self.source.strip().split(";") 75 | targets = self.target.strip().split(";") 76 | ctcrefs = self.ctcref.strip().split(";") 77 | 78 | for source, target, ctcref in zip(sources, targets, ctcrefs): 79 | with open(source, 'r', encoding='utf-8') as src_reader, \ 80 | open(target, 'r', encoding='utf-8') as tgt_reader, \ 81 | open(ctcref, 'r', encoding='utf-8') as ctc_reader: 82 | 83 | while True: 84 | src_line = src_reader.readline() 85 | tgt_line = tgt_reader.readline() 86 | ctc_line = ctc_reader.readline() 87 | 88 | if tgt_line == "" or src_line == "" or ctc_line == "": 89 | break 90 | 91 | src_line = src_line.strip() 92 | tgt_line = tgt_line.strip() 93 | ctc_line = ctc_line.strip() 94 | 95 | if is_train and (tgt_line == "" or src_line == "" or ctc_line == ""): 96 | continue 97 | 98 | yield ( 99 | yaml.safe_load(src_line)[0], 100 | self.tgt_vocab.to_id(tgt_line.split()[:self.max_text_len]), 101 | self.src_vocab.to_id(ctc_line.split()[:self.max_text_len]), 102 | ) 103 | 104 | def to_matrix(self, batch): 105 | batch_size = len(batch) 106 | 107 | # handle source audios 108 | sources = [] 109 | frames = [] 110 | for sample in batch: 111 | audio_infor = sample[1] 112 | frames.append(get_rough_length(audio_infor, self.p)) 113 | 114 | sources.append(audio_encode( 115 | os.path.join(self.src_audio_path, audio_infor['wav']), 116 | audio_infor['offset'], 117 | audio_infor['duration'], 118 | sample_rate=self.sr)) 119 | 120 | src_lens = [len(sample) for sample in sources] 121 | tgt_lens = [len(sample[2]) for sample in batch] 122 | ctc_lens = [len(sample[3]) for sample in batch] 123 | 124 | src_len = min(self.max_frame_len, max(src_lens)) 125 | tgt_len = min(self.max_text_len, max(tgt_lens)) 126 | ctc_len = min(self.max_text_len, max(ctc_lens)) 127 | 128 | # (x, s, t) => (data_index, audio, translation) 129 | s = np.zeros([batch_size, src_len], dtype=np.float32) 130 | t = np.zeros([batch_size, tgt_len], dtype=np.int32) 131 | x = [] 132 | for eidx, sample in enumerate(batch): 133 | x.append(sample[0]) 134 | src_ids, tgt_ids = sources[eidx], sample[2] 135 | 136 | s[eidx, :min(src_len, len(src_ids))] = src_ids[:src_len] 137 | t[eidx, :min(tgt_len, len(tgt_ids))] = tgt_ids[:tgt_len] 138 | 139 | # construct sparse label sequence, for ctc training 140 | seq_indexes = [] 141 | seq_values = [] 142 | for n, sample in enumerate(batch): 143 | # change to ctc_ids and ctc_len 144 | sequence = sample[3][:ctc_len] 145 | 146 | seq_indexes.extend(zip([n] * len(sequence), range(len(sequence)))) 147 | # apply CoLaCTC (MoD) 148 | if self.p.cola_ctc_L < 0: 149 | seq_values.extend(sequence) 150 | else: 151 | # i.e. a very simple mod operation 152 | seq_values.extend([v % self.p.cola_ctc_L for v in sequence]) 153 | 154 | seq_indexes = np.asarray(seq_indexes, dtype=np.int64) 155 | seq_values = np.asarray(seq_values, dtype=np.int32) 156 | seq_shape = np.asarray([batch_size, ctc_len], dtype=np.int64) 157 | 158 | return x, s, t, (seq_indexes, seq_values, seq_shape), frames 159 | 160 | def processor(self, batch): 161 | x, s, t, spar, f = self.to_matrix(batch) 162 | return { 163 | 'src': s, 164 | 'tgt': t, 165 | 'frames': f, 166 | 'spar': spar, 167 | 'index': x, 168 | 'raw': batch, 169 | } 170 | 171 | def batcher(self, size, buffer_size=1000, shuffle=True, train=True): 172 | def _handle_buffer(_buffer): 173 | sorted_buffer = sorted( 174 | _buffer, key=lambda xx: max(get_rough_length(xx[1], self.p), len(xx[2]))) 175 | 176 | if self.batch_or_token == 'batch': 177 | buffer_index = batch_indexer(len(sorted_buffer), size) 178 | else: 179 | buffer_index = token_indexer( 180 | [[get_rough_length(sample[1], self.p), len(sample[2])] 181 | for sample in sorted_buffer], size) 182 | 183 | index_over_index = batch_indexer(len(buffer_index), 1) 184 | if shuffle: np.random.shuffle(index_over_index) 185 | 186 | for ioi in index_over_index: 187 | index = buffer_index[ioi[0]] 188 | batch = [sorted_buffer[ii] for ii in index] 189 | yield batch 190 | 191 | buffer = self.leak_buffer 192 | self.leak_buffer = [] 193 | for i, (src_ids, tgt_ids, ctc_ids) in enumerate(self.load_data(train)): 194 | buffer.append((i, src_ids, tgt_ids, ctc_ids)) 195 | if len(buffer) >= buffer_size: 196 | for data in _handle_buffer(buffer): 197 | # check whether the data is tailed 198 | batch_size = len(data) if self.batch_or_token == 'batch' \ 199 | else max(sum([len(sample[2]) for sample in data]), 200 | sum([get_rough_length(sample[1], self.p) for sample in data])) 201 | if batch_size < size * self.data_leak_ratio: 202 | self.leak_buffer += data 203 | else: 204 | yield data 205 | buffer = self.leak_buffer 206 | self.leak_buffer = [] 207 | 208 | # deal with data in the buffer 209 | if len(buffer) > 0: 210 | for data in _handle_buffer(buffer): 211 | # check whether the data is tailed 212 | batch_size = len(data) if self.batch_or_token == 'batch' \ 213 | else max(sum([len(sample[2]) for sample in data]), 214 | sum([get_rough_length(sample[1], self.p) for sample in data])) 215 | if train and batch_size < size * self.data_leak_ratio: 216 | self.leak_buffer += data 217 | else: 218 | yield data 219 | -------------------------------------------------------------------------------- /evalu.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import time 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | from utils import queuer, util, metric 12 | 13 | 14 | def decode_target_token(id_seq, vocab): 15 | """Convert sequence ids into tokens""" 16 | valid_id_seq = [] 17 | for tok_id in id_seq: 18 | if tok_id == vocab.eos() \ 19 | or tok_id == vocab.pad(): 20 | break 21 | valid_id_seq.append(tok_id) 22 | return vocab.to_tokens(valid_id_seq) 23 | 24 | 25 | def decode_hypothesis(seqs, scores, params, mask=None): 26 | """Generate decoded sequence from seqs""" 27 | if mask is None: 28 | mask = [1.] * len(seqs) 29 | 30 | hypoes = [] 31 | marks = [] 32 | for _seqs, _scores, _m in zip(seqs, scores, mask): 33 | if _m < 1.: continue 34 | 35 | for seq, score in zip(_seqs, _scores): 36 | # Temporarily, Use top-1 decoding 37 | best_seq = seq[0] 38 | best_score = score[0] 39 | 40 | hypo = decode_target_token(best_seq, params.tgt_vocab) 41 | mark = best_score 42 | 43 | hypoes.append(hypo) 44 | marks.append(mark) 45 | 46 | return hypoes, marks 47 | 48 | 49 | def decoding(session, features, out_seqs, out_scores, dataset, params): 50 | """Performing decoding with exising information""" 51 | translations = [] 52 | scores = [] 53 | indices = [] 54 | 55 | eval_queue = queuer.EnQueuer( 56 | dataset.batcher(params.eval_batch_size, 57 | buffer_size=params.buffer_size, 58 | shuffle=False, 59 | train=False), 60 | dataset.processor, 61 | worker_processes_num=params.process_num, 62 | input_queue_size=params.input_queue_size, 63 | output_queue_size=params.output_queue_size, 64 | ) 65 | 66 | def _predict_one_batch(_data_on_gpu): 67 | feed_dicts = {} 68 | 69 | _step_indices = [] 70 | for fidx, shard_data in enumerate(_data_on_gpu): 71 | # define feed_dict 72 | _feed_dict = { 73 | features[fidx]["source"]: shard_data['src'], 74 | } 75 | feed_dicts.update(_feed_dict) 76 | 77 | # collect data indices 78 | _step_indices.extend(shard_data['index']) 79 | 80 | # pick up valid outputs 81 | data_size = len(_data_on_gpu) 82 | valid_out_seqs = out_seqs[:data_size] 83 | valid_out_scores = out_scores[:data_size] 84 | 85 | _decode_seqs, _decode_scores = session.run( 86 | [valid_out_seqs, valid_out_scores], feed_dict=feed_dicts) 87 | 88 | _step_translations, _step_scores = decode_hypothesis( 89 | _decode_seqs, _decode_scores, params 90 | ) 91 | 92 | return _step_translations, _step_scores, _step_indices 93 | 94 | very_begin_time = time.time() 95 | data_on_gpu = [] 96 | for bidx, data in enumerate(eval_queue): 97 | if bidx == 0: 98 | # remove the data reading time 99 | very_begin_time = time.time() 100 | 101 | data_on_gpu.append(data) 102 | # use multiple gpus, and data samples is not enough 103 | if len(params.gpus) > 0 and len(data_on_gpu) < len(params.gpus): 104 | continue 105 | 106 | start_time = time.time() 107 | step_outputs = _predict_one_batch(data_on_gpu) 108 | data_on_gpu = [] 109 | 110 | translations.extend(step_outputs[0]) 111 | scores.extend(step_outputs[1]) 112 | indices.extend(step_outputs[2]) 113 | 114 | tf.logging.info( 115 | "Decoding Batch {} using {:.3f} s, translating {} " 116 | "sentences using {:.3f} s in total".format( 117 | bidx, time.time() - start_time, 118 | len(translations), time.time() - very_begin_time 119 | ) 120 | ) 121 | 122 | if len(data_on_gpu) > 0: 123 | 124 | start_time = time.time() 125 | step_outputs = _predict_one_batch(data_on_gpu) 126 | 127 | translations.extend(step_outputs[0]) 128 | scores.extend(step_outputs[1]) 129 | indices.extend(step_outputs[2]) 130 | 131 | tf.logging.info( 132 | "Decoding Batch {} using {:.3f} s, translating {} " 133 | "sentences using {:.3f} s in total".format( 134 | 'final', time.time() - start_time, 135 | len(translations), time.time() - very_begin_time 136 | ) 137 | ) 138 | 139 | return translations, scores, indices 140 | 141 | 142 | def scoring(session, features, out_scores, dataset, params): 143 | """Performing decoding with exising information""" 144 | scores = [] 145 | indices = [] 146 | 147 | eval_queue = queuer.EnQueuer( 148 | dataset.batcher(params.eval_batch_size, 149 | buffer_size=params.buffer_size, 150 | shuffle=False, 151 | train=False), 152 | dataset.processor, 153 | worker_processes_num=params.process_num, 154 | input_queue_size=params.input_queue_size, 155 | output_queue_size=params.output_queue_size, 156 | ) 157 | 158 | total_entropy = 0. 159 | total_tokens = 0. 160 | 161 | def _predict_one_batch(_data_on_gpu): 162 | feed_dicts = {} 163 | 164 | _step_indices = [] 165 | for fidx, shard_data in enumerate(_data_on_gpu): 166 | # define feed_dict 167 | _feed_dict = { 168 | features[fidx]["source"]: shard_data['src'], 169 | features[fidx]["target"]: shard_data['tgt'], 170 | } 171 | feed_dicts.update(_feed_dict) 172 | 173 | # collect data indices 174 | _step_indices.extend(shard_data['index']) 175 | 176 | # pick up valid outputs 177 | data_size = len(_data_on_gpu) 178 | valid_out_scores = out_scores[:data_size] 179 | 180 | _decode_scores = session.run( 181 | valid_out_scores, feed_dict=feed_dicts) 182 | 183 | _batch_entropy = sum([s * float((d > 0).sum()) 184 | for shard_data, shard_scores in zip(_data_on_gpu, _decode_scores) 185 | for d, s in zip(shard_data['tgt'], shard_scores.tolist())]) 186 | _batch_tokens = sum([(shard_data['tgt'] > 0).sum() for shard_data in _data_on_gpu]) 187 | 188 | _decode_scores = [s for _scores in _decode_scores for s in _scores] 189 | 190 | return _decode_scores, _step_indices, _batch_entropy, _batch_tokens 191 | 192 | very_begin_time = time.time() 193 | data_on_gpu = [] 194 | for bidx, data in enumerate(eval_queue): 195 | if bidx == 0: 196 | # remove the data reading time 197 | very_begin_time = time.time() 198 | 199 | data_on_gpu.append(data) 200 | # use multiple gpus, and data samples is not enough 201 | if len(params.gpus) > 0 and len(data_on_gpu) < len(params.gpus): 202 | continue 203 | 204 | start_time = time.time() 205 | step_outputs = _predict_one_batch(data_on_gpu) 206 | data_on_gpu = [] 207 | 208 | scores.extend(step_outputs[0]) 209 | indices.extend(step_outputs[1]) 210 | 211 | total_entropy += step_outputs[2] 212 | total_tokens += step_outputs[3] 213 | 214 | tf.logging.info( 215 | "Decoding Batch {} using {:.3f} s, translating {} " 216 | "sentences using {:.3f} s in total".format( 217 | bidx, time.time() - start_time, 218 | len(scores), time.time() - very_begin_time 219 | ) 220 | ) 221 | 222 | if len(data_on_gpu) > 0: 223 | 224 | start_time = time.time() 225 | step_outputs = _predict_one_batch(data_on_gpu) 226 | 227 | scores.extend(step_outputs[0]) 228 | indices.extend(step_outputs[1]) 229 | 230 | total_entropy += step_outputs[2] 231 | total_tokens += step_outputs[3] 232 | 233 | tf.logging.info( 234 | "Decoding Batch {} using {:.3f} s, translating {} " 235 | "sentences using {:.3f} s in total".format( 236 | 'final', time.time() - start_time, 237 | len(scores), time.time() - very_begin_time 238 | ) 239 | ) 240 | 241 | scores = [data[1] for data in 242 | sorted(zip(indices, scores), key=lambda x: x[0])] 243 | 244 | ppl = np.exp(total_entropy / total_tokens) 245 | 246 | return scores, ppl 247 | 248 | 249 | def eval_metric(trans, target_file, indices=None): 250 | """BLEU Evaluate """ 251 | target_valid_files = util.fetch_valid_ref_files(target_file) 252 | if target_valid_files is None: 253 | return 0.0 254 | 255 | if indices is not None: 256 | trans = [data[1] for data in sorted(zip(indices, trans), key=lambda x: x[0])] 257 | 258 | references = [] 259 | for ref_file in target_valid_files: 260 | cur_refs = tf.gfile.Open(ref_file).readlines() 261 | cur_refs = [line.strip().split() for line in cur_refs] 262 | references.append(cur_refs) 263 | 264 | references = list(zip(*references)) 265 | 266 | return metric.bleu(trans, references) 267 | 268 | 269 | def dump_tanslation(tranes, output, indices=None): 270 | """save translation""" 271 | if indices is not None: 272 | tranes = [data[1] for data in 273 | sorted(zip(indices, tranes), key=lambda x: x[0])] 274 | with tf.gfile.Open(output, 'w') as writer: 275 | for hypo in tranes: 276 | if isinstance(hypo, list): 277 | writer.write(' '.join(hypo) + "\n") 278 | else: 279 | writer.write(str(hypo) + "\n") 280 | tf.logging.info("Saving translations into {}".format(output)) 281 | -------------------------------------------------------------------------------- /example/README.md: -------------------------------------------------------------------------------- 1 | # Walk-through Example 2 | 3 | This file shows the rough procedure training an end-to-end SLT Transformer model based on MUST_C dataset. 4 | 5 | ### Step 1. Download MUST_C Dataset 6 | 7 | Take the En->Ge as an example 8 | 9 | * You can go to [the official website](https://ict.fbk.eu/must-c/) to download the dataset. 10 | 11 | * You can use the following 12 | [Google Drive Address](https://drive.google.com/open?id=1Mf2il_VelDIJMSio0bq7I8M9fSs-X4Ie) for downloading. 13 | 14 | 15 | Untar the dataset. 16 | 17 | ### Step 2. Download this code base 18 | 19 | ``` 20 | git clone https://github.com/bzhangGo/st_from_scratch.git 21 | ``` 22 | Suppose the downloaded code path is `st_from_scratch` so we refer to the code base as `${code}` 23 | 24 | ### Step 3. Preprocess the speech dataset 25 | 26 | You need to preprocess the English and German text file (tokenization, truecase, subword-bpe). 27 | Audios will be dynamically loaded during training. 28 | 29 | 1) Preprocessing the text files 30 | ``` 31 | en_de=/path/to/untared/en-de/ 32 | ln -s ${en_de} en-de 33 | ln -s en-de/data/dev/txt/dev.en . 34 | ln -s en-de/data/dev/txt/dev.de . 35 | ln -s en-de/data/tst-COMMON/txt/tst-COMMON.en test.en 36 | ln -s en-de/data/tst-COMMON/txt/tst-COMMON.de test.de 37 | ln -s en-de/data/train/txt/train.en . 38 | ln -s en-de/data/train/txt/train.de . 39 | 40 | # tokenize, true-case and BPE 41 | # you need download the mosesdecoder and subword-nmt, and re-set the path in the following script 42 | ./prepare.sh 43 | 44 | # prepare vocabulary 45 | python ${code}/vocab.py train.bpe.en vocab.zero.en 46 | python ${code}/vocab.py train.bpe.de vocab.zero.de 47 | ``` 48 | The resulting file is: 49 | 50 | - (train source, train target): `train.bpe.en, train.bpe.de` 51 | - (dev source, dev target): `dev.bpe.en, dev.bpe.de` 52 | - (test source, test target): `test.bpe.en, test.reftok.de` 53 | 54 | Notice the test reference file: `test.reftok.de`. It's only tokenized, without punctuation normalizing and true-casing 55 | 56 | ### Step 4. Train your model 57 | 58 | See the given running scripts `train.sh` for reference. It uses about 4~5 days (with one GPU) or shorter (with more gpus). 59 | 60 | ### Step 5. Decoding 61 | 62 | See the given running scripts `test.sh` for reference. 63 | 64 | -------------------------------------------------------------------------------- /example/prepare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -v 2 | 3 | # suffix of source language files 4 | SRC=en 5 | 6 | # suffix of target language files 7 | TRG=de 8 | 9 | # number of merge operations, using a smaller bpe number 10 | bpe_operations=8000 11 | 12 | # path to moses decoder: https://github.com/moses-smt/mosesdecoder 13 | mosesdecoder=path-to-mosesdecoder 14 | 15 | # path to subword segmentation scripts: https://github.com/rsennrich/subword-nmt 16 | subword_nmt=path-to-subwordnmt 17 | 18 | # tokenize 19 | # should use sacreBLEU for final evaluation 20 | for prefix in train dev test 21 | do 22 | cat $prefix.$SRC \ 23 | | $mosesdecoder/scripts/tokenizer/normalize-punctuation.perl -l $SRC \ 24 | | $mosesdecoder/scripts/tokenizer/tokenizer.perl -a -l $SRC > $prefix.tok.$SRC 25 | 26 | test -f $prefix.$TRG || continue 27 | 28 | cat $prefix.$TRG \ 29 | | $mosesdecoder/scripts/tokenizer/normalize-punctuation.perl -l $TRG \ 30 | | $mosesdecoder/scripts/tokenizer/tokenizer.perl -a -l $TRG > $prefix.tok.$TRG 31 | done 32 | 33 | # note this "*.reftok.*" file should be used as tokenized reference 34 | # by reference, we shouldn't apply punctuation normalization, or truecasing. 35 | for prefix in dev test 36 | do 37 | cat $prefix.$TRG \ 38 | | $mosesdecoder/scripts/tokenizer/tokenizer.perl -a -l $TRG > $prefix.reftok.$TRG 39 | done 40 | 41 | 42 | # train truecaser 43 | $mosesdecoder/scripts/recaser/train-truecaser.perl -corpus train.tok.$SRC -model tc.$SRC 44 | $mosesdecoder/scripts/recaser/train-truecaser.perl -corpus train.tok.$TRG -model tc.$TRG 45 | 46 | # apply truecaser (cleaned training corpus) 47 | for prefix in train dev test 48 | do 49 | $mosesdecoder/scripts/recaser/truecase.perl -model tc.$SRC < $prefix.tok.$SRC > $prefix.tc.$SRC 50 | test -f $prefix.tok.$TRG || continue 51 | $mosesdecoder/scripts/recaser/truecase.perl -model tc.$TRG < $prefix.tok.$TRG > $prefix.tc.$TRG 52 | done 53 | 54 | # train BPE 55 | cat train.tc.$SRC train.tc.$TRG | $subword_nmt/learn_bpe.py -s $bpe_operations > $SRC$TRG.bpe 56 | 57 | # apply BPE 58 | for prefix in train dev test 59 | do 60 | $subword_nmt/apply_bpe.py -c $SRC$TRG.bpe < $prefix.tc.$SRC > $prefix.bpe.$SRC 61 | test -f $prefix.tc.$TRG || continue 62 | $subword_nmt/apply_bpe.py -c $SRC$TRG.bpe < $prefix.tc.$TRG > $prefix.bpe.$TRG 63 | done 64 | -------------------------------------------------------------------------------- /example/test.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | 5 | data=path-to-preprocessed-ende-dataset/ 6 | code=path-to-zero-codebase/ 7 | moses=path-to-mosesdecoder/ 8 | 9 | # average best 10 checkpoints 10 | python3 ${code}/scripts/checkpoint_averaging.py --path ../train/best --output avg --checkpoints 10 --gpu 0 11 | 12 | python3 ${code}/run.py --mode test --parameters=hidden_size=256,embed_size=256,filter_size=4096,\ 13 | dropout=0.1,label_smooth=0.1,attention_dropout=0.1,relu_dropout=0.2,residual_dropout=0.2,\ 14 | max_text_len=256,max_frame_len=480000,batch_size=80,eval_batch_size=35,\ 15 | token_size=20000,batch_or_token='token',\ 16 | initializer="uniform_unit_scaling",initializer_gain=0.5,beam_size=8,decode_alpha=1.4,\ 17 | model_name="transformer",scope_name="transformer",buffer_size=5000,data_leak_ratio=0.1,\ 18 | input_queue_size=1000,output_queue_size=1000,\ 19 | deep_transformer_init=True,\ 20 | audio_num_mel_bins=40,audio_add_delta_deltas=True,pdp_r=512,\ 21 | sinusoid_posenc=True,max_poslen=20480,ctc_enable=True,ctc_alpha=0.3,audio_dither=0.0,\ 22 | enc_localize="pdp",dec_localize="none",encdec_localize="none",\ 23 | clip_grad_norm=0.0,\ 24 | num_heads=4,\ 25 | process_num=4,\ 26 | lrate=1.0,\ 27 | estop_patience=100,\ 28 | num_encoder_layer=12,\ 29 | num_decoder_layer=6,\ 30 | warmup_steps=4000,\ 31 | lrate_strategy="noam",\ 32 | epoches=5000,\ 33 | update_cycle=25,\ 34 | gpus=[0],\ 35 | disp_freq=1,\ 36 | eval_freq=1000,\ 37 | save_freq=2500,\ 38 | sample_freq=1000,\ 39 | checkpoints=10,\ 40 | best_checkpoints=10,\ 41 | max_training_steps=50000,\ 42 | beta1=0.9,\ 43 | beta2=0.98,\ 44 | random_seed=1234,\ 45 | src_vocab_file="$data/vocab.zero.en",\ 46 | tgt_vocab_file="$data/vocab.zero.de",\ 47 | src_train_path="$data/en-de/data/train/wav/",\ 48 | src_train_file="$data/en-de/data/train/txt/train.yaml",\ 49 | tgt_train_file="$data/train.bpe.de",\ 50 | src_dev_path="$data/en-de/data/dev/wav/",\ 51 | src_dev_file="$data/en-de/data/dev/txt/dev.yaml",\ 52 | tgt_dev_file="$data/dev.bpe.de",\ 53 | src_test_path="$data/en-de/data/tst-COMMON/wav/",\ 54 | src_test_file="$data/en-de/data/tst-COMMON/txt/tst-COMMON.yaml",\ 55 | tgt_test_file="$data/test.bpe.de",\ 56 | output_dir="avg",\ 57 | test_output="trans.txt",\ 58 | 59 | # post processing 60 | sed -r 's/ \@(\S*?)\@ /\1/g' < trans.txt | 61 | sed -r 's/\@\@ //g' | 62 | sed "s/<s>//" | 63 | ${moses}/scripts/recaser/detruecase.perl > trans.tok.txt 64 | 65 | # evaluation 66 | ${moses}/scripts/generic/multi-bleu.perl $data/test.reftok.de < trans.tok.txt > test.bleu 67 | 68 | # note to perform sacrebleu, you need `${moses}/scripts/tokenizer/detokenizer.perl -l de` to get the detokenized outputs -------------------------------------------------------------------------------- /example/train.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | 5 | data=path-to-preprocessed-ende-dataset/ 6 | code=path-to-codebase/ 7 | 8 | # the effective batch size: token_size * len(gpus) * update_cycle 9 | # - token_size: the token size number for one GPU. In speech translation case, the length of source input is very long, so it's not accurate. 10 | # - len(gpus): you could use multiple GPUS, modify the CUDA_VISIBLE_DEVICES and gpus=[0] 11 | # such as: export CUDA_VISIBLE_DEVICES=4,5,6,7, gpus=[0,1,2,3] to use four gpus 12 | # - update_cycle: accumulated gradient steps 13 | # You could use multiple gpus and a smaller update_cycle to accelerate training 14 | 15 | # Models are saved every `save_freq` steps, and evaluated every `eval_freq` with a maximum training steps `max_training_steps` 16 | 17 | # Other settings: Transformer encoder layers (num_encoder_layer, 12), decoder layers (num_decoder_layer, 6), model_size (hidden_size and embed_size, 256, filter_size 4096) 18 | # Adam: learning rate (lrate_strategy, noam and warmup steps 4000) 19 | 20 | python3 ${code}/run.py --mode train --parameters=hidden_size=256,embed_size=256,filter_size=4096,\ 21 | dropout=0.1,label_smooth=0.1,attention_dropout=0.1,relu_dropout=0.2,residual_dropout=0.2,\ 22 | max_text_len=256,max_frame_len=480000,eval_batch_size=5,\ 23 | token_size=20000,batch_or_token='token',\ 24 | initializer="uniform_unit_scaling",initializer_gain=0.5,\ 25 | model_name="transformer",scope_name="transformer",buffer_size=5000,data_leak_ratio=0.1,\ 26 | input_queue_size=1000,output_queue_size=1000,\ 27 | deep_transformer_init=True,\ 28 | audio_num_mel_bins=40,audio_add_delta_deltas=True,pdp_r=512,\ 29 | sinusoid_posenc=True,max_poslen=20480,ctc_enable=True,ctc_alpha=0.3,audio_dither=0.0,\ 30 | enc_localize="pdp",dec_localize="none",encdec_localize="none",\ 31 | clip_grad_norm=0.0,\ 32 | num_heads=4,\ 33 | process_num=4,\ 34 | lrate=1.0,\ 35 | estop_patience=100,\ 36 | num_encoder_layer=12,\ 37 | num_decoder_layer=6,\ 38 | warmup_steps=4000,\ 39 | lrate_strategy="noam",\ 40 | epoches=5000,\ 41 | update_cycle=25,\ 42 | gpus=[0],\ 43 | disp_freq=1,\ 44 | eval_freq=1000,\ 45 | save_freq=2500,\ 46 | sample_freq=1000,\ 47 | checkpoints=10,\ 48 | best_checkpoints=10,\ 49 | max_training_steps=50000,\ 50 | beta1=0.9,\ 51 | beta2=0.98,\ 52 | random_seed=1234,\ 53 | src_vocab_file="$data/vocab.zero.en",\ 54 | tgt_vocab_file="$data/vocab.zero.de",\ 55 | src_train_path="$data/en-de/data/train/wav/",\ 56 | src_train_file="$data/en-de/data/train/txt/train.yaml",\ 57 | tgt_train_file="$data/train.bpe.de",\ 58 | src_dev_path="$data/en-de/data/dev/wav/",\ 59 | src_dev_file="$data/en-de/data/dev/txt/dev.yaml",\ 60 | tgt_dev_file="$data/dev.bpe.de",\ 61 | src_test_path="$data/en-de/data/tst-COMMON/wav/",\ 62 | src_test_file="$data/en-de/data/tst-COMMON/txt/tst-COMMON.yaml",\ 63 | tgt_test_file="$data/test.bpe.de",\ 64 | output_dir="train",\ 65 | test_output="",\ 66 | 67 | 68 | # depth-scaled initialization 69 | # initializer="uniform_unit_scaling",initializer_gain=0.5,\ 70 | # deep_transformer_init=True,\ 71 | 72 | # 40-dimensional log-mel filterbanks as features, using delta-deltas, R=512 73 | # audio_num_mel_bins=40,audio_add_delta_deltas=True,pdp_r=512,\ 74 | # enc_localize="pdp",dec_localize="none",encdec_localize="none",\ 75 | 76 | # adopt sinusoidal positional encoding, using CTC-regularization with coefficient 0.3 77 | # sinusoid_posenc=True,max_poslen=20480,ctc_enable=True,ctc_alpha=0.3,audio_dither=0.0,\ 78 | 79 | # train_path: the wav path, train_file: the audio yaml file specific to must-c 80 | # src_train_path="$data/en-de/data/train/wav/",\ 81 | # src_train_file="$data/en-de/data/train/txt/train.yaml",\ 82 | # tgt_train_file="$data/train.bpe.de",\ 83 | -------------------------------------------------------------------------------- /func.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import math 8 | import tensorflow as tf 9 | 10 | from utils import util, dtype 11 | 12 | 13 | def linear(x, dim, bias=True, ln=False, 14 | weight_initializer=None, 15 | bias_initializer=tf.zeros_initializer(), 16 | scope=None, custom_getter=None): 17 | """ 18 | basic linear or feed forward layer 19 | :param x: input tensor or list 20 | :param dim: output dimension or list 21 | :param bias: whether use bias term 22 | :param ln: whether use layer normalization 23 | :param weight_initializer: you can set it if you want 24 | :param bias_initializer: you can set it if you want 25 | :param scope 26 | :return: 27 | """ 28 | with tf.variable_scope(scope or "linear", values=[x], 29 | dtype=tf.as_dtype(dtype.floatx()), 30 | custom_getter=custom_getter): 31 | if not isinstance(x, (list, tuple)): 32 | x = [x] 33 | if not isinstance(dim, (list, tuple)): 34 | dim = [dim] 35 | 36 | if not ln: 37 | # by default, we concatenate inputs 38 | x = [tf.concat(x, -1)] 39 | 40 | outputs = [] 41 | for oidx, osize in enumerate(dim): 42 | 43 | results = [] 44 | for iidx, ix in enumerate(x): 45 | x_shp = util.shape_list(ix) 46 | xsize = x_shp[-1] 47 | 48 | W = tf.get_variable("W_{}_{}".format(oidx, iidx), [xsize, osize], initializer=weight_initializer) 49 | o = tf.matmul(tf.reshape(ix, [-1, xsize]), W) 50 | 51 | if ln: 52 | o = layer_norm(o, scope="ln_{}_{}".format(oidx, iidx)) 53 | results.append(o) 54 | 55 | o = tf.add_n(results) 56 | 57 | if bias: 58 | b = tf.get_variable("b_{}".format(oidx), [osize], initializer=bias_initializer) 59 | o = tf.nn.bias_add(o, b) 60 | x_shp = util.shape_list(x[0])[:-1] 61 | o = tf.reshape(o, tf.concat([x_shp, [osize]], 0)) 62 | 63 | outputs.append(o) 64 | 65 | return outputs[0] if len(outputs) == 1 else outputs 66 | 67 | 68 | def split_heads(inputs, num_heads, name=None): 69 | """ Split heads 70 | :param inputs: A tensor with shape [batch, length, channels] 71 | :param num_heads: An integer 72 | :param name: An optional string 73 | :returns: A tensor with shape [batch, heads, length, channels / heads] 74 | """ 75 | 76 | with tf.name_scope(name or "split_heads"): 77 | x = inputs 78 | n = num_heads 79 | old_shape = x.get_shape().dims 80 | 81 | last = old_shape[-1] 82 | new_shape = old_shape[:-1] + [n] + [last // n if last else None] 83 | ret = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [n, -1]], 0)) 84 | ret.set_shape(new_shape) 85 | return tf.transpose(ret, [0, 2, 1, 3]) 86 | 87 | 88 | def combine_heads(inputs, name=None): 89 | """ Combine heads 90 | :param inputs: A tensor with shape [batch, heads, length, channels] 91 | :param name: An optional string 92 | :returns: A tensor with shape [batch, length, heads * channels] 93 | """ 94 | 95 | with tf.name_scope(name or "combine_heads"): 96 | x = inputs 97 | x = tf.transpose(x, [0, 2, 1, 3]) 98 | old_shape = x.get_shape().dims 99 | a, b = old_shape[-2:] 100 | new_shape = old_shape[:-2] + [a * b if a and b else None] 101 | x = tf.reshape(x, tf.concat([tf.shape(x)[:-2], [-1]], 0)) 102 | x.set_shape(new_shape) 103 | 104 | return x 105 | 106 | 107 | def dot_attention(query, memory, mem_mask, hidden_size, 108 | ln=False, num_heads=1, cache=None, dropout=None, 109 | pdp_r=16, out_map=True, scope=None, 110 | decode_step=None, localize=None): 111 | """ 112 | dotted attention model 113 | :param query: [batch_size, qey_len, dim] 114 | :param memory: [batch_size, seq_len, mem_dim] or None 115 | :param mem_mask: [batch_size, seq_len] 116 | :param hidden_size: attention space dimension 117 | :param ln: whether use layer normalization 118 | :param num_heads: attention head number 119 | :param dropout: attention dropout, default disable 120 | :param out_map: output additional mapping 121 | :param cache: cache-based decoding 122 | :param pdp_r: maximum position considered for pdp (parameterized distance penalty) 123 | :param decode_step: the time step of current decoding, 0-based 124 | :param localize: localization method for self-attention, including None, log, and pdp 125 | :param scope: 126 | :return: a value matrix, [batch_size, qey_len, mem_dim] 127 | """ 128 | with tf.variable_scope(scope or "dot_attention", reuse=tf.AUTO_REUSE, 129 | dtype=tf.as_dtype(dtype.floatx())): 130 | if memory is None: 131 | # suppose self-attention from queries alone 132 | h = linear(query, hidden_size * 3, ln=ln, scope="qkv_map") 133 | q, k, v = tf.split(h, 3, -1) 134 | 135 | if cache is not None: 136 | k = tf.concat([cache['k'], k], axis=1) 137 | v = tf.concat([cache['v'], v], axis=1) 138 | cache = { 139 | 'k': k, 140 | 'v': v, 141 | } 142 | else: 143 | q = linear(query, hidden_size, ln=ln, scope="q_map") 144 | if cache is not None and ('mk' in cache and 'mv' in cache): 145 | k, v = cache['mk'], cache['mv'] 146 | else: 147 | k = linear(memory, hidden_size, ln=ln, scope="k_map") 148 | v = linear(memory, hidden_size, ln=ln, scope="v_map") 149 | 150 | if cache is not None: 151 | cache['mk'] = k 152 | cache['mv'] = v 153 | 154 | q = split_heads(q, num_heads) 155 | k = split_heads(k, num_heads) 156 | v = split_heads(v, num_heads) 157 | 158 | q *= (hidden_size // num_heads) ** (-0.5) 159 | 160 | q_shp = util.shape_list(q) 161 | k_shp = util.shape_list(k) 162 | 163 | q_len = q_shp[2] if decode_step is None else decode_step + 1 164 | r_lst = None if decode_step is None else 1 165 | 166 | # q * k => attention weights 167 | logits = tf.matmul(q, k, transpose_b=True) 168 | 169 | if mem_mask is not None: 170 | logits += mem_mask 171 | 172 | # consider localization 173 | if localize is not None and localize != "none": 174 | k_len = k_shp[2] 175 | 176 | q_rng = tf.range(q_len) 177 | k_rng = tf.range(k_len) 178 | 179 | # shape: len_Q x len_K 180 | dist = tf.expand_dims(q_rng, 1) - tf.expand_dims(k_rng, 0) 181 | 182 | if localize == "log": 183 | dist = tf.abs(dist) + 1 184 | log_dist = tf.log(tf.to_float(dist)) 185 | if r_lst is not None: 186 | log_dist = log_dist[-r_lst:] 187 | logits -= tf.expand_dims(tf.expand_dims(log_dist, 0), 0) 188 | # implementation for the proposed parameterized penalty distance 189 | elif localize == "pdp": 190 | log_dist = tf.log(dtype.tf_to_float(tf.abs(dist) + 1)) 191 | if r_lst is not None: 192 | log_dist = log_dist[-r_lst:] 193 | 194 | # consider one more position for `zero` 195 | vocab_size = pdp_r + 1 196 | depth = num_heads 197 | 198 | # only consider absolute relative distance 199 | padding = vocab_size - 1 200 | mask = tf.to_int32(tf.less(tf.abs(dist), vocab_size)) 201 | dist = mask * tf.abs(dist) + (1 - mask) * tf.ones_like(dist)*padding 202 | 203 | if r_lst is not None: 204 | dist = dist[-r_lst:] 205 | 206 | pos_embedding = tf.get_variable("embeddings", [vocab_size, depth], initializer=tf.ones_initializer()) 207 | # len_Q x len_K x num_heads 208 | dist_emb = tf.gather(pos_embedding, dist) 209 | dist_emb = tf.transpose(dist_emb, [2, 0, 1]) 210 | logits += tf.expand_dims(dist_emb, 0) * (- tf.expand_dims(tf.expand_dims(log_dist, 0), 0)) 211 | else: 212 | raise NotImplementedError("invalid localization function {}".format(localize)) 213 | 214 | weights = tf.nn.softmax(logits) 215 | 216 | dweights = util.valid_apply_dropout(weights, dropout) 217 | 218 | # weights * v => attention vectors 219 | o = tf.matmul(dweights, v) 220 | 221 | o = combine_heads(o) 222 | 223 | if out_map: 224 | o = linear(o, hidden_size, ln=ln, scope="o_map") 225 | 226 | results = { 227 | 'weights': weights, 228 | 'output': o, 229 | 'cache': cache 230 | } 231 | 232 | return results 233 | 234 | 235 | def layer_norm(x, eps=None, scope=None, custom_getter=None): 236 | """Layer normalization layer""" 237 | if eps is None: 238 | eps = dtype.epsilon() 239 | with tf.variable_scope(scope or "layer_norm", 240 | dtype=tf.as_dtype(dtype.floatx()), 241 | custom_getter=custom_getter): 242 | layer_size = util.shape_list(x)[-1] 243 | 244 | scale = tf.get_variable("scale", [layer_size], initializer=tf.ones_initializer()) 245 | offset = tf.get_variable("offset", [layer_size], initializer=tf.zeros_initializer()) 246 | 247 | mean = tf.reduce_mean(x, -1, keep_dims=True) 248 | var = tf.reduce_mean((x - mean) ** 2, -1, keep_dims=True) 249 | 250 | return scale * (x - mean) * tf.rsqrt(var + eps) + offset 251 | 252 | 253 | def rms_norm(x, eps=None, scope=None): 254 | """RMS-based Layer normalization layer""" 255 | if eps is None: 256 | eps = dtype.epsilon() 257 | with tf.variable_scope(scope or "rms_norm", 258 | dtype=tf.as_dtype(dtype.floatx())): 259 | layer_size = util.shape_list(x)[-1] 260 | 261 | scale = tf.get_variable("scale", [layer_size], initializer=tf.ones_initializer()) 262 | 263 | ms = tf.reduce_mean(x ** 2, -1, keep_dims=True) 264 | 265 | return scale * x * tf.rsqrt(ms + eps) 266 | 267 | 268 | def residual_fn(x, y, dropout=None): 269 | """Residual Connection""" 270 | y = util.valid_apply_dropout(y, dropout) 271 | return x + y 272 | 273 | 274 | def ffn_layer(x, d, d_o, dropout=None, scope=None): 275 | """FFN layer in Transformer""" 276 | with tf.variable_scope(scope or "ffn_layer", 277 | dtype=tf.as_dtype(dtype.floatx())): 278 | hidden = linear(x, d, scope="enlarge") 279 | hidden = tf.nn.relu(hidden) 280 | 281 | hidden = util.valid_apply_dropout(hidden, dropout) 282 | 283 | output = linear(hidden, d_o, scope="output") 284 | 285 | return output 286 | 287 | 288 | def add_timing_signal(x, min_timescale=1.0, max_timescale=1.0e4, 289 | time=None, name=None): 290 | """Transformer Positional Embedding""" 291 | 292 | with tf.name_scope(name, default_name="add_timing_signal", values=[x]): 293 | length = tf.shape(x)[1] 294 | channels = tf.shape(x)[2] 295 | if time is None: 296 | position = dtype.tf_to_float(tf.range(length)) 297 | else: 298 | # decoding position embedding 299 | position = tf.expand_dims(time, 0) 300 | num_timescales = channels // 2 301 | 302 | log_timescale_increment = ( 303 | math.log(float(max_timescale) / float(min_timescale)) / 304 | (dtype.tf_to_float(num_timescales) - 1) 305 | ) 306 | inv_timescales = min_timescale * tf.exp( 307 | dtype.tf_to_float(tf.range(num_timescales)) * -log_timescale_increment 308 | ) 309 | 310 | scaled_time = (tf.expand_dims(position, 1) * 311 | tf.expand_dims(inv_timescales, 0)) 312 | signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) 313 | signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]]) 314 | signal = tf.reshape(signal, [1, length, channels]) 315 | 316 | return x + signal 317 | 318 | 319 | def attention_bias(inputs, mode, inf=None, name=None): 320 | """ A bias tensor used in attention mechanism""" 321 | 322 | if inf is None: 323 | inf = - dtype.inf() 324 | 325 | with tf.name_scope(name, default_name="attention_bias", values=[inputs]): 326 | if mode == "causal": 327 | length = inputs 328 | lower_triangle = tf.matrix_band_part( 329 | tf.ones([length, length]), -1, 0 330 | ) 331 | ret = dtype.tf_to_float(inf * (1.0 - lower_triangle)) 332 | return tf.reshape(ret, [1, 1, length, length]) 333 | elif mode == "masking": 334 | mask = inputs 335 | ret = (1.0 - mask) * inf 336 | return tf.expand_dims(tf.expand_dims(ret, 1), 1) 337 | elif mode == "aan": 338 | length = tf.shape(inputs)[1] 339 | diagonal = tf.eye(length) 340 | cum_factor = tf.expand_dims(tf.cumsum(diagonal, axis=0), 0) 341 | mask = tf.expand_dims(inputs, 1) * tf.expand_dims(inputs, 2) 342 | mask *= dtype.tf_to_float(cum_factor) 343 | weight = tf.nn.softmax(mask + (1.0 - mask) * inf) 344 | weight *= mask 345 | return weight 346 | else: 347 | raise ValueError("Unknown mode %s" % mode) 348 | -------------------------------------------------------------------------------- /lrs/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from lrs import noamlr, epochlr, cosinelr 4 | 5 | 6 | def get_lr(params): 7 | 8 | strategy = params.lrate_strategy.lower() 9 | 10 | if strategy == "noam": 11 | return noamlr.NoamDecayLr( 12 | params.lrate, 13 | params.min_lrate, 14 | params.max_lrate, 15 | params.warmup_steps, 16 | params.hidden_size 17 | ) 18 | elif strategy == "epoch": 19 | return epochlr.EpochDecayLr( 20 | params.lrate, 21 | params.min_lrate, 22 | params.max_lrate, 23 | params.lrate_decay, 24 | ) 25 | elif strategy == "cosine": 26 | return cosinelr.CosineDecayLr( 27 | params.lrate, 28 | params.min_lrate, 29 | params.max_lrate, 30 | params.warmup_steps, 31 | params.lrate_decay, 32 | t_mult=params.cosine_factor, 33 | update_period=params.cosine_period 34 | ) 35 | else: 36 | raise NotImplementedError( 37 | "{} is not supported".format(strategy)) 38 | -------------------------------------------------------------------------------- /lrs/cosinelr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import math 8 | 9 | from lrs import lr 10 | 11 | 12 | class CosineDecayLr(lr.Lr): 13 | """Decay the learning rate during each training step, follows FairSeq""" 14 | def __init__(self, 15 | init_lr, # initial learning rate => warmup_init_lr 16 | min_lr, # minimum learning rate 17 | max_lr, # maximum learning rate 18 | warmup_steps, # warmup step => warmup_updates 19 | decay, # learning rate shrink factor for annealing 20 | t_mult=1, # factor to grow the length of each period 21 | update_period=5000, # initial number of updates per period 22 | name="cosine_decay_lr" # model name, no use 23 | ): 24 | super(CosineDecayLr, self).__init__(init_lr, min_lr, max_lr, name=name) 25 | 26 | self.warmup_steps = warmup_steps 27 | 28 | self.warmup_init_lr = init_lr 29 | self.warmup_end_lr = max_lr 30 | self.t_mult = t_mult 31 | self.period = update_period 32 | 33 | if self.warmup_steps > 0: 34 | self.lr_step = (self.warmup_end_lr - self.warmup_init_lr) / self.warmup_steps 35 | else: 36 | self.lr_step = 1. 37 | 38 | self.decay = decay 39 | 40 | # initial learning rate 41 | self.lrate = init_lr 42 | 43 | def step(self, step): 44 | if step < self.warmup_steps: 45 | self.lrate = self.warmup_init_lr + step * self.lr_step 46 | else: 47 | curr_updates = step - self.warmup_steps 48 | if self.t_mult != 1: 49 | i = math.floor(math.log(1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult)) 50 | t_i = self.t_mult ** i * self.period 51 | t_curr = curr_updates - (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period 52 | else: 53 | i = math.floor(curr_updates / self.period) 54 | t_i = self.period 55 | t_curr = curr_updates - (self.period * i) 56 | 57 | lr_shrink = self.decay ** i 58 | min_lr = self.min_lrate * lr_shrink 59 | max_lr = self.max_lrate * lr_shrink 60 | 61 | self.lrate = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i)) 62 | 63 | return self.lrate 64 | -------------------------------------------------------------------------------- /lrs/epochlr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | 8 | from lrs import lr 9 | 10 | 11 | class EpochDecayLr(lr.Lr): 12 | """Decay the learning rate after each epoch""" 13 | def __init__(self, 14 | init_lr, 15 | min_lr, # minimum learning rate 16 | max_lr, # maximum learning rate 17 | decay=0.5, # learning rate decay rate 18 | name="epoch_decay_lr" 19 | ): 20 | super(EpochDecayLr, self).__init__(init_lr, min_lr, max_lr, name=name) 21 | 22 | self.decay = decay 23 | 24 | def after_epoch(self, eidx=None): 25 | if eidx is None: 26 | self.lrate = self.init_lrate * self.decay 27 | else: 28 | self.lrate = self.init_lrate * self.decay ** int(eidx) 29 | -------------------------------------------------------------------------------- /lrs/lr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | 8 | # This is an abstract class that deals with 9 | # different learning rate decay strategy 10 | # Generally, we decay the learning rate with GPU computation 11 | # However, in this paper, we simply decay the learning rate 12 | # at CPU level, and feed the decayed lr into GPU for 13 | # optimization 14 | class Lr(object): 15 | def __init__(self, 16 | init_lrate, # initial learning rate 17 | min_lrate, # minimum learning rate 18 | max_lrate, # maximum learning rate 19 | name="lr", # learning rate name, no use 20 | ): 21 | self.name = name 22 | self.init_lrate = init_lrate # just record the init learning rate 23 | self.lrate = init_lrate # active learning rate, change with training 24 | self.min_lrate = min_lrate 25 | self.max_lrate = max_lrate 26 | 27 | assert self.max_lrate > self.min_lrate, "Minimum learning rate " \ 28 | "should less than maximum learning rate" 29 | 30 | # suppose the eidx starts from 1 31 | def before_epoch(self, eidx=None): 32 | pass 33 | 34 | def after_epoch(self, eidx=None): 35 | pass 36 | 37 | def step(self, step): 38 | pass 39 | 40 | def after_eval(self, eval_score): 41 | pass 42 | 43 | def get_lr(self): 44 | """Return the learning rate whenever you want""" 45 | return max(min(self.lrate, self.max_lrate), self.min_lrate) 46 | -------------------------------------------------------------------------------- /lrs/noamlr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | 9 | from lrs import lr 10 | 11 | 12 | class NoamDecayLr(lr.Lr): 13 | """Decay the learning rate during each training step, follows Transformer""" 14 | def __init__(self, 15 | init_lr, # initial learning rate 16 | min_lr, # minimum learning rate 17 | max_lr, # maximum learning rate 18 | warmup_steps, # warmup step 19 | hidden_size, # model hidden size 20 | name="noam_decay_lr" # model name, no use 21 | ): 22 | super(NoamDecayLr, self).__init__(init_lr, min_lr, max_lr, name=name) 23 | 24 | self.warmup_steps = warmup_steps 25 | self.hidden_size = hidden_size 26 | 27 | def step(self, step): 28 | step = float(step) 29 | warmup_steps = float(self.warmup_steps) 30 | 31 | multiplier = float(self.hidden_size) ** -0.5 32 | decay = multiplier * np.minimum((step + 1) * (warmup_steps ** -1.5), 33 | (step + 1) ** -0.5) 34 | self.lrate = self.init_lrate * decay 35 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import time 9 | import copy 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | import evalu 14 | import lrs 15 | from data import Dataset 16 | from models import model 17 | from search import beam_search 18 | from utils import parallel, cycle, util, queuer, saver, dtype 19 | from modules import initializer 20 | 21 | 22 | def tower_train_graph(train_features, optimizer, graph, params): 23 | # define multi-gpu training graph 24 | def _tower_train_graph(features): 25 | train_output = graph.train_fn( 26 | features, params, initializer=initializer.get_initializer(params.initializer, params.initializer_gain)) 27 | 28 | tower_gradients = optimizer.compute_gradients( 29 | train_output["loss"] * tf.cast(params.loss_scale, tf.float32), 30 | colocate_gradients_with_ops=True) 31 | tower_gradients = [(g / tf.cast(params.loss_scale, tf.float32), v) for g, v in tower_gradients] 32 | 33 | return { 34 | "loss": train_output["loss"], 35 | "gradient": tower_gradients 36 | } 37 | 38 | # feed model to multiple gpus 39 | tower_outputs = parallel.parallel_model( 40 | _tower_train_graph, train_features, 41 | params.gpus, use_cpu=(len(params.gpus) == 0)) 42 | 43 | loss = tf.add_n(tower_outputs['loss']) / len(tower_outputs['loss']) 44 | gradients = parallel.average_gradients(tower_outputs['gradient']) 45 | 46 | return loss, gradients 47 | 48 | 49 | def tower_infer_graph(eval_features, graph, params): 50 | # define multi-gpu inferring graph 51 | def _tower_infer_graph(features): 52 | encoding_fn, decoding_fn = graph.infer_fn(params) 53 | beam_output = beam_search(features, encoding_fn, decoding_fn, params) 54 | 55 | return beam_output 56 | 57 | # feed model to multiple gpus 58 | eval_outputs = parallel.parallel_model( 59 | _tower_infer_graph, eval_features, 60 | params.gpus, use_cpu=(len(params.gpus) == 0)) 61 | eval_seqs, eval_scores = eval_outputs['seq'], eval_outputs['score'] 62 | 63 | return eval_seqs, eval_scores 64 | 65 | 66 | def tower_score_graph(eval_features, graph, params): 67 | # define multi-gpu inferring graph 68 | def _tower_infer_graph(features): 69 | scores = graph.score_fn(features, params) 70 | return scores 71 | 72 | # feed model to multiple gpus 73 | eval_outputs = parallel.parallel_model( 74 | _tower_infer_graph, eval_features, 75 | params.gpus, use_cpu=(len(params.gpus) == 0)) 76 | eval_scores = eval_outputs['score'] 77 | 78 | return eval_scores 79 | 80 | 81 | def train(params): 82 | # status measure 83 | if params.recorder.estop or \ 84 | params.recorder.epoch > params.epoches or \ 85 | params.recorder.step > params.max_training_steps: 86 | tf.logging.info("Stop condition reached, you have finished training your model.") 87 | return 0. 88 | 89 | # loading dataset 90 | tf.logging.info("Begin Loading Training and Dev Dataset") 91 | start_time = time.time() 92 | train_dataset = Dataset(params, params.src_train_file, params.tgt_train_file, 93 | params.src_vocab, params.tgt_vocab, 94 | ctc_file=params.ctc_train_file, 95 | batch_or_token=params.batch_or_token, 96 | data_leak_ratio=params.data_leak_ratio, 97 | src_audio_path=params.src_train_path) 98 | dev_dataset = Dataset(params, params.src_dev_file, params.tgt_dev_file, 99 | params.src_vocab, params.src_vocab, 100 | batch_or_token='batch', 101 | data_leak_ratio=params.data_leak_ratio, 102 | src_audio_path=params.src_dev_path) 103 | tf.logging.info( 104 | "End Loading dataset, within {} seconds".format(time.time() - start_time)) 105 | 106 | # Build Graph 107 | with tf.Graph().as_default(): 108 | lr = tf.placeholder(tf.as_dtype(dtype.floatx()), [], "learn_rate") 109 | 110 | # shift automatically sliced multi-gpu process into `zero` manner :) 111 | features = [] 112 | for fidx in range(max(len(params.gpus), 1)): 113 | feature = { 114 | "source": tf.placeholder(tf.float32, [None, None], "source"), 115 | "target": tf.placeholder(tf.int32, [None, None], "target"), 116 | "label": tf.sparse_placeholder(tf.int32, name="label"), 117 | } 118 | features.append(feature) 119 | 120 | # session info 121 | sess = util.get_session(params.gpus) 122 | 123 | tf.logging.info("Begining Building Training Graph") 124 | start_time = time.time() 125 | 126 | # create global step 127 | global_step = tf.train.get_or_create_global_step() 128 | 129 | # set up optimizer 130 | optimizer = tf.train.AdamOptimizer(lr, 131 | beta1=params.beta1, 132 | beta2=params.beta2, 133 | epsilon=params.epsilon) 134 | 135 | # get graph 136 | graph = model.get_model(params.model_name) 137 | 138 | # set up training graph 139 | loss, gradients = tower_train_graph(features, optimizer, graph, params) 140 | 141 | # apply pseudo cyclic parallel operation 142 | vle, ops = cycle.create_train_op({"loss": loss}, gradients, 143 | optimizer, global_step, params) 144 | 145 | tf.logging.info("End Building Training Graph, within {} seconds".format(time.time() - start_time)) 146 | 147 | tf.logging.info("Begin Building Inferring Graph") 148 | start_time = time.time() 149 | 150 | # set up infer graph 151 | eval_seqs, eval_scores = tower_infer_graph(features, graph, params) 152 | 153 | tf.logging.info("End Building Inferring Graph, within {} seconds".format(time.time() - start_time)) 154 | 155 | # initialize the model 156 | sess.run(tf.global_variables_initializer()) 157 | 158 | # log parameters 159 | util.variable_printer() 160 | 161 | # create saver 162 | train_saver = saver.Saver( 163 | checkpoints=params.checkpoints, 164 | output_dir=params.output_dir, 165 | best_checkpoints=params.best_checkpoints, 166 | ) 167 | 168 | tf.logging.info("Training") 169 | cycle_counter = 0 170 | data_on_gpu = [] 171 | cum_tokens = [] 172 | cum_frames = [] 173 | 174 | # restore parameters 175 | tf.logging.info("Trying restore ASR existing parameters") 176 | train_saver.restore( 177 | sess, path=params.asr_pretrain, filter_variables=params.filter_variables) 178 | 179 | tf.logging.info("Trying restore existing parameters") 180 | train_saver.restore(sess) 181 | 182 | # setup learning rate 183 | params.lrate = params.recorder.lrate 184 | adapt_lr = lrs.get_lr(params) 185 | 186 | start_time = time.time() 187 | start_epoch = params.recorder.epoch 188 | for epoch in range(start_epoch, params.epoches + 1): 189 | 190 | params.recorder.epoch = epoch 191 | 192 | tf.logging.info("Training the model for epoch {}".format(epoch)) 193 | size = params.batch_size if params.batch_or_token == 'batch' \ 194 | else params.token_size 195 | 196 | train_queue = queuer.EnQueuer( 197 | train_dataset.batcher(size, 198 | buffer_size=params.buffer_size, 199 | shuffle=params.shuffle_batch, 200 | train=True), 201 | train_dataset.processor, 202 | worker_processes_num=params.process_num, 203 | input_queue_size=params.input_queue_size, 204 | output_queue_size=params.output_queue_size, 205 | ) 206 | 207 | adapt_lr.before_epoch(eidx=epoch) 208 | 209 | for lidx, data in enumerate(train_queue): 210 | 211 | if params.train_continue: 212 | if lidx <= params.recorder.lidx: 213 | segments = params.recorder.lidx // 5 214 | if params.recorder.lidx < 5 or lidx % segments == 0: 215 | tf.logging.info( 216 | "{} Passing {}-th index according to record" 217 | "".format(util.time_str(time.time()), lidx)) 218 | 219 | continue 220 | 221 | params.recorder.lidx = lidx 222 | 223 | data_on_gpu.append(data) 224 | # use multiple gpus, and data samples is not enough 225 | # make sure the data is fully added 226 | # The actual batch size: batch_size * num_gpus * update_cycle 227 | if len(params.gpus) > 0 and len(data_on_gpu) < len(params.gpus): 228 | continue 229 | 230 | # increase the counter by 1 231 | cycle_counter += 1 232 | 233 | if cycle_counter == 1: 234 | # calculate adaptive learning rate 235 | adapt_lr.step(params.recorder.step) 236 | 237 | # clear internal states 238 | sess.run(ops["zero_op"]) 239 | 240 | # data feeding to gpu placeholders 241 | feed_dicts = {} 242 | for fidx, shard_data in enumerate(data_on_gpu): 243 | # define feed_dict 244 | feed_dict = { 245 | features[fidx]["source"]: shard_data["src"], 246 | features[fidx]["target"]: shard_data["tgt"], 247 | features[fidx]["label"]: shard_data["spar"], 248 | lr: adapt_lr.get_lr(), 249 | } 250 | feed_dicts.update(feed_dict) 251 | 252 | # collect target tokens 253 | cum_tokens.append(np.sum(shard_data['tgt'] > 0)) 254 | cum_frames.append(sum(shard_data['frames'])) 255 | 256 | # reset data points on gpus 257 | data_on_gpu = [] 258 | 259 | # internal accumulative gradient collection 260 | if cycle_counter < params.update_cycle: 261 | sess.run(ops["collect_op"], feed_dict=feed_dicts) 262 | 263 | # at the final step, update model parameters 264 | if cycle_counter == params.update_cycle: 265 | cycle_counter = 0 266 | 267 | # directly update parameters, usually this works well 268 | if not params.safe_nan: 269 | _, loss, gnorm, pnorm, gstep = sess.run( 270 | [ops["train_op"], vle["loss"], vle["gradient_norm"], vle["parameter_norm"], 271 | global_step], feed_dict=feed_dicts) 272 | 273 | if np.isnan(loss) or np.isinf(loss) or np.isnan(gnorm) or np.isinf(gnorm): 274 | tf.logging.error("Nan or Inf raised! Loss {} GNorm {}.".format(loss, gnorm)) 275 | params.recorder.estop = True 276 | break 277 | else: 278 | # Note, applying safe nan can help train the big model, but sacrifice speed 279 | loss, gnorm, pnorm, gstep = sess.run( 280 | [vle["loss"], vle["gradient_norm"], vle["parameter_norm"], global_step], 281 | feed_dict=feed_dicts) 282 | 283 | if np.isnan(loss) or np.isinf(loss) or np.isnan(gnorm) or np.isinf(gnorm) \ 284 | or gnorm > params.gnorm_upper_bound: 285 | tf.logging.error( 286 | "Nan or Inf raised, GStep {} is passed! Loss {} GNorm {}.".format(gstep, loss, gnorm)) 287 | continue 288 | 289 | sess.run(ops["train_op"], feed_dict=feed_dicts) 290 | 291 | if gstep % params.disp_freq == 0: 292 | end_time = time.time() 293 | tf.logging.info( 294 | "{} Epoch {}, GStep {}~{}, LStep {}~{}, " 295 | "Loss {:.3f}, GNorm {:.3f}, PNorm {:.3f}, Lr {:.5f}, " 296 | "Src {}, Tgt {}, Tokens {}, Frames {}, UD {:.3f} s".format( 297 | util.time_str(end_time), epoch, 298 | gstep - params.disp_freq + 1, gstep, 299 | lidx - params.disp_freq + 1, lidx, 300 | loss, gnorm, pnorm, 301 | adapt_lr.get_lr(), data['src'].shape, data['tgt'].shape, 302 | np.sum(cum_tokens), np.sum(cum_frames), end_time - start_time) 303 | ) 304 | start_time = time.time() 305 | cum_tokens = [] 306 | cum_frames = [] 307 | 308 | # trigger model saver 309 | if gstep > 0 and gstep % params.save_freq == 0: 310 | train_saver.save(sess, gstep) 311 | params.recorder.save_to_json(os.path.join(params.output_dir, "record.json")) 312 | 313 | # trigger model evaluation 314 | if gstep > 0 and gstep % params.eval_freq == 0: 315 | if params.ema_decay > 0.: 316 | sess.run(ops['ema_backup_op']) 317 | sess.run(ops['ema_assign_op']) 318 | 319 | tf.logging.info("Start Evaluating") 320 | eval_start_time = time.time() 321 | tranes, scores, indices = evalu.decoding( 322 | sess, features, eval_seqs, 323 | eval_scores, dev_dataset, params) 324 | bleu = evalu.eval_metric(tranes, params.tgt_dev_file, indices=indices) 325 | eval_end_time = time.time() 326 | tf.logging.info("End Evaluating") 327 | 328 | if params.ema_decay > 0.: 329 | sess.run(ops['ema_restore_op']) 330 | 331 | tf.logging.info( 332 | "{} GStep {}, Scores {}, BLEU {}, Duration {:.3f} s".format( 333 | util.time_str(eval_end_time), gstep, np.mean(scores), 334 | bleu, eval_end_time - eval_start_time) 335 | ) 336 | 337 | # save eval translation 338 | evalu.dump_tanslation( 339 | tranes, 340 | os.path.join(params.output_dir, 341 | "eval-{}.trans.txt".format(gstep)), 342 | indices=indices) 343 | 344 | # save parameters 345 | train_saver.save(sess, gstep, bleu) 346 | 347 | # check for early stopping 348 | valid_scores = [v[1] for v in params.recorder.valid_script_scores] 349 | if len(valid_scores) == 0 or bleu > np.max(valid_scores): 350 | params.recorder.bad_counter = 0 351 | else: 352 | params.recorder.bad_counter += 1 353 | 354 | if params.recorder.bad_counter > params.estop_patience: 355 | params.recorder.estop = True 356 | break 357 | 358 | params.recorder.history_scores.append( 359 | (int(gstep), float(np.mean(scores))) 360 | ) 361 | params.recorder.valid_script_scores.append( 362 | (int(gstep), float(bleu)) 363 | ) 364 | params.recorder.save_to_json( 365 | os.path.join(params.output_dir, "record.json")) 366 | 367 | # handle the learning rate decay in a typical manner 368 | adapt_lr.after_eval(float(bleu)) 369 | 370 | # trigger temporary sampling 371 | if gstep > 0 and gstep % params.sample_freq == 0: 372 | tf.logging.info("Start Sampling") 373 | decode_seqs, decode_scores = sess.run( 374 | [eval_seqs[:1], eval_scores[:1]], feed_dict={features[0]["source"]: data["src"][:5]}) 375 | tranes, scores = evalu.decode_hypothesis(decode_seqs, decode_scores, params) 376 | 377 | for sidx in range(min(5, len(scores))): 378 | sample_target = evalu.decode_target_token(data['tgt'][sidx], params.tgt_vocab) 379 | tf.logging.info("{}-th Target: {}".format(sidx, ' '.join(sample_target))) 380 | sample_trans = tranes[sidx] 381 | tf.logging.info("{}-th Translation: {}".format(sidx, ' '.join(sample_trans))) 382 | 383 | tf.logging.info("End Sampling") 384 | 385 | # trigger stopping 386 | if gstep >= params.max_training_steps: 387 | # stop running by setting EStop signal 388 | params.recorder.estop = True 389 | break 390 | 391 | # should be equal to global_step 392 | params.recorder.step = int(gstep) 393 | 394 | if params.recorder.estop: 395 | tf.logging.info("Early Stopped!") 396 | break 397 | 398 | # reset to 0 399 | params.recorder.lidx = -1 400 | 401 | adapt_lr.after_epoch(eidx=epoch) 402 | 403 | # Final Evaluation 404 | tf.logging.info("Start Final Evaluating") 405 | if params.ema_decay > 0.: 406 | sess.run(ops['ema_backup_op']) 407 | sess.run(ops['ema_assign_op']) 408 | 409 | gstep = int(params.recorder.step + 1) 410 | eval_start_time = time.time() 411 | tranes, scores, indices = evalu.decoding(sess, features, eval_seqs, eval_scores, dev_dataset, params) 412 | bleu = evalu.eval_metric(tranes, params.tgt_dev_file, indices=indices) 413 | eval_end_time = time.time() 414 | tf.logging.info("End Evaluating") 415 | 416 | if params.ema_decay > 0.: 417 | sess.run(ops['ema_restore_op']) 418 | 419 | tf.logging.info( 420 | "{} GStep {}, Scores {}, BLEU {}, Duration {:.3f} s".format( 421 | util.time_str(eval_end_time), gstep, np.mean(scores), bleu, eval_end_time - eval_start_time) 422 | ) 423 | 424 | # save eval translation 425 | evalu.dump_tanslation( 426 | tranes, 427 | os.path.join(params.output_dir, 428 | "eval-{}.trans.txt".format(gstep)), 429 | indices=indices) 430 | 431 | tf.logging.info("Your training is finished :)") 432 | 433 | return train_saver.best_score 434 | 435 | 436 | def evaluate(params): 437 | # loading dataset 438 | tf.logging.info("Begin Loading Test Dataset") 439 | start_time = time.time() 440 | test_dataset = Dataset(params, params.src_test_file, params.tgt_test_file, 441 | params.src_vocab, params.src_vocab, 442 | batch_or_token='batch', 443 | data_leak_ratio=params.data_leak_ratio, 444 | src_audio_path=params.src_test_path) 445 | tf.logging.info( 446 | "End Loading dataset, within {} seconds".format(time.time() - start_time)) 447 | 448 | # Build Graph 449 | with tf.Graph().as_default(): 450 | features = [] 451 | for fidx in range(max(len(params.gpus), 1)): 452 | feature = { 453 | "source": tf.placeholder(tf.float32, [None, None], "source"), 454 | } 455 | features.append(feature) 456 | 457 | # session info 458 | sess = util.get_session(params.gpus) 459 | 460 | tf.logging.info("Begining Building Evaluation Graph") 461 | start_time = time.time() 462 | 463 | # get graph 464 | graph = model.get_model(params.model_name) 465 | 466 | # set up infer graph 467 | eval_seqs, eval_scores = tower_infer_graph(features, graph, params) 468 | 469 | tf.logging.info("End Building Inferring Graph, within {} seconds".format(time.time() - start_time)) 470 | 471 | # set up ema 472 | if params.ema_decay > 0.: 473 | # recover from EMA 474 | ema = tf.train.ExponentialMovingAverage(decay=params.ema_decay) 475 | ema.apply(tf.trainable_variables()) 476 | ema_assign_op = tf.group(*(tf.assign(var, ema.average(var).read_value()) 477 | for var in tf.trainable_variables())) 478 | else: 479 | ema_assign_op = tf.no_op() 480 | 481 | # initialize the model 482 | sess.run(tf.global_variables_initializer()) 483 | 484 | # log parameters 485 | util.variable_printer() 486 | 487 | # create saver 488 | eval_saver = saver.Saver(checkpoints=params.checkpoints, output_dir=params.output_dir) 489 | 490 | # restore parameters 491 | tf.logging.info("Trying restore existing parameters") 492 | eval_saver.restore(sess, params.output_dir) 493 | sess.run(ema_assign_op) 494 | 495 | tf.logging.info("Starting Evaluating") 496 | eval_start_time = time.time() 497 | tranes, scores, indices = evalu.decoding(sess, features, eval_seqs, eval_scores, test_dataset, params) 498 | bleu = evalu.eval_metric(tranes, params.tgt_test_file, indices=indices) 499 | eval_end_time = time.time() 500 | 501 | tf.logging.info( 502 | "{} Scores {}, BLEU {}, Duration {}s".format( 503 | util.time_str(eval_end_time), np.mean(scores), bleu, eval_end_time - eval_start_time) 504 | ) 505 | 506 | # save translation 507 | evalu.dump_tanslation(tranes, params.test_output, indices=indices) 508 | 509 | return bleu 510 | 511 | 512 | def scorer(params): 513 | # loading dataset 514 | tf.logging.info("Begin Loading Test Dataset") 515 | start_time = time.time() 516 | test_dataset = Dataset(params, params.src_test_file, params.tgt_test_file, 517 | params.src_vocab, params.tgt_vocab, 518 | batch_or_token='batch', 519 | data_leak_ratio=params.data_leak_ratio, 520 | src_audio_path=params.src_test_path) 521 | tf.logging.info( 522 | "End Loading dataset, within {} seconds".format(time.time() - start_time)) 523 | 524 | # Build Graph 525 | with tf.Graph().as_default(): 526 | features = [] 527 | for fidx in range(max(len(params.gpus), 1)): 528 | feature = { 529 | "source": tf.placeholder(tf.float32, [None, None], "source"), 530 | "target": tf.placeholder(tf.int32, [None, None], "target"), 531 | } 532 | features.append(feature) 533 | 534 | # session info 535 | sess = util.get_session(params.gpus) 536 | 537 | tf.logging.info("Begining Building Evaluation Graph") 538 | start_time = time.time() 539 | 540 | # get graph 541 | graph = model.get_model(params.model_name) 542 | 543 | # set up infer graph 544 | eval_scores = tower_score_graph(features, graph, params) 545 | 546 | tf.logging.info("End Building Inferring Graph, within {} seconds".format(time.time() - start_time)) 547 | 548 | # set up ema 549 | if params.ema_decay > 0.: 550 | # recover from EMA 551 | ema = tf.train.ExponentialMovingAverage(decay=params.ema_decay) 552 | ema.apply(tf.trainable_variables()) 553 | ema_assign_op = tf.group(*(tf.assign(var, ema.average(var).read_value()) 554 | for var in tf.trainable_variables())) 555 | else: 556 | ema_assign_op = tf.no_op() 557 | 558 | # initialize the model 559 | sess.run(tf.global_variables_initializer()) 560 | 561 | # log parameters 562 | util.variable_printer() 563 | 564 | # create saver 565 | eval_saver = saver.Saver(checkpoints=params.checkpoints, output_dir=params.output_dir) 566 | 567 | # restore parameters 568 | tf.logging.info("Trying restore existing parameters") 569 | eval_saver.restore(sess, params.output_dir) 570 | sess.run(ema_assign_op) 571 | 572 | tf.logging.info("Starting Evaluating") 573 | eval_start_time = time.time() 574 | scores, ppl = evalu.scoring(sess, features, eval_scores, test_dataset, params) 575 | eval_end_time = time.time() 576 | 577 | tf.logging.info( 578 | "{} Scores {}, PPL {}, Duration {}s".format( 579 | util.time_str(eval_end_time), np.mean(scores), ppl, eval_end_time - eval_start_time) 580 | ) 581 | 582 | # save translation 583 | evalu.dump_tanslation(scores, params.test_output) 584 | 585 | return np.mean(scores) 586 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | from collections import namedtuple 9 | 10 | # global models defined in Zero 11 | _total_models = {} 12 | 13 | 14 | class ModelWrapper(namedtuple("ModelTupleWrapper", 15 | ("train_fn", "score_fn", "infer_fn"))): 16 | pass 17 | 18 | 19 | # you need register your model by your self 20 | def model_register(model_name, train_fn, score_fn, infer_fn): 21 | model_name = model_name.lower() 22 | 23 | if model_name in _total_models: 24 | raise Exception("Conflict Model Name: {}".format(model_name)) 25 | 26 | tf.logging.info("Registering model: {}".format(model_name)) 27 | 28 | _total_models[model_name] = ModelWrapper( 29 | train_fn=train_fn, 30 | score_fn=score_fn, 31 | infer_fn=infer_fn, 32 | ) 33 | 34 | 35 | def get_model(model_name): 36 | model_name = model_name.lower() 37 | 38 | if model_name in _total_models: 39 | return _total_models[model_name] 40 | 41 | raise Exception("No supported model {}".format(model_name)) 42 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import copy 8 | import tensorflow as tf 9 | 10 | import func 11 | from models import model 12 | from modules import speech 13 | from utils import util, dtype 14 | 15 | 16 | def stacking(inputs, scale=3, mask=None): 17 | """ 18 | Reshapes the given outputs, i.e. reduces the 19 | time resolution by 3. 20 | Similar to "Listen Attend Spell". 21 | https://arxiv.org/pdf/1508.01211.pdf 22 | """ 23 | # [batch_size, max_time, num_units] 24 | batch_size, max_time, num_units = util.shape_list(inputs) 25 | 26 | if mask is not None: 27 | inputs *= tf.expand_dims(mask, -1) 28 | 29 | num_pad = tf.cast(tf.ceil(tf.divide(max_time, scale)) * scale, tf.int32) - max_time 30 | 31 | pads = [[0, 0], [0, num_pad], [0, 0]] 32 | inputs = tf.pad(inputs, pads) 33 | 34 | if mask is not None: 35 | pads = [[0, 0], [0, num_pad]] 36 | mask = tf.pad(mask, pads) 37 | 38 | concat_inputs = tf.reshape(inputs, (batch_size, -1, num_units * scale)) 39 | if mask is not None: 40 | concat_mask = tf.reshape(mask, (batch_size, -1, scale)) 41 | concat_mask = 1. - tf.to_float(tf.less(tf.reduce_sum(concat_mask, -1), scale)) 42 | 43 | return concat_inputs, concat_mask 44 | else: 45 | return concat_inputs 46 | 47 | 48 | def encoder(source, params): 49 | hidden_size = params.hidden_size 50 | 51 | # extract logmel features 52 | source, mask, wavframes = speech.extract_logmel_features(source, params) 53 | target = source 54 | 55 | if params.use_nafm: 56 | x = wavframes 57 | with tf.variable_scope("feed_forward"): 58 | y = func.ffn_layer( 59 | x, 60 | params.filter_size, 61 | util.shape_list(wavframes)[-1], 62 | dropout=params.relu_dropout, 63 | ) 64 | 65 | x = func.residual_fn(x, y, dropout=params.residual_dropout) 66 | x = func.layer_norm(x) 67 | 68 | with tf.variable_scope("feed_forward2"): 69 | y = func.ffn_layer( 70 | x, 71 | params.filter_size, 72 | util.shape_list(wavframes)[-1], 73 | dropout=params.relu_dropout, 74 | ) 75 | 76 | x = func.residual_fn(x, y, dropout=params.residual_dropout) 77 | x = func.layer_norm(x) 78 | 79 | source = func.linear(x, util.shape_list(target)[-1], scope="pretrain") 80 | _source, _mask = source, mask 81 | 82 | # tried different settings for scale, turns out 3 is good 83 | source, mask = stacking(source, scale=3, mask=mask) 84 | if not params.sinusoid_posenc: 85 | source = source[:, :params.max_poslen] 86 | mask = mask[:, :params.max_poslen] 87 | source, mask = dtype.tf_to_float(source), dtype.tf_to_float(mask) 88 | 89 | # map from raw feature space to Transformer dimension 90 | inputs = func.linear(source, params.embed_size, scope="emb_mapper") 91 | 92 | # transformer is sensitive to the position encoding, 93 | if params.sinusoid_posenc: 94 | inputs = func.add_timing_signal(inputs) 95 | else: 96 | pos_emb = tf.get_variable("pos_embedding", [params.max_poslen, params.embed_size]) 97 | 98 | ishp = util.shape_list(inputs) 99 | inputs += tf.expand_dims(pos_emb[:ishp[1]], 0) 100 | 101 | # this normalization layer deeply stabilize the gradient and optimization issue 102 | inputs = func.layer_norm(inputs) 103 | inputs = util.valid_apply_dropout(inputs, params.dropout) 104 | 105 | with tf.variable_scope("encoder"): 106 | x = inputs 107 | for layer in range(params.num_encoder_layer): 108 | if params.deep_transformer_init: 109 | layer_initializer = tf.variance_scaling_initializer( 110 | params.initializer_gain * (layer + 1) ** -0.5, 111 | mode="fan_avg", 112 | distribution="uniform") 113 | else: 114 | layer_initializer = None 115 | with tf.variable_scope("layer_{}".format(layer), initializer=layer_initializer): 116 | with tf.variable_scope("self_attention"): 117 | # suggest: encoder_localize-> pdp, decoder->none 118 | y = func.dot_attention( 119 | x, 120 | None, 121 | func.attention_bias(mask, "masking"), 122 | hidden_size, 123 | num_heads=params.num_heads, 124 | dropout=params.attention_dropout, 125 | localize=params.enc_localize, 126 | pdp_r=params.pdp_r, 127 | ) 128 | 129 | y = y['output'] 130 | x = func.residual_fn(x, y, dropout=params.residual_dropout) 131 | x = func.layer_norm(x) 132 | 133 | with tf.variable_scope("feed_forward"): 134 | y = func.ffn_layer( 135 | x, 136 | params.filter_size, 137 | hidden_size, 138 | dropout=params.relu_dropout, 139 | ) 140 | 141 | x = func.residual_fn(x, y, dropout=params.residual_dropout) 142 | x = func.layer_norm(x) 143 | 144 | source_encodes = x 145 | x_shp = util.shape_list(x) 146 | 147 | states = { 148 | "encodes": source_encodes, 149 | "decoder_initializer": { 150 | "layer_{}".format(l): { 151 | "k": dtype.tf_to_float(tf.zeros([x_shp[0], 0, hidden_size])), 152 | "v": dtype.tf_to_float(tf.zeros([x_shp[0], 0, hidden_size])), 153 | } 154 | for l in range(params.num_decoder_layer) 155 | }, 156 | "mask": mask 157 | } 158 | 159 | if params.use_nafm: 160 | states['_target'] = target 161 | states['_source'] = _source 162 | states['_mask'] = _mask 163 | return states 164 | 165 | 166 | def decoder(target, state, params, labels=None): 167 | mask = dtype.tf_to_float(tf.cast(target, tf.bool)) 168 | hidden_size = params.hidden_size 169 | initializer = tf.random_normal_initializer(0.0, hidden_size ** -0.5) 170 | 171 | is_training = ('decoder' not in state) 172 | 173 | if is_training: 174 | target, mask = util.remove_invalid_seq(target, mask) 175 | 176 | embed_name = "embedding" if params.shared_source_target_embedding \ 177 | else "tgt_embedding" 178 | tgt_emb = tf.get_variable(embed_name, 179 | [params.tgt_vocab.size(), params.embed_size], 180 | initializer=initializer) 181 | tgt_bias = tf.get_variable("bias", [params.embed_size]) 182 | 183 | inputs = tf.gather(tgt_emb, target) * (hidden_size ** 0.5) 184 | inputs = tf.nn.bias_add(inputs, tgt_bias) 185 | 186 | # shift 187 | if is_training: 188 | inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]]) 189 | inputs = inputs[:, :-1, :] 190 | inputs = func.add_timing_signal(inputs) 191 | else: 192 | inputs = tf.cond(tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())), 193 | lambda: tf.zeros_like(inputs), 194 | lambda: inputs) 195 | mask = tf.ones_like(mask) 196 | inputs = func.add_timing_signal(inputs, time=dtype.tf_to_float(state['time'])) 197 | 198 | inputs = util.valid_apply_dropout(inputs, params.dropout) 199 | 200 | with tf.variable_scope("decoder"): 201 | x = inputs 202 | for layer in range(params.num_decoder_layer): 203 | if params.deep_transformer_init: 204 | layer_initializer = tf.variance_scaling_initializer( 205 | params.initializer_gain * (layer + 1) ** -0.5, 206 | mode="fan_avg", 207 | distribution="uniform") 208 | else: 209 | layer_initializer = None 210 | with tf.variable_scope("layer_{}".format(layer), initializer=layer_initializer): 211 | with tf.variable_scope("self_attention"): 212 | y = func.dot_attention( 213 | x, 214 | None, 215 | func.attention_bias(tf.shape(mask)[1], "causal"), 216 | hidden_size, 217 | num_heads=params.num_heads, 218 | dropout=params.attention_dropout, 219 | cache=None if is_training else state['decoder']['state']['layer_{}'.format(layer)], 220 | localize=params.dec_localize, 221 | pdp_r=params.pdp_r, 222 | decode_step=None if is_training else state['time'], 223 | ) 224 | if not is_training: 225 | # k, v 226 | state['decoder']['state']['layer_{}'.format(layer)] \ 227 | .update(y['cache']) 228 | 229 | y = y['output'] 230 | x = func.residual_fn(x, y, dropout=params.residual_dropout) 231 | x = func.layer_norm(x) 232 | 233 | with tf.variable_scope("cross_attention"): 234 | y = func.dot_attention( 235 | x, 236 | state['encodes'], 237 | func.attention_bias(state['mask'], "masking"), 238 | hidden_size, 239 | num_heads=params.num_heads, 240 | dropout=params.attention_dropout, 241 | cache=None if is_training else state['decoder']['state']['layer_{}'.format(layer)], 242 | localize=params.encdec_localize, 243 | pdp_r=params.pdp_r, 244 | ) 245 | if not is_training: 246 | # mk, mv 247 | state['decoder']['state']['layer_{}'.format(layer)] \ 248 | .update(y['cache']) 249 | 250 | y = y['output'] 251 | x = func.residual_fn(x, y, dropout=params.residual_dropout) 252 | x = func.layer_norm(x) 253 | 254 | with tf.variable_scope("feed_forward"): 255 | y = func.ffn_layer( 256 | x, 257 | params.filter_size, 258 | hidden_size, 259 | dropout=params.relu_dropout, 260 | ) 261 | 262 | x = func.residual_fn(x, y, dropout=params.residual_dropout) 263 | x = func.layer_norm(x) 264 | feature = x 265 | if 'dev_decode' in state: 266 | feature = x[:, -1, :] 267 | 268 | embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \ 269 | else "softmax_embedding" 270 | embed_name = "embedding" if params.shared_source_target_embedding \ 271 | else embed_name 272 | softmax_emb = tf.get_variable(embed_name, 273 | [params.tgt_vocab.size(), params.embed_size], 274 | initializer=initializer) 275 | feature = tf.reshape(feature, [-1, params.embed_size]) 276 | logits = tf.matmul(feature, softmax_emb, False, True) 277 | 278 | logits = tf.cast(logits, tf.float32) 279 | 280 | soft_label, normalizer = util.label_smooth( 281 | target, 282 | util.shape_list(logits)[-1], 283 | factor=params.label_smooth) 284 | centropy = tf.nn.softmax_cross_entropy_with_logits_v2( 285 | logits=logits, 286 | labels=soft_label 287 | ) 288 | centropy -= normalizer 289 | centropy = tf.reshape(centropy, tf.shape(target)) 290 | 291 | mask = tf.cast(mask, tf.float32) 292 | per_sample_loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum(mask, -1) 293 | loss = tf.reduce_mean(per_sample_loss) 294 | 295 | if is_training and params.ctc_enable: 296 | assert labels is not None 297 | 298 | # batch x seq x dim 299 | encoding = state['encodes'] 300 | # CTC projection: adding one more symbol for blank 301 | ctc_label_size = params.src_vocab.size() + 1 302 | # Supporting CoLaCTC 303 | if params.cola_ctc_L > 0: 304 | ctc_label_size = params.cola_ctc_L + 1 305 | 306 | enc_logits = func.linear(encoding, ctc_label_size, scope="ctc_mapper") 307 | # seq dimension transpose 308 | enc_logits = tf.transpose(enc_logits, (1, 0, 2)) 309 | 310 | enc_logits = tf.to_float(enc_logits) 311 | 312 | with tf.name_scope('loss'): 313 | ctc_loss = tf.nn.ctc_loss(labels, enc_logits, tf.cast(tf.reduce_sum(state['mask'], -1), tf.int32), 314 | ignore_longer_outputs_than_inputs=True, 315 | preprocess_collapse_repeated=params.ctc_repeated) 316 | ctc_loss /= tf.reduce_sum(mask, -1) 317 | ctc_loss = tf.reduce_mean(ctc_loss) 318 | 319 | loss = params.ctc_alpha * ctc_loss + (1. - params.ctc_alpha) * loss 320 | 321 | if is_training and params.use_nafm: 322 | mse_loss = tf.reduce_sum((state["_source"] - state["target"]) ** 2, -1) 323 | per_sample_loss = tf.reduce_sum(mse_loss * state["_mask"], -1) / tf.reduce_sum(state["_mask"], -1) 324 | mse_loss = tf.reduce_mean(per_sample_loss) 325 | loss = loss + params.nafm_alpha * mse_loss 326 | 327 | # these mask tricks mainly used to deal with zero shapes, such as [0, 1] 328 | loss = tf.cond(tf.equal(tf.shape(target)[0], 0), 329 | lambda: tf.constant(0, tf.float32), 330 | lambda: loss) 331 | 332 | return loss, logits, state, per_sample_loss 333 | 334 | 335 | def train_fn(features, params, initializer=None): 336 | with tf.variable_scope(params.scope_name or "model", 337 | initializer=initializer, 338 | reuse=tf.AUTO_REUSE, 339 | dtype=tf.as_dtype(dtype.floatx()), 340 | custom_getter=dtype.float32_variable_storage_getter): 341 | state = encoder(features['source'], params) 342 | loss, logits, state, _ = decoder(features['target'], state, params, 343 | labels=features['label'] if params.ctc_enable else None) 344 | 345 | return { 346 | "loss": loss 347 | } 348 | 349 | 350 | def score_fn(features, params, initializer=None): 351 | params = copy.copy(params) 352 | params = util.closing_dropout(params) 353 | params.label_smooth = 0.0 354 | params.audio_dither=0.0 355 | with tf.variable_scope(params.scope_name or "model", 356 | initializer=initializer, 357 | reuse=tf.AUTO_REUSE, 358 | dtype=tf.as_dtype(dtype.floatx()), 359 | custom_getter=dtype.float32_variable_storage_getter): 360 | state = encoder(features['source'], params) 361 | _, _, _, scores = decoder(features['target'], state, params) 362 | 363 | return { 364 | "score": scores 365 | } 366 | 367 | 368 | def infer_fn(params): 369 | params = copy.copy(params) 370 | params = util.closing_dropout(params) 371 | # NOTICE!@!!! 372 | params.audio_dither=0.0 373 | 374 | def encoding_fn(source): 375 | with tf.variable_scope(params.scope_name or "model", 376 | reuse=tf.AUTO_REUSE, 377 | dtype=tf.as_dtype(dtype.floatx()), 378 | custom_getter=dtype.float32_variable_storage_getter): 379 | state = encoder(source, params) 380 | state["decoder"] = { 381 | "state": state["decoder_initializer"] 382 | } 383 | return state 384 | 385 | def decoding_fn(target, state, time): 386 | with tf.variable_scope(params.scope_name or "model", 387 | reuse=tf.AUTO_REUSE, 388 | dtype=tf.as_dtype(dtype.floatx()), 389 | custom_getter=dtype.float32_variable_storage_getter): 390 | if params.search_mode == "cache": 391 | state['time'] = time 392 | step_loss, step_logits, step_state, _ = decoder( 393 | target, state, params) 394 | del state['time'] 395 | else: 396 | estate = encoder(state, params) 397 | estate['dev_decode'] = True 398 | _, step_logits, _, _ = decoder(target, estate, params) 399 | step_state = state 400 | 401 | return step_logits, step_state 402 | 403 | return encoding_fn, decoding_fn 404 | 405 | 406 | # register the model, with a unique name 407 | model.model_register("transformer", train_fn, score_fn, infer_fn) 408 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | -------------------------------------------------------------------------------- /modules/initializer.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | from utils import dtype 9 | 10 | 11 | def get_initializer(initializer, initializer_gain): 12 | tfdtype = tf.as_dtype(dtype.floatx()) 13 | 14 | if initializer == "uniform": 15 | max_val = initializer_gain 16 | return tf.random_uniform_initializer(-max_val, max_val, dtype=tfdtype) 17 | elif initializer == "normal": 18 | return tf.random_normal_initializer(0.0, initializer_gain, dtype=tfdtype) 19 | elif initializer == "normal_unit_scaling": 20 | return tf.variance_scaling_initializer(initializer_gain, 21 | mode="fan_avg", 22 | distribution="normal", 23 | dtype=tfdtype) 24 | elif initializer == "uniform_unit_scaling": 25 | return tf.variance_scaling_initializer(initializer_gain, 26 | mode="fan_avg", 27 | distribution="uniform", 28 | dtype=tfdtype) 29 | else: 30 | tf.logging.warn("Unrecognized initializer: %s" % initializer) 31 | tf.logging.warn("Return to default initializer: glorot_uniform_initializer") 32 | return tf.glorot_uniform_initializer(dtype=tfdtype) 33 | 34 | 35 | def scale_initializer(scale, initializer): 36 | """Rescale the value given by initializer""" 37 | tfdtype = tf.as_dtype(dtype.floatx()) 38 | 39 | def _initializer(shape, dtype=tfdtype, partition_info=None): 40 | value = initializer(shape, dtype=dtype, partition_info=partition_info) 41 | value *= scale 42 | 43 | return value 44 | 45 | return _initializer 46 | -------------------------------------------------------------------------------- /modules/speech.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import functools 8 | import numpy as np 9 | import scipy.signal 10 | import tensorflow as tf 11 | 12 | from utils import util 13 | 14 | 15 | def add_delta_deltas(filterbanks, name=None): 16 | """Compute time first and second-order derivative channels. 17 | Args: 18 | filterbanks: float32 tensor with shape [batch_size, len, num_bins, 1] 19 | name: scope name 20 | Returns: 21 | float32 tensor with shape [batch_size, len, num_bins, 3] 22 | """ 23 | delta_filter = np.array([2, 1, 0, -1, -2]) 24 | delta_delta_filter = scipy.signal.convolve(delta_filter, delta_filter, "full") 25 | 26 | delta_filter_stack = np.array( 27 | [[0] * 4 + [1] + [0] * 4, [0] * 2 + list(delta_filter) + [0] * 2, 28 | list(delta_delta_filter)], 29 | dtype=np.float32).T[:, None, None, :] 30 | 31 | delta_filter_stack /= np.sqrt( 32 | np.sum(delta_filter_stack ** 2, axis=0, keepdims=True)) 33 | 34 | filterbanks = tf.nn.conv2d( 35 | filterbanks, tf.cast(tf.constant(delta_filter_stack), tf.float32), [1, 1, 1, 1], "SAME", data_format="NHWC", 36 | name=name) 37 | return filterbanks 38 | 39 | 40 | def compute_mel_filterbank_features( 41 | waveforms, 42 | sample_rate=16000, dither=1.0 / np.iinfo(np.int16).max, preemphasis=0.97, 43 | frame_length=25, frame_step=10, fft_length=None, 44 | window_fn=functools.partial(tf.signal.hann_window, periodic=True), 45 | lower_edge_hertz=80.0, upper_edge_hertz=7600.0, num_mel_bins=80, 46 | log_noise_floor=1e-3, apply_mask=True): 47 | """implement mel-filterbank extraction using tf ops. 48 | args: 49 | waveforms: float32 tensor with shape [batch_size, max_len] 50 | sample_rate: sampling rate of the waveform 51 | dither: stddev of gaussian noise added to waveform to prevent quantization 52 | artefacts 53 | preemphasis: waveform high-pass filtering constant 54 | frame_length: frame length in ms 55 | frame_step: frame_step in ms 56 | fft_length: number of fft bins 57 | window_fn: windowing function 58 | lower_edge_hertz: lowest frequency of the filterbank 59 | upper_edge_hertz: highest frequency of the filterbank 60 | num_mel_bins: filterbank size 61 | log_noise_floor: clip small values to prevent numeric overflow in log 62 | apply_mask: when working on a batch of samples, set padding frames to zero 63 | returns: 64 | filterbanks: a float32 tensor with shape [batch_size, len, num_bins, 1] 65 | masks: masks to indicate padded positions [batch_size, len] 66 | """ 67 | # `stfts` is a complex64 tensor representing the short-time fourier 68 | # transform of each signal in `signals`. its shape is 69 | # [batch_size, ?, fft_unique_bins] 70 | # where fft_unique_bins = fft_length // 2 + 1 71 | 72 | # find the wave length: the largest index for which the value is !=0 73 | # note that waveforms samples that are exactly 0.0 are quite common, so 74 | # simply doing sum(waveforms != 0, axis=-1) will not work correctly. 75 | # [batch_size]: padding is ok to indicate meaningless points 76 | wav_lens = tf.reduce_max( 77 | tf.expand_dims(tf.range(tf.shape(waveforms)[1]), 0) * 78 | tf.to_int32(tf.not_equal(waveforms, 0.0)), 79 | axis=-1) + 1 80 | # adding small noise to the speech for robust modeling 81 | if dither > 0: 82 | waveforms += tf.random_normal(tf.shape(waveforms), stddev=dither) 83 | # time difference, a normal operation to pre-process speech 84 | if preemphasis > 0: 85 | waveforms = waveforms[:, 1:] - preemphasis * waveforms[:, :-1] 86 | wav_lens -= 1 87 | # frame_length: number of samples in one frame 88 | frame_length = int(frame_length * sample_rate / 1e3) 89 | # frame_step: step size => number of frames = sample number // frame_step 90 | frame_step = int(frame_step * sample_rate / 1e3) 91 | if fft_length is None: 92 | fft_length = int(2 ** (np.ceil(np.log2(frame_length)))) 93 | 94 | # convert a sequence of audio signals into [num_frames, frame_length] 95 | # and then apply sfft operation 96 | # [batch_size, num_frames, fft_unique_bins] 97 | stfts = tf.signal.stft( 98 | waveforms, 99 | frame_length=frame_length, 100 | frame_step=frame_step, 101 | fft_length=fft_length, 102 | window_fn=window_fn, 103 | pad_end=True) 104 | # [batch_size, num_frames, frame_length] 105 | frames = tf.signal.frame( 106 | waveforms, 107 | frame_length=frame_length, 108 | frame_step=frame_step, 109 | pad_end=True) 110 | 111 | # num_frames: [batch_size] for each sample 112 | stft_lens = (wav_lens + (frame_step - 1)) // frame_step 113 | # [batch_size, num_frames]: 1 => valid, 0 => invalid 114 | masks = tf.to_float(tf.less_equal( 115 | tf.expand_dims(tf.range(tf.shape(stfts)[1]), 0), 116 | tf.expand_dims(stft_lens, 1))) 117 | 118 | # an energy spectrogram is the magnitude of the complex-valued stft. 119 | # a float32 tensor of shape [batch_size, ?, 257]. 120 | magnitude_spectrograms = tf.abs(stfts) 121 | 122 | # warp the linear-scale, magnitude spectrograms into the mel-scale. 123 | num_spectrogram_bins = magnitude_spectrograms.shape[-1].value 124 | linear_to_mel_weight_matrix = ( 125 | tf.signal.linear_to_mel_weight_matrix( 126 | num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, 127 | upper_edge_hertz)) 128 | mel_spectrograms = tf.tensordot( 129 | magnitude_spectrograms, linear_to_mel_weight_matrix, 1) 130 | # note: shape inference for tensordot does not currently handle this case. 131 | mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate( 132 | linear_to_mel_weight_matrix.shape[-1:])) 133 | 134 | log_mel_sgram = tf.log(tf.maximum(log_noise_floor, mel_spectrograms)) 135 | 136 | if apply_mask: 137 | log_mel_sgram *= tf.expand_dims(masks, -1) 138 | frames *= tf.expand_dims(masks, -1) 139 | 140 | return tf.expand_dims(log_mel_sgram, -1, name="mel_sgrams"), masks, frames 141 | 142 | 143 | def extract_logmel_features(wav, hparams): 144 | """ extract logmel features from raw wav file 145 | 146 | args: 147 | wav: [batch, wavlength], 148 | hparams: hyper-parameters 149 | returns: 150 | features: [batch, num_frames, features] 151 | mask: [batch, num_frames] 152 | """ 153 | p = hparams 154 | d = p.audio_num_mel_bins 155 | mel_fbanks, masks, frames = compute_mel_filterbank_features( 156 | wav, 157 | sample_rate=p.audio_sample_rate, 158 | dither=p.audio_dither, 159 | preemphasis=p.audio_preemphasis, 160 | frame_length=p.audio_frame_length, 161 | frame_step=p.audio_frame_step, 162 | lower_edge_hertz=p.audio_lower_edge_hertz, 163 | upper_edge_hertz=p.audio_upper_edge_hertz, 164 | num_mel_bins=p.audio_num_mel_bins, 165 | apply_mask=True) 166 | if p.audio_add_delta_deltas: 167 | d *= 3 168 | mel_fbanks = add_delta_deltas(mel_fbanks) 169 | 170 | mfshp = util.shape_list(mel_fbanks) 171 | mel_fbanks = tf.reshape(mel_fbanks, [mfshp[0], mfshp[1], d]) 172 | masking = tf.expand_dims(masks, -1) 173 | 174 | # this replaces cmvn estimation on data 175 | var_epsilon = 1e-08 176 | mean = tf.reduce_sum(mel_fbanks * masking, keepdims=True, axis=1) / \ 177 | (tf.reduce_sum(masking, keepdims=True, axis=1) + var_epsilon) 178 | sqr_diff = tf.squared_difference(mel_fbanks, mean) 179 | variance = tf.reduce_sum(sqr_diff * masking, keepdims=True, axis=1) / \ 180 | (tf.reduce_sum(masking, keepdims=True, axis=1) + var_epsilon) 181 | 182 | mel_fbanks = (mel_fbanks - mean) * tf.rsqrt(variance + var_epsilon) 183 | 184 | return mel_fbanks, masks, frames 185 | 186 | -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bzhangGo/st_from_scratch/5b05e5f3c7b9955c24d3c91a9067af4c18554a36/overview.png -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import time 8 | import os 9 | import random 10 | import socket 11 | 12 | import numpy as np 13 | import tensorflow as tf 14 | import tensorflow.contrib as tc 15 | 16 | import models 17 | import main as graph 18 | from vocab import Vocab 19 | from utils.recorder import Recorder 20 | from utils import dtype, util 21 | 22 | 23 | logger = tf.get_logger() 24 | logger.propagate = False 25 | 26 | 27 | # define global initial parameters 28 | global_params = tc.training.HParams( 29 | # whether share source and target word embedding 30 | shared_source_target_embedding=False, 31 | # whether share target and softmax word embedding 32 | shared_target_softmax_embedding=True, 33 | 34 | # decoding maximum length: source length + decode_length 35 | decode_length=50, 36 | # beam size 37 | beam_size=4, 38 | # length penalty during beam search 39 | decode_alpha=0.6, 40 | decode_beta=1./6., 41 | # noise beam search with gumbel 42 | enable_noise_beam_search=False, 43 | # beam search temperature, sharp or flat prediction 44 | beam_search_temperature=1.0, 45 | # return top elements, not used 46 | top_beams=1, 47 | # which version of beam search to use 48 | # cache or dev 49 | search_mode="cache", 50 | 51 | # distance considered for PDP 52 | pdp_r=512, 53 | 54 | # speech feature number 55 | # not that meaningful, we extracted mel features of dimension 40 56 | # after applying deltas, the feature grows to 120 57 | audio_sample_rate=16000, 58 | audio_preemphasis=0.97, 59 | # note, disable it after training 60 | audio_dither=1.0 / np.iinfo(np.int16).max, 61 | audio_frame_length=25.0, 62 | audio_frame_step=10.0, 63 | audio_lower_edge_hertz=20.0, 64 | audio_upper_edge_hertz=8000.0, 65 | audio_num_mel_bins=80, 66 | audio_add_delta_deltas=True, 67 | 68 | # ASR pretrained model path 69 | asr_pretrain="", 70 | # whether filter variables from ASR initialization, such as not initlaize global steps 71 | filter_variables=False, 72 | 73 | # lrate decay 74 | # number of shards 75 | nstable=4, 76 | # warmup steps: start point for learning rate stop increaing 77 | warmup_steps=4000, 78 | # select strategy: noam, gnmt+, epoch, score and vanilla 79 | lrate_strategy="noam", 80 | # learning decay rate 81 | lrate_decay=0.5, 82 | # cosine learning rate schedule period 83 | cosine_period=5000, 84 | # cosine factor 85 | cosine_factor=1, 86 | 87 | # early stopping 88 | estop_patience=100, 89 | 90 | # initialization 91 | # type of initializer 92 | initializer="uniform", 93 | # initializer range control 94 | initializer_gain=0.08, 95 | 96 | # parameters for rnnsearch 97 | # encoder and decoder hidden size 98 | hidden_size=1000, 99 | # source and target embedding size 100 | embed_size=620, 101 | # dropout value 102 | dropout=0.1, 103 | relu_dropout=0.1, 104 | residual_dropout=0.1, 105 | # label smoothing value 106 | label_smooth=0.1, 107 | # model name 108 | model_name="transformer", 109 | # scope name 110 | scope_name="transformer", 111 | # filter size for transformer 112 | filter_size=2048, 113 | # attention dropout 114 | attention_dropout=0.1, 115 | # the number of encoder layers, valid for deep nmt 116 | num_encoder_layer=6, 117 | # the number of decoder layers, valid for deep nmt 118 | num_decoder_layer=6, 119 | # the number of attention heads 120 | num_heads=8, 121 | 122 | # sample rate * N / 100 123 | max_frame_len=100, 124 | max_text_len=100, 125 | # constant batch size at 'batch' mode for batch-based batching 126 | batch_size=80, 127 | # constant token size at 'token' mode for token-based batching 128 | token_size=3000, 129 | # token or batch-based data iterator 130 | batch_or_token='token', 131 | # batch size for decoding, i.e. number of source sentences decoded at the same time 132 | eval_batch_size=32, 133 | # whether shuffle batches during training 134 | shuffle_batch=True, 135 | # data leak buffer threshold 136 | data_leak_ratio=0.5, 137 | 138 | # whether use multiprocessing deal with data reading, default true 139 | process_num=1, 140 | # buffer size controls the number of sentences readed in one time, 141 | buffer_size=100, 142 | # a unique queue in multi-thread reading process 143 | input_queue_size=100, 144 | output_queue_size=100, 145 | 146 | # source vocabulary 147 | src_vocab_file="", 148 | # target vocabulary 149 | tgt_vocab_file="", 150 | # source train file 151 | src_train_path="", 152 | src_train_file="", 153 | # target train file 154 | tgt_train_file="", 155 | # ctc train file 156 | ctc_train_file="", 157 | # source development file 158 | src_dev_path="", 159 | src_dev_file="", 160 | # target development file 161 | tgt_dev_file="", 162 | # source test file 163 | src_test_path="", 164 | src_test_file="", 165 | # target test file 166 | tgt_test_file="", 167 | # output directory 168 | output_dir="", 169 | # output during testing 170 | test_output="", 171 | 172 | # adam optimizer hyperparameters 173 | beta1=0.9, 174 | beta2=0.999, 175 | epsilon=1e-9, 176 | # gradient clipping value 177 | clip_grad_norm=5.0, 178 | # the gradient norm upper bound, to avoid wired large gradient norm, only works for safe nan mode 179 | gnorm_upper_bound=1e20, 180 | # initial learning rate 181 | lrate=1e-5, 182 | # minimum learning rate 183 | min_lrate=0.0, 184 | # maximum learning rate 185 | max_lrate=1.0, 186 | 187 | # maximum epochs 188 | epoches=10, 189 | # the effective batch size is: batch/token size * update_cycle * num_gpus 190 | # sequential update cycle 191 | update_cycle=1, 192 | # the number of gpus 193 | gpus=[0], 194 | 195 | # enable safely handle nan 196 | safe_nan=False, 197 | # exponential moving average for stability, disabled by default 198 | ema_decay=-1., 199 | 200 | # enable training deep transformer 201 | deep_transformer_init=False, 202 | 203 | # print information every disp_freq training steps 204 | disp_freq=100, 205 | # evaluate on the development file every eval_freq steps 206 | eval_freq=10000, 207 | # save the model parameters every save_freq steps 208 | save_freq=5000, 209 | # print sample translations every sample_freq steps 210 | sample_freq=1000, 211 | # saved checkpoint number 212 | checkpoints=5, 213 | best_checkpoints=1, 214 | # the maximum training steps, program with stop if epochs or max_training_steps is meet 215 | max_training_steps=1000, 216 | 217 | # random control, not so well for tensorflow. 218 | random_seed=1234, 219 | # whether or not train from checkpoint 220 | train_continue=True, 221 | 222 | # provide interface to modify the default datatype 223 | default_dtype="float32", 224 | dtype_epsilon=1e-8, 225 | dtype_inf=1e8, 226 | loss_scale=1.0, 227 | 228 | # speech-specific settings 229 | sinusoid_posenc=True, 230 | max_poslen=2048, 231 | ctc_repeated=False, 232 | ctc_enable=False, 233 | ctc_alpha=0.3, # ctc loss factor 234 | enc_localize="log", 235 | dec_localize="none", 236 | encdec_localize="none", 237 | 238 | # cola ctc settings 239 | # -1: disable cola ctc, in our paper we set 256. 240 | cola_ctc_L=-1, 241 | 242 | # neural acoustic feature modeling 243 | use_nafm=False, 244 | nafm_alpha=0.05, 245 | 246 | ) 247 | 248 | flags = tf.flags 249 | flags.DEFINE_string("config", "", "Additional Mergable Parameters") 250 | flags.DEFINE_string("parameters", "", "Command Line Refinable Parameters") 251 | flags.DEFINE_string("name", "model", "Description of the training process for distinguishing") 252 | flags.DEFINE_string("mode", "train", "train or test or ensemble") 253 | 254 | 255 | # saving model configuration 256 | def save_parameters(params, output_dir): 257 | if not tf.gfile.Exists(output_dir): 258 | tf.gfile.MkDir(output_dir) 259 | 260 | param_name = os.path.join(output_dir, "param.json") 261 | with tf.gfile.Open(param_name, "w") as writer: 262 | tf.logging.info("Saving parameters into {}" 263 | .format(param_name)) 264 | writer.write(params.to_json()) 265 | 266 | 267 | # load model configuration 268 | def load_parameters(params, output_dir): 269 | param_name = os.path.join(output_dir, "param.json") 270 | param_name = os.path.abspath(param_name) 271 | 272 | if tf.gfile.Exists(param_name): 273 | tf.logging.info("Loading parameters from {}" 274 | .format(param_name)) 275 | with tf.gfile.Open(param_name, 'r') as reader: 276 | json_str = reader.readline() 277 | params.parse_json(json_str) 278 | return params 279 | 280 | 281 | # build training process recorder 282 | def setup_recorder(params): 283 | recorder = Recorder() 284 | # This is for early stopping, currently I did not use it 285 | recorder.bad_counter = 0 # start from 0 286 | recorder.estop = False 287 | 288 | recorder.lidx = -1 # local data index 289 | recorder.step = 0 # global step, start from 0 290 | recorder.epoch = 1 # epoch number, start from 1 291 | recorder.lrate = params.lrate # running learning rate 292 | recorder.history_scores = [] 293 | recorder.valid_script_scores = [] 294 | 295 | # trying to load saved recorder 296 | record_path = os.path.join(params.output_dir, "record.json") 297 | record_path = os.path.abspath(record_path) 298 | if tf.gfile.Exists(record_path): 299 | recorder.load_from_json(record_path) 300 | 301 | params.add_hparam('recorder', recorder) 302 | return params 303 | 304 | 305 | # print model configuration 306 | def print_parameters(params): 307 | tf.logging.info("The Used Configuration:") 308 | for k, v in params.values().items(): 309 | tf.logging.info("%s\t%s", k.ljust(20), str(v).ljust(20)) 310 | tf.logging.info("") 311 | 312 | 313 | def main(_): 314 | # set up logger 315 | tf.logging.set_verbosity(tf.logging.INFO) 316 | 317 | tf.logging.info("Welcome Using Zero :)") 318 | 319 | pid = os.getpid() 320 | tf.logging.info("Your pid is {0} and use the following command to force kill your running:\n" 321 | "'pkill -9 -P {0}; kill -9 {0}'".format(pid)) 322 | # On clusters, this could tell which machine you are running 323 | tf.logging.info("Your running machine name is {}".format(socket.gethostname())) 324 | 325 | # load registered models 326 | util.dynamic_load_module(models, prefix="models") 327 | 328 | params = global_params 329 | 330 | # try loading parameters 331 | # priority: command line > saver > default 332 | params.parse(flags.FLAGS.parameters) 333 | if os.path.exists(flags.FLAGS.config): 334 | params.override_from_dict(eval(open(flags.FLAGS.config).read())) 335 | params = load_parameters(params, params.output_dir) 336 | # override 337 | if os.path.exists(flags.FLAGS.config): 338 | params.override_from_dict(eval(open(flags.FLAGS.config).read())) 339 | params.parse(flags.FLAGS.parameters) 340 | 341 | # set up random seed 342 | random.seed(params.random_seed) 343 | np.random.seed(params.random_seed) 344 | tf.set_random_seed(params.random_seed) 345 | 346 | # loading vocabulary 347 | tf.logging.info("Begin Loading Vocabulary") 348 | start_time = time.time() 349 | params.src_vocab = Vocab(params.src_vocab_file) 350 | params.tgt_vocab = Vocab(params.tgt_vocab_file) 351 | tf.logging.info("End Loading Vocabulary, Source Vocab Size {}, " 352 | "Target Vocab Size {}, within {} seconds" 353 | .format(params.src_vocab.size(), params.tgt_vocab.size(), 354 | time.time() - start_time)) 355 | 356 | # print parameters 357 | print_parameters(params) 358 | 359 | # set up the default datatype 360 | dtype.set_floatx(params.default_dtype) 361 | dtype.set_epsilon(params.dtype_epsilon) 362 | dtype.set_inf(params.dtype_inf) 363 | 364 | mode = flags.FLAGS.mode 365 | if mode == "train": 366 | # save parameters 367 | save_parameters(params, params.output_dir) 368 | 369 | # load the recorder 370 | params = setup_recorder(params) 371 | 372 | graph.train(params) 373 | elif mode == "test": 374 | graph.evaluate(params) 375 | elif mode == "score": 376 | graph.scorer(params) 377 | else: 378 | tf.logging.error("Invalid mode: {}".format(mode)) 379 | 380 | 381 | if __name__ == '__main__': 382 | 383 | tf.app.run() 384 | -------------------------------------------------------------------------------- /scripts/checkpoint_averaging.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import argparse 8 | import operator 9 | import os 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | 15 | def parseargs(): 16 | msg = "Average checkpoints" 17 | usage = "average.py [] [-h | --help]" 18 | parser = argparse.ArgumentParser(description=msg, usage=usage) 19 | 20 | parser.add_argument("--path", type=str, required=True, 21 | help="checkpoint dir") 22 | parser.add_argument("--checkpoints", type=int, required=True, 23 | help="number of checkpoints to use") 24 | parser.add_argument("--output", type=str, help="output path") 25 | parser.add_argument("--gpu", type=int, default=0, 26 | help="the default gpu device index") 27 | 28 | return parser.parse_args() 29 | 30 | 31 | def get_checkpoints(path): 32 | if not tf.gfile.Exists(os.path.join(path, "checkpoint")): 33 | raise ValueError("Cannot find checkpoints in %s" % path) 34 | 35 | checkpoint_names = [] 36 | 37 | with tf.gfile.GFile(os.path.join(path, "checkpoint")) as fd: 38 | # Skip the first line 39 | fd.readline() 40 | for line in fd: 41 | name = line.strip().split(":")[-1].strip()[1:-1] 42 | key = int(name.split("-")[-1]) 43 | checkpoint_names.append((key, os.path.join(path, name))) 44 | 45 | sorted_names = sorted(checkpoint_names, key=operator.itemgetter(0), 46 | reverse=True) 47 | 48 | return [item[-1] for item in sorted_names] 49 | 50 | 51 | def checkpoint_exists(path): 52 | return (tf.gfile.Exists(path) or tf.gfile.Exists(path + ".meta") or 53 | tf.gfile.Exists(path + ".index")) 54 | 55 | 56 | def main(_): 57 | tf.logging.set_verbosity(tf.logging.INFO) 58 | checkpoints = get_checkpoints(FLAGS.path) 59 | checkpoints = checkpoints[:FLAGS.checkpoints] 60 | 61 | if not checkpoints: 62 | raise ValueError("No checkpoints provided for averaging.") 63 | 64 | checkpoints = [c for c in checkpoints if checkpoint_exists(c)] 65 | 66 | if not checkpoints: 67 | raise ValueError( 68 | "None of the provided checkpoints exist. %s" % FLAGS.checkpoints 69 | ) 70 | 71 | var_list = tf.contrib.framework.list_variables(checkpoints[0]) 72 | var_values, var_dtypes = {}, {} 73 | 74 | for (name, shape) in var_list: 75 | if not name.startswith("global_step"): 76 | var_values[name] = np.zeros(shape) 77 | 78 | for checkpoint in checkpoints: 79 | reader = tf.contrib.framework.load_checkpoint(checkpoint) 80 | for name in var_values: 81 | tensor = reader.get_tensor(name) 82 | var_dtypes[name] = tensor.dtype 83 | var_values[name] += tensor 84 | tf.logging.info("Read from checkpoint %s", checkpoint) 85 | 86 | # Average checkpoints 87 | for name in var_values: 88 | var_values[name] /= len(checkpoints) 89 | 90 | tf_vars = [ 91 | tf.get_variable(name, shape=var_values[name].shape, 92 | dtype=var_dtypes[name]) for name in var_values 93 | ] 94 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 95 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 96 | global_step = tf.Variable(0, name="global_step", trainable=False, 97 | dtype=tf.int64) 98 | saver = tf.train.Saver(tf.global_variables()) 99 | 100 | sess_config = tf.ConfigProto(allow_soft_placement=True) 101 | sess_config.gpu_options.allow_growth = True 102 | sess_config.gpu_options.visible_device_list = "%s" % FLAGS.gpu 103 | 104 | with tf.Session(config=sess_config) as sess: 105 | sess.run(tf.global_variables_initializer()) 106 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 107 | var_values.items()): 108 | sess.run(assign_op, {p: value}) 109 | saved_name = os.path.join(FLAGS.output, "average") 110 | saver.save(sess, saved_name, global_step=global_step) 111 | 112 | tf.logging.info("Averaged checkpoints saved in %s", saved_name) 113 | 114 | params_pattern = os.path.join(FLAGS.path, "*.json") 115 | params_files = tf.gfile.Glob(params_pattern) 116 | 117 | for name in params_files: 118 | new_name = name.replace(FLAGS.path.rstrip("/"), 119 | FLAGS.output.rstrip("/")) 120 | tf.gfile.Copy(name, new_name, overwrite=True) 121 | 122 | 123 | if __name__ == "__main__": 124 | FLAGS = parseargs() 125 | tf.app.run() 126 | -------------------------------------------------------------------------------- /scripts/chrF.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Author: Rico Sennrich 4 | 5 | """Compute chrF3 for machine translation evaluation 6 | 7 | Reference: 8 | Maja Popović (2015). chrF: character n-gram F-score for automatic MT evaluation. In Proceedings of the Tenth Workshop on Statistical Machine Translationn, pages 392–395, Lisbon, Portugal. 9 | """ 10 | 11 | from __future__ import print_function, unicode_literals, division 12 | 13 | import sys 14 | import codecs 15 | import io 16 | import argparse 17 | 18 | from collections import defaultdict 19 | 20 | # hack for python2/3 compatibility 21 | from io import open 22 | argparse.open = open 23 | 24 | def create_parser(): 25 | parser = argparse.ArgumentParser( 26 | formatter_class=argparse.RawDescriptionHelpFormatter, 27 | description="learn BPE-based word segmentation") 28 | 29 | parser.add_argument( 30 | '--ref', '-r', type=argparse.FileType('r'), required=True, 31 | metavar='PATH', 32 | help="Reference file") 33 | parser.add_argument( 34 | '--hyp', type=argparse.FileType('r'), metavar='PATH', 35 | default=sys.stdin, 36 | help="Hypothesis file (default: stdin).") 37 | parser.add_argument( 38 | '--beta', '-b', type=float, default=3, 39 | metavar='FLOAT', 40 | help="beta parameter (default: '%(default)s')") 41 | parser.add_argument( 42 | '--ngram', '-n', type=int, default=6, 43 | metavar='INT', 44 | help="ngram order (default: '%(default)s')") 45 | parser.add_argument( 46 | '--space', '-s', action='store_true', 47 | help="take spaces into account (default: '%(default)s')") 48 | parser.add_argument( 49 | '--precision', action='store_true', 50 | help="report precision (default: '%(default)s')") 51 | parser.add_argument( 52 | '--recall', action='store_true', 53 | help="report recall (default: '%(default)s')") 54 | 55 | return parser 56 | 57 | def extract_ngrams(words, max_length=4, spaces=False): 58 | 59 | if not spaces: 60 | words = ''.join(words.split()) 61 | else: 62 | words = words.strip() 63 | 64 | results = defaultdict(lambda: defaultdict(int)) 65 | for length in range(max_length): 66 | for start_pos in range(len(words)): 67 | end_pos = start_pos + length + 1 68 | if end_pos <= len(words): 69 | results[length][tuple(words[start_pos: end_pos])] += 1 70 | return results 71 | 72 | 73 | def get_correct(ngrams_ref, ngrams_test, correct, total): 74 | 75 | for rank in ngrams_test: 76 | for chain in ngrams_test[rank]: 77 | total[rank] += ngrams_test[rank][chain] 78 | if chain in ngrams_ref[rank]: 79 | correct[rank] += min(ngrams_test[rank][chain], ngrams_ref[rank][chain]) 80 | 81 | return correct, total 82 | 83 | 84 | def f1(correct, total_hyp, total_ref, max_length, beta=3, smooth=0): 85 | 86 | precision = 0 87 | recall = 0 88 | 89 | for i in range(max_length): 90 | if total_hyp[i] + smooth and total_ref[i] + smooth: 91 | precision += (correct[i] + smooth) / (total_hyp[i] + smooth) 92 | recall += (correct[i] + smooth) / (total_ref[i] + smooth) 93 | 94 | precision /= max_length 95 | recall /= max_length 96 | 97 | return (1 + beta**2) * (precision*recall) / ((beta**2 * precision) + recall), precision, recall 98 | 99 | def main(args): 100 | 101 | correct = [0]*args.ngram 102 | total = [0]*args.ngram 103 | total_ref = [0]*args.ngram 104 | for line in args.ref: 105 | line2 = args.hyp.readline() 106 | 107 | ngrams_ref = extract_ngrams(line, max_length=args.ngram, spaces=args.space) 108 | ngrams_test = extract_ngrams(line2, max_length=args.ngram, spaces=args.space) 109 | 110 | get_correct(ngrams_ref, ngrams_test, correct, total) 111 | 112 | for rank in ngrams_ref: 113 | for chain in ngrams_ref[rank]: 114 | total_ref[rank] += ngrams_ref[rank][chain] 115 | 116 | chrf, precision, recall = f1(correct, total, total_ref, args.ngram, args.beta) 117 | 118 | print('chrF3: {0:.4f}'.format(chrf)) 119 | if args.precision: 120 | print('chrPrec: {0:.4f}'.format(precision)) 121 | if args.recall: 122 | print('chrRec: {0:.4f}'.format(recall)) 123 | 124 | if __name__ == '__main__': 125 | 126 | # python 2/3 compatibility 127 | if sys.version_info < (3, 0): 128 | sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) 129 | sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) 130 | sys.stdin = codecs.getreader('UTF-8')(sys.stdin) 131 | else: 132 | sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') 133 | sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') 134 | sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', write_through=True, line_buffering=True) 135 | 136 | parser = create_parser() 137 | args = parser.parse_args() 138 | 139 | main(args) 140 | -------------------------------------------------------------------------------- /scripts/multi-bleu-detok.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # This file uses the internal tokenization of mteval-v13a.pl, 7 | # giving the exact same (case-sensitive) results on untokenized text. 8 | # Using this script with detokenized output and untokenized references is 9 | # preferrable over multi-bleu.perl, since scores aren't affected by tokenization differences. 10 | # 11 | # like multi-bleu.perl , it supports plain text input and multiple references. 12 | 13 | # $Id$ 14 | use warnings; 15 | use strict; 16 | 17 | binmode(STDIN, ":utf8"); 18 | use open ':encoding(UTF-8)'; 19 | 20 | my $lowercase = 0; 21 | if ($ARGV[0] eq "-lc") { 22 | $lowercase = 1; 23 | shift; 24 | } 25 | 26 | my $stem = $ARGV[0]; 27 | if (!defined $stem) { 28 | print STDERR "usage: multi-bleu-detok.pl [-lc] reference < hypothesis\n"; 29 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 30 | exit(1); 31 | } 32 | 33 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 34 | 35 | my @REF; 36 | my $ref=0; 37 | while(-e "$stem$ref") { 38 | &add_to_ref("$stem$ref",\@REF); 39 | $ref++; 40 | } 41 | &add_to_ref($stem,\@REF) if -e $stem; 42 | die("ERROR: could not find reference file $stem") unless scalar @REF; 43 | 44 | # add additional references explicitly specified on the command line 45 | shift; 46 | foreach my $stem (@ARGV) { 47 | &add_to_ref($stem,\@REF) if -e $stem; 48 | } 49 | 50 | 51 | 52 | sub add_to_ref { 53 | my ($file,$REF) = @_; 54 | my $s=0; 55 | if ($file =~ /.gz$/) { 56 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 57 | } else { 58 | open(REF,$file) or die "Can't read $file"; 59 | } 60 | while() { 61 | chop; 62 | $_ = tokenization($_); 63 | push @{$$REF[$s++]}, $_; 64 | } 65 | close(REF); 66 | } 67 | 68 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 69 | my $s=0; 70 | while() { 71 | chop; 72 | $_ = lc if $lowercase; 73 | $_ = tokenization($_); 74 | my @WORD = split; 75 | my %REF_NGRAM = (); 76 | my $length_translation_this_sentence = scalar(@WORD); 77 | my ($closest_diff,$closest_length) = (9999,9999); 78 | foreach my $reference (@{$REF[$s]}) { 79 | # print "$s $_ <=> $reference\n"; 80 | $reference = lc($reference) if $lowercase; 81 | my @WORD = split(' ',$reference); 82 | my $length = scalar(@WORD); 83 | my $diff = abs($length_translation_this_sentence-$length); 84 | if ($diff < $closest_diff) { 85 | $closest_diff = $diff; 86 | $closest_length = $length; 87 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 88 | } elsif ($diff == $closest_diff) { 89 | $closest_length = $length if $length < $closest_length; 90 | # from two references with the same closeness to me 91 | # take the *shorter* into account, not the "first" one. 92 | } 93 | for(my $n=1;$n<=4;$n++) { 94 | my %REF_NGRAM_N = (); 95 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 96 | my $ngram = "$n"; 97 | for(my $w=0;$w<$n;$w++) { 98 | $ngram .= " ".$WORD[$start+$w]; 99 | } 100 | $REF_NGRAM_N{$ngram}++; 101 | } 102 | foreach my $ngram (keys %REF_NGRAM_N) { 103 | if (!defined($REF_NGRAM{$ngram}) || 104 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 105 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 106 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 107 | } 108 | } 109 | } 110 | } 111 | $length_translation += $length_translation_this_sentence; 112 | $length_reference += $closest_length; 113 | for(my $n=1;$n<=4;$n++) { 114 | my %T_NGRAM = (); 115 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 116 | my $ngram = "$n"; 117 | for(my $w=0;$w<$n;$w++) { 118 | $ngram .= " ".$WORD[$start+$w]; 119 | } 120 | $T_NGRAM{$ngram}++; 121 | } 122 | foreach my $ngram (keys %T_NGRAM) { 123 | $ngram =~ /^(\d+) /; 124 | my $n = $1; 125 | # my $corr = 0; 126 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 127 | $TOTAL[$n] += $T_NGRAM{$ngram}; 128 | if (defined($REF_NGRAM{$ngram})) { 129 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 130 | $CORRECT[$n] += $T_NGRAM{$ngram}; 131 | # $corr = $T_NGRAM{$ngram}; 132 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 133 | } 134 | else { 135 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 136 | # $corr = $REF_NGRAM{$ngram}; 137 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 138 | } 139 | } 140 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 141 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 142 | } 143 | } 144 | $s++; 145 | } 146 | my $brevity_penalty = 1; 147 | my $bleu = 0; 148 | 149 | my @bleu=(); 150 | 151 | for(my $n=1;$n<=4;$n++) { 152 | if (defined ($TOTAL[$n])){ 153 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 154 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 155 | }else{ 156 | $bleu[$n]=0; 157 | } 158 | } 159 | 160 | if ($length_reference==0){ 161 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 162 | exit(1); 163 | } 164 | 165 | if ($length_translation<$length_reference) { 166 | $brevity_penalty = exp(1-$length_reference/$length_translation); 167 | } 168 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 169 | my_log( $bleu[2] ) + 170 | my_log( $bleu[3] ) + 171 | my_log( $bleu[4] ) ) / 4) ; 172 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 173 | 100*$bleu, 174 | 100*$bleu[1], 175 | 100*$bleu[2], 176 | 100*$bleu[3], 177 | 100*$bleu[4], 178 | $brevity_penalty, 179 | $length_translation / $length_reference, 180 | $length_translation, 181 | $length_reference; 182 | 183 | sub my_log { 184 | return -9999999999 unless $_[0]; 185 | return log($_[0]); 186 | } 187 | 188 | 189 | 190 | sub tokenization 191 | { 192 | my ($norm_text) = @_; 193 | 194 | # language-independent part: 195 | $norm_text =~ s///g; # strip "skipped" tags 196 | $norm_text =~ s/-\n//g; # strip end-of-line hyphenation and join lines 197 | $norm_text =~ s/\n/ /g; # join lines 198 | $norm_text =~ s/"/"/g; # convert SGML tag for quote to " 199 | $norm_text =~ s/&/&/g; # convert SGML tag for ampersand to & 200 | $norm_text =~ s/</ 201 | $norm_text =~ s/>/>/g; # convert SGML tag for greater-than to < 202 | 203 | # language-dependent part (assuming Western languages): 204 | $norm_text = " $norm_text "; 205 | $norm_text =~ s/([\{-\~\[-\` -\&\(-\+\:-\@\/])/ $1 /g; # tokenize punctuation 206 | $norm_text =~ s/([^0-9])([\.,])/$1 $2 /g; # tokenize period and comma unless preceded by a digit 207 | $norm_text =~ s/([\.,])([^0-9])/ $1 $2/g; # tokenize period and comma unless followed by a digit 208 | $norm_text =~ s/([0-9])(-)/$1 $2 /g; # tokenize dash when preceded by a digit 209 | $norm_text =~ s/\s+/ /g; # one space only between words 210 | $norm_text =~ s/^\s+//; # no leading space 211 | $norm_text =~ s/\s+$//; # no trailing space 212 | 213 | return $norm_text; 214 | } 215 | -------------------------------------------------------------------------------- /scripts/multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | # add additional references explicitly specified on the command line 35 | shift; 36 | foreach my $stem (@ARGV) { 37 | &add_to_ref($stem,\@REF) if -e $stem; 38 | } 39 | 40 | 41 | 42 | sub add_to_ref { 43 | my ($file,$REF) = @_; 44 | my $s=0; 45 | if ($file =~ /.gz$/) { 46 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 47 | } else { 48 | open(REF,$file) or die "Can't read $file"; 49 | } 50 | while() { 51 | chomp; 52 | push @{$$REF[$s++]}, $_; 53 | } 54 | close(REF); 55 | } 56 | 57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 58 | my $s=0; 59 | while() { 60 | chomp; 61 | $_ = lc if $lowercase; 62 | my @WORD = split; 63 | my %REF_NGRAM = (); 64 | my $length_translation_this_sentence = scalar(@WORD); 65 | my ($closest_diff,$closest_length) = (9999,9999); 66 | foreach my $reference (@{$REF[$s]}) { 67 | # print "$s $_ <=> $reference\n"; 68 | $reference = lc($reference) if $lowercase; 69 | my @WORD = split(' ',$reference); 70 | my $length = scalar(@WORD); 71 | my $diff = abs($length_translation_this_sentence-$length); 72 | if ($diff < $closest_diff) { 73 | $closest_diff = $diff; 74 | $closest_length = $length; 75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 76 | } elsif ($diff == $closest_diff) { 77 | $closest_length = $length if $length < $closest_length; 78 | # from two references with the same closeness to me 79 | # take the *shorter* into account, not the "first" one. 80 | } 81 | for(my $n=1;$n<=4;$n++) { 82 | my %REF_NGRAM_N = (); 83 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 84 | my $ngram = "$n"; 85 | for(my $w=0;$w<$n;$w++) { 86 | $ngram .= " ".$WORD[$start+$w]; 87 | } 88 | $REF_NGRAM_N{$ngram}++; 89 | } 90 | foreach my $ngram (keys %REF_NGRAM_N) { 91 | if (!defined($REF_NGRAM{$ngram}) || 92 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 93 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 94 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 95 | } 96 | } 97 | } 98 | } 99 | $length_translation += $length_translation_this_sentence; 100 | $length_reference += $closest_length; 101 | for(my $n=1;$n<=4;$n++) { 102 | my %T_NGRAM = (); 103 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 104 | my $ngram = "$n"; 105 | for(my $w=0;$w<$n;$w++) { 106 | $ngram .= " ".$WORD[$start+$w]; 107 | } 108 | $T_NGRAM{$ngram}++; 109 | } 110 | foreach my $ngram (keys %T_NGRAM) { 111 | $ngram =~ /^(\d+) /; 112 | my $n = $1; 113 | # my $corr = 0; 114 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 115 | $TOTAL[$n] += $T_NGRAM{$ngram}; 116 | if (defined($REF_NGRAM{$ngram})) { 117 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 118 | $CORRECT[$n] += $T_NGRAM{$ngram}; 119 | # $corr = $T_NGRAM{$ngram}; 120 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 121 | } 122 | else { 123 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 124 | # $corr = $REF_NGRAM{$ngram}; 125 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 126 | } 127 | } 128 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 129 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 130 | } 131 | } 132 | $s++; 133 | } 134 | my $brevity_penalty = 1; 135 | my $bleu = 0; 136 | 137 | my @bleu=(); 138 | 139 | for(my $n=1;$n<=4;$n++) { 140 | if (defined ($TOTAL[$n])){ 141 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 142 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 143 | }else{ 144 | $bleu[$n]=0; 145 | } 146 | } 147 | 148 | if ($length_reference==0){ 149 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 150 | exit(1); 151 | } 152 | 153 | if ($length_translation<$length_reference) { 154 | $brevity_penalty = exp(1-$length_reference/$length_translation); 155 | } 156 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 157 | my_log( $bleu[2] ) + 158 | my_log( $bleu[3] ) + 159 | my_log( $bleu[4] ) ) / 4) ; 160 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 161 | 100*$bleu, 162 | 100*$bleu[1], 163 | 100*$bleu[2], 164 | 100*$bleu[3], 165 | 100*$bleu[4], 166 | $brevity_penalty, 167 | $length_translation / $length_reference, 168 | $length_translation, 169 | $length_reference; 170 | 171 | 172 | print STDERR "It is not advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n"; 173 | 174 | sub my_log { 175 | return -9999999999 unless $_[0]; 176 | return log($_[0]); 177 | } 178 | -------------------------------------------------------------------------------- /scripts/shuffle_corpus.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import argparse 8 | import numpy 9 | import h5py 10 | 11 | 12 | def parseargs(): 13 | parser = argparse.ArgumentParser(description="Shuffle corpus") 14 | 15 | parser.add_argument("--corpus", nargs="+", required=True, 16 | help="input corpora") 17 | parser.add_argument("--audio", type=str, default="none", 18 | help="audio corpora") 19 | parser.add_argument("--suffix", type=str, default="shuf", 20 | help="Suffix of output files") 21 | parser.add_argument("--seed", type=int, help="Random seed") 22 | 23 | return parser.parse_args() 24 | 25 | 26 | def main(args): 27 | name = args.corpus 28 | suffix = "." + args.suffix 29 | stream = [open(item, "r") for item in name] 30 | data = [fd.readlines() for fd in stream] 31 | minlen = min([len(lines) for lines in data]) 32 | 33 | if args.seed: 34 | numpy.random.seed(args.seed) 35 | 36 | indices = numpy.arange(minlen) 37 | numpy.random.shuffle(indices) 38 | 39 | newstream = [open(item + suffix, "w") for item in name] 40 | 41 | if args.audio != "none": 42 | audiostream = h5py.File(args.audio + suffix + ".h5", 'w') 43 | audioreader = h5py.File(args.audio, 'r') 44 | 45 | for h, idx in enumerate(indices.tolist()): 46 | lines = [item[idx] for item in data] 47 | 48 | for line, fd in zip(lines, newstream): 49 | fd.write(line) 50 | 51 | if args.audio != "none": 52 | audio = audioreader["audio_{}".format(idx)][()] 53 | audiostream.create_dataset("audio_{}".format(h), data=audio) 54 | 55 | if args.audio != "none": 56 | audioreader.close() 57 | audiostream.close() 58 | 59 | for fdr, fdw in zip(stream, newstream): 60 | fdr.close() 61 | fdw.close() 62 | 63 | 64 | if __name__ == "__main__": 65 | parsed_args = parseargs() 66 | main(parsed_args) 67 | -------------------------------------------------------------------------------- /search.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | 9 | from utils import util, dtype 10 | from collections import namedtuple 11 | from tensorflow.python.util import nest 12 | 13 | 14 | class BeamSearchState(namedtuple("BeamSearchState", 15 | ("inputs", "state", "finish"))): 16 | pass 17 | 18 | 19 | def beam_search(features, encoding_fn, decoding_fn, params): 20 | decode_length = params.decode_length 21 | beam_size = params.beam_size 22 | alpha = params.decode_alpha 23 | beta = params.decode_beta 24 | eos_id = params.tgt_vocab.eos() 25 | pad_id = params.tgt_vocab.pad() 26 | 27 | batch_size = tf.shape(features["source"])[0] 28 | if params.search_mode == "cache": 29 | model_state = encoding_fn(features["source"]) 30 | else: 31 | model_state = features["source"] 32 | 33 | src_mask = model_state['mask'] 34 | source_length = tf.to_int32(tf.reduce_sum(src_mask, -1) * beta) 35 | max_target_length = source_length + decode_length 36 | 37 | model_state = nest.map_structure( 38 | lambda x: util.expand_tile_dims(x, beam_size, axis=1), 39 | model_state 40 | ) 41 | 42 | # in our mixed precision mode, we finally convert logits into tf.float32 43 | tfdtype = tf.float32 44 | 45 | # [batch, beam] 46 | init_log_probs = tf.constant([[0.] + [tfdtype.min] * (beam_size - 1)], dtype=tfdtype) 47 | init_log_probs = tf.tile(init_log_probs, [batch_size, 1]) 48 | init_scores = tf.zeros_like(init_log_probs) 49 | # [batch, beam, 1], begin-of-sequence 50 | init_seq = tf.fill([batch_size, beam_size, 1], params.tgt_vocab.pad()) 51 | init_finish_seq = tf.zeros_like(init_seq) 52 | # [batch, beam] 53 | init_finish_scores = tf.fill([batch_size, beam_size], tfdtype.min) 54 | init_finish_flags = tf.zeros([batch_size, beam_size], tf.bool) 55 | 56 | def cache_init(prev_seq, state): 57 | # used to initialize some caches 58 | # this is because pre-compute these caches is to hard, 59 | # so let's use one dummy run to obtain them. 60 | flat_prev_seqs = util.merge_neighbor_dims(prev_seq, axis=0) 61 | flat_prev_state = nest.map_structure( 62 | lambda x: util.merge_neighbor_dims(x, axis=0), 63 | state 64 | ) 65 | _, step_state = decoding_fn( 66 | flat_prev_seqs[:, -1:], flat_prev_state, 0) 67 | 68 | new_state = nest.map_structure( 69 | lambda x: util.unmerge_neighbor_dims(x, batch_size, axis=0), 70 | step_state 71 | ) 72 | new_state = util.dict_update(new_state, state) 73 | 74 | return new_state 75 | 76 | if params.search_mode == "cache": 77 | model_state = cache_init(init_seq, model_state) 78 | 79 | bsstate = BeamSearchState( 80 | inputs=(init_seq, init_log_probs, init_scores), 81 | state=model_state, 82 | finish=(init_finish_seq, init_finish_scores, init_finish_flags) 83 | ) 84 | 85 | def _not_finished(time, bsstate): 86 | # if the maximum time step is reached, or 87 | # all samples in one batch satisfy that the worst finished sequence 88 | # score is not better than the best alive sequence score 89 | alive_log_probs = bsstate.inputs[1] 90 | finish_scores = bsstate.finish[1] 91 | finish_flags = bsstate.finish[2] 92 | 93 | # upper bound of length penality 94 | max_length_penality = tf.pow( 95 | (5. + tf.cast(max_target_length, tfdtype)) / 6., alpha) 96 | best_alive_score = alive_log_probs[:, 0] / max_length_penality 97 | 98 | # minimum score among finished sequences alone 99 | worst_finish_score = tf.reduce_min( 100 | finish_scores * tf.cast(finish_flags, tfdtype), 1) 101 | # deal with unfinished instances, which is set to `tf.float32.min` 102 | unfinish_mask = 1. - tf.cast(tf.reduce_any(finish_flags, 1), tfdtype) 103 | worst_finish_score += unfinish_mask * tfdtype.min 104 | 105 | # boundary 106 | bound_is_met = tf.reduce_all(tf.greater(worst_finish_score, 107 | best_alive_score)) 108 | 109 | # length constraint 110 | length_is_met = tf.reduce_any( 111 | tf.less(time, tf.cast(max_target_length, tf.int32))) 112 | 113 | return tf.logical_and(tf.logical_not(bound_is_met), length_is_met) 114 | 115 | def _step_fn(time, bsstate): 116 | """one expansion step of beam search process""" 117 | 118 | # 1. feed previous predictions, and get the next probabilities 119 | # generating beam * vocab_size predictions 120 | prev_seq, prev_log_probs, prev_scores = bsstate.inputs 121 | 122 | flat_prev_seqs = util.merge_neighbor_dims(prev_seq, axis=0) 123 | flat_prev_state = nest.map_structure( 124 | lambda x: util.merge_neighbor_dims(x, axis=0), 125 | bsstate.state 126 | ) 127 | 128 | # curr_logits: [batch * beam, vocab_size] 129 | if params.search_mode == "cache": 130 | decode_target = flat_prev_seqs[:, -1:] 131 | else: 132 | # introducing `dev` mode into search function 133 | # this mainly is for model developing, because when developing new models 134 | # perhaps your new model is very complex, with complex internal dependencies 135 | # at this time, maintaining the cache state is rather boring and usually make 136 | # mistakes. To this end, I add the dev mode, that the model only uses 137 | # source sentence and partial target sentence at the cost of slower decoding. 138 | # Definitely disabled if you want higher decoding efficiency. 139 | decode_target = tf.pad( 140 | flat_prev_seqs[:, 1:], [[0, 0], [0, 1]], constant_values=1) 141 | step_logits, step_state = decoding_fn( 142 | decode_target, flat_prev_state, time) 143 | # add gumbel noise into the logits, simulate gumbel top-k sampling without replacement 144 | if params.enable_noise_beam_search: 145 | step_logits += util.gumbel_noise(util.shape_list(step_logits)) 146 | # apply temperature decoding 147 | step_logits /= params.beam_search_temperature 148 | step_log_probs = util.log_prob_from_logits(step_logits) 149 | vocab_size = util.shape_list(step_log_probs)[-1] 150 | 151 | # force decoding 152 | eos_mask = tf.cast(tf.equal(tf.range(vocab_size), eos_id), tfdtype) 153 | step_log_probs = tf.cond(dtype.tf_to_float(time) < dtype.tf_to_float(1.), 154 | lambda: step_log_probs + tf.expand_dims(eos_mask, 0) * - dtype.inf(), 155 | lambda: step_log_probs) 156 | 157 | # expand to [batch, beam, vocab_size] 158 | step_log_probs = util.unmerge_neighbor_dims(step_log_probs, 159 | batch_size, axis=0) 160 | step_state = nest.map_structure( 161 | lambda x: util.unmerge_neighbor_dims(x, batch_size, axis=0), 162 | step_state 163 | ) 164 | 165 | # 2. compute top-k scored next predictions 166 | # reducing beam * vocab_size to 2 * beam 167 | # [batch, beam, 1] + [batch, beam, vocab_size] 168 | curr_log_probs = tf.expand_dims(prev_log_probs, 2) + step_log_probs 169 | length_penality = tf.pow((5.0 + tf.cast(time + 1, tfdtype)) / 6., alpha) 170 | curr_scores = curr_log_probs / length_penality 171 | 172 | # [batch, beam * vocab_size] 173 | curr_flat_scores = util.merge_neighbor_dims(curr_scores, axis=1) 174 | # [batch, 2 * beam] 175 | topk_scores, topk_indices = tf.nn.top_k( 176 | curr_flat_scores, 2 * beam_size) 177 | 178 | # index manipulation, [batch, 2 * beam] 179 | curr_beam_indices = topk_indices // vocab_size 180 | curr_symbol_indices = topk_indices % vocab_size 181 | beam2_pos = util.batch_coordinates(batch_size, 2 * beam_size) 182 | curr_coordinates = tf.stack([beam2_pos, curr_beam_indices], axis=2) 183 | 184 | # extract candidate sequences 185 | # [batch, 2 * beam, time + 1] 186 | curr_seq = tf.gather_nd(prev_seq, curr_coordinates) 187 | curr_seq = tf.concat([curr_seq, 188 | tf.expand_dims(curr_symbol_indices, 2)], 2) 189 | 190 | # 3. handling alive sequences 191 | # reducing 2 * beam to beam 192 | curr_fin_flags = tf.logical_or( 193 | tf.equal(curr_symbol_indices, eos_id), 194 | # if time step exceeds the maximum decoding length, should stop 195 | tf.expand_dims( 196 | tf.greater_equal(time, tf.cast(max_target_length, tf.int32)), 1) 197 | ) 198 | alive_scores = topk_scores + tf.cast(curr_fin_flags, tfdtype) * tfdtype.min 199 | # [batch, 2 * beam] -> [batch, beam] 200 | alive_scores, alive_indices = tf.nn.top_k(alive_scores, beam_size) 201 | beam_pos = util.batch_coordinates(batch_size, beam_size) 202 | alive_coordinates = tf.stack([beam_pos, alive_indices], axis=2) 203 | alive_seq = tf.gather_nd(curr_seq, alive_coordinates) 204 | alive_beam_indices = tf.gather_nd(curr_beam_indices, alive_coordinates) 205 | beam_coordinates = tf.stack([beam_pos, alive_beam_indices], axis=2) 206 | alive_state = nest.map_structure( 207 | lambda x: tf.gather_nd(x, beam_coordinates), 208 | step_state 209 | ) 210 | alive_log_probs = alive_scores * length_penality 211 | 212 | # 4. handle finished sequences 213 | # reducing 3 * beam to beam 214 | prev_fin_seq, prev_fin_scores, prev_fin_flags = bsstate.finish 215 | # [batch, 2 * beam] 216 | curr_fin_scores = topk_scores + (1.0 - tf.cast(curr_fin_flags, tfdtype)) * tfdtype.min 217 | # [batch, 3 * beam] 218 | fin_flags = tf.concat([prev_fin_flags, curr_fin_flags], axis=1) 219 | fin_scores = tf.concat([prev_fin_scores, curr_fin_scores], axis=1) 220 | # [batch, beam] 221 | fin_scores, fin_indices = tf.nn.top_k(fin_scores, beam_size) 222 | fin_coordinates = tf.stack([beam_pos, fin_indices], axis=2) 223 | fin_flags = tf.gather_nd(fin_flags, fin_coordinates) 224 | pad_seq = tf.fill([batch_size, beam_size, 1], 225 | tf.constant(pad_id, tf.int32)) 226 | prev_fin_seq = tf.concat([prev_fin_seq, pad_seq], axis=2) 227 | fin_seq = tf.concat([prev_fin_seq, curr_seq], axis=1) 228 | fin_seq = tf.gather_nd(fin_seq, fin_coordinates) 229 | 230 | next_state = BeamSearchState( 231 | inputs=(alive_seq, alive_log_probs, alive_scores), 232 | state=alive_state, 233 | finish=(fin_seq, fin_scores, fin_flags) 234 | ) 235 | 236 | return time + 1, next_state 237 | 238 | time = tf.constant(0, tf.int32, name="time") 239 | shape_invariants = BeamSearchState( 240 | inputs=(tf.TensorShape([None, None, None]), 241 | tf.TensorShape([None, None]), 242 | tf.TensorShape([None, None])), 243 | state=nest.map_structure( 244 | lambda x: util.get_shape_invariants(x), 245 | bsstate.state 246 | ), 247 | finish=(tf.TensorShape([None, None, None]), 248 | tf.TensorShape([None, None]), 249 | tf.TensorShape([None, None])) 250 | ) 251 | outputs = tf.while_loop(_not_finished, _step_fn, [time, bsstate], 252 | shape_invariants=[tf.TensorShape([]), 253 | shape_invariants], 254 | parallel_iterations=32, 255 | back_prop=False) 256 | final_state = outputs[1] 257 | 258 | alive_seqs = final_state.inputs[0] 259 | init_scores = final_state.inputs[2] 260 | final_seqs = final_state.finish[0] 261 | final_scores = final_state.finish[1] 262 | final_flags = final_state.finish[2] 263 | 264 | alive_seqs.set_shape([None, beam_size, None]) 265 | final_seqs.set_shape([None, beam_size, None]) 266 | 267 | final_seqs = tf.where(tf.reduce_any(final_flags, 1), final_seqs, 268 | alive_seqs) 269 | final_scores = tf.where(tf.reduce_any(final_flags, 1), final_scores, 270 | init_scores) 271 | 272 | return { 273 | 'seq': final_seqs[:, :, 1:], 274 | 'score': final_scores 275 | } 276 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | -------------------------------------------------------------------------------- /utils/cycle.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | from utils import dtype 9 | 10 | 11 | def _zero_variables(variables, name=None): 12 | ops = [] 13 | 14 | for var in variables: 15 | with tf.device(var.device): 16 | op = var.assign(tf.zeros_like(var)) 17 | ops.append(op) 18 | 19 | return tf.group(*ops, name=name or "zero_variables") 20 | 21 | 22 | def _replicate_variables(variables, device=None, suffix="Replica"): 23 | new_vars = [] 24 | 25 | for var in variables: 26 | device = device or var.device 27 | with tf.device(device): 28 | name = var.op.name + "/{}".format(suffix) 29 | new_vars.append(tf.Variable(tf.zeros_like(var), 30 | name=name, trainable=False)) 31 | 32 | return new_vars 33 | 34 | 35 | def _collect_gradients(gradients, variables): 36 | ops = [] 37 | 38 | for grad, var in zip(gradients, variables): 39 | if isinstance(grad, tf.Tensor): 40 | ops.append(tf.assign_add(var, grad)) 41 | else: 42 | ops.append(tf.scatter_add(var, grad.indices, grad.values)) 43 | 44 | return tf.group(*ops, name="collect_gradients") 45 | 46 | 47 | def create_train_op(named_scalars, grads_and_vars, optimizer, global_step, params): 48 | tf.get_variable_scope().set_dtype(tf.as_dtype(dtype.floatx())) 49 | 50 | gradients = [item[0] for item in grads_and_vars] 51 | variables = [item[1] for item in grads_and_vars] 52 | 53 | if params.update_cycle == 1: 54 | zero_variables_op = tf.no_op("zero_variables") 55 | collect_op = tf.no_op("collect_op") 56 | else: 57 | named_vars = {} 58 | for name in named_scalars: 59 | named_var = tf.Variable(tf.zeros([], dtype=tf.float32), 60 | name="{}/CTrainOpReplica".format(name), 61 | trainable=False) 62 | named_vars[name] = named_var 63 | count_var = tf.Variable(tf.zeros([], dtype=tf.as_dtype(dtype.floatx())), 64 | name="count/CTrainOpReplica", 65 | trainable=False) 66 | slot_variables = _replicate_variables(variables, suffix='CTrainOpReplica') 67 | zero_variables_op = _zero_variables( 68 | slot_variables + [count_var] + list(named_vars.values())) 69 | 70 | collect_ops = [] 71 | # collect gradients 72 | collect_grads_op = _collect_gradients(gradients, slot_variables) 73 | collect_ops.append(collect_grads_op) 74 | 75 | # collect other scalars 76 | for name in named_scalars: 77 | scalar = named_scalars[name] 78 | named_var = named_vars[name] 79 | collect_op = tf.assign_add(named_var, scalar) 80 | collect_ops.append(collect_op) 81 | # collect counting variable 82 | collect_count_op = tf.assign_add(count_var, 1.0) 83 | collect_ops.append(collect_count_op) 84 | 85 | collect_op = tf.group(*collect_ops, name="collect_op") 86 | scale = 1.0 / (tf.cast(count_var, tf.float32) + 1.0) 87 | gradients = [scale * (g + s) 88 | for (g, s) in zip(gradients, slot_variables)] 89 | 90 | for name in named_scalars: 91 | named_scalars[name] = scale * ( 92 | named_scalars[name] + named_vars[name]) 93 | 94 | grand_norm = tf.global_norm(gradients) 95 | param_norm = tf.global_norm(variables) 96 | 97 | # Gradient clipping 98 | if isinstance(params.clip_grad_norm or None, float): 99 | gradients, _ = tf.clip_by_global_norm(gradients, 100 | params.clip_grad_norm, 101 | use_norm=grand_norm) 102 | 103 | # Update variables 104 | grads_and_vars = list(zip(gradients, variables)) 105 | train_op = optimizer.apply_gradients(grads_and_vars, global_step) 106 | 107 | ops = { 108 | "zero_op": zero_variables_op, 109 | "collect_op": collect_op, 110 | "train_op": train_op 111 | } 112 | 113 | # apply ema 114 | if params.ema_decay > 0.: 115 | tf.logging.info('Using Exp Moving Average to train the model with decay {}.'.format(params.ema_decay)) 116 | ema = tf.train.ExponentialMovingAverage(decay=params.ema_decay, num_updates=global_step) 117 | ema_op = ema.apply(variables) 118 | with tf.control_dependencies([ops['train_op']]): 119 | ops['train_op'] = tf.group(ema_op) 120 | bck_vars = _replicate_variables(variables, suffix="CTrainOpBackUpReplica") 121 | 122 | ops['ema_backup_op'] = tf.group(*(tf.assign(bck, var.read_value()) 123 | for bck, var in zip(bck_vars, variables))) 124 | ops['ema_restore_op'] = tf.group(*(tf.assign(var, bck.read_value()) 125 | for bck, var in zip(bck_vars, variables))) 126 | ops['ema_assign_op'] = tf.group(*(tf.assign(var, ema.average(var).read_value()) 127 | for var in variables)) 128 | 129 | ret = named_scalars 130 | ret.update({ 131 | "gradient_norm": grand_norm, 132 | "parameter_norm": param_norm, 133 | }) 134 | 135 | return ret, ops 136 | -------------------------------------------------------------------------------- /utils/dtype.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | # Copied from Keras 11 | 12 | # the type of float to use throughout the session. 13 | _FLOATX = 'float32' 14 | _EPSILON = 1e-8 15 | _INF = 1e8 16 | 17 | 18 | def epsilon(): 19 | return _EPSILON 20 | 21 | 22 | def set_epsilon(e): 23 | global _EPSILON 24 | _EPSILON = e 25 | 26 | 27 | def inf(): 28 | return _INF 29 | 30 | 31 | def set_inf(e): 32 | global _INF 33 | _INF = e 34 | 35 | 36 | def floatx(): 37 | return _FLOATX 38 | 39 | 40 | def set_floatx(floatx): 41 | global _FLOATX 42 | if floatx not in {'float16', 'float32', 'float64'}: 43 | raise ValueError('Unknown floatx type: ' + str(floatx)) 44 | _FLOATX = str(floatx) 45 | 46 | 47 | def np_to_float(x): 48 | return np.asarray(x, dtype=_FLOATX) 49 | 50 | 51 | def tf_to_float(x): 52 | return tf.cast(x, tf.as_dtype(floatx())) 53 | 54 | 55 | def float32_variable_storage_getter(getter, name, shape=None, dtype=None, 56 | initializer=None, regularizer=None, 57 | trainable=True, 58 | *args, **kwargs): 59 | """Custom variable getter that forces trainable variables to be stored in 60 | float32 precision and then casts them to the training precision. 61 | """ 62 | storage_dtype = tf.float32 if trainable else dtype 63 | variable = getter(name, shape, dtype=storage_dtype, 64 | initializer=initializer, regularizer=regularizer, 65 | trainable=trainable, 66 | *args, **kwargs) 67 | if trainable and dtype != tf.float32: 68 | variable = tf.cast(variable, dtype) 69 | return variable 70 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import math 9 | import sys 10 | import argparse 11 | from collections import defaultdict 12 | 13 | '''https://github.com/DeepLearnXMU/Otem-Utem''' 14 | 15 | 16 | def _get_refs(ref): 17 | """Get reference files, ref indicates the path, following the multi-bleu tradition.""" 18 | refs = [] 19 | 20 | # return the existed reference file, and assume only one reference 21 | if os.path.exists(ref): 22 | refs.append(ref) 23 | else: 24 | # the reference does not exist, check whether the indexed file exist, usually multiple references 25 | if not os.path.exists(ref + "0"): 26 | print('Error: could not find proper reference file ', ref + "0", file=sys.stderr) 27 | sys.exit(1) 28 | 29 | # enumerate all possible references 30 | while True: 31 | cur_file = ref + "%d" % len(refs) 32 | if not os.path.exists(cur_file): 33 | break 34 | refs.append(cur_file) 35 | return refs 36 | 37 | 38 | def _tokenize(s): 39 | """An interface for tokenization, currently we rely on external tokenizers 40 | i.e. We assume all the inputs have been well-tokenized 41 | """ 42 | return s.split() 43 | 44 | 45 | def _read(f, lc=False): 46 | """Reading all contents inside the file `f`, "lc" tells whether open the 'lower case' function.""" 47 | return [_tokenize(line.strip()) if not lc else _tokenize(line.strip().lower()) 48 | for line in open(f, 'rU').readlines()] 49 | 50 | 51 | def _get_ngram_list(sentence, ngram=4): 52 | """Read all ngrams inside the sentences, default up to 4.""" 53 | ngram_dict = defaultdict(int) 54 | for n in range(1, ngram + 1): 55 | for start in range(0, len(sentence) - (n - 1)): 56 | ngram_str = ' '.join(sentence[start:start + n]) 57 | ngram_dict[ngram_str] += 1 58 | return ngram_dict 59 | 60 | 61 | def _common_strategies(choices): 62 | """Generate some common strategies to deal with multiple references.""" 63 | return {'min': min(choices), 64 | 'max': max(choices), 65 | 'avg': sum(choices) * 1. / len(choices) 66 | } 67 | 68 | 69 | def _get_length_reference(ref_lengths, cand_length, strategy="best_match"): 70 | """When multiple references exist, return the length of a preferred references.""" 71 | 72 | # different strategy, no one is absolutely correct 73 | strategies = _common_strategies(ref_lengths) 74 | 75 | # the best matched cases 76 | length, diff = 9999, 9999 77 | for r in ref_lengths: 78 | d = abs(r - cand_length) 79 | 80 | if d < diff: 81 | length, diff = r, d 82 | elif d == diff: 83 | if r < length: 84 | length = r 85 | strategies['best_match'] = length 86 | 87 | return strategies[strategy] 88 | 89 | 90 | def _safe_log(d): 91 | """Deal with invalid inputs.""" 92 | if d <= 0: 93 | print("WARNING, a non-positive number is processed by log", file=sys.stderr) 94 | return -9999999999 95 | 96 | return math.log(d) 97 | 98 | 99 | def otem(cand, refs, bp='closest', smooth=False, n=2, weights=None): 100 | """Over-Translation Evaluation Metric, LOWER is BETTER""" 101 | len_c = 0 102 | len_ref = 0 103 | 104 | tngram_corpus, ongram_corpus = defaultdict(int), defaultdict(int) 105 | 106 | # scan all candidates in the corpus 107 | for candidate, references in zip(cand, refs): 108 | len_c += len(candidate) 109 | len_ref += _get_length_reference([len(r) for r in references], len(candidate), 110 | strategy='best_match' if bp == 'closest' else 'min') 111 | 112 | # get all n-grams in current candidate from n = 1...4 113 | cngrams = _get_ngram_list(candidate, ngram=n) 114 | 115 | tngram_sample, ongram_sample = defaultdict(int), defaultdict(int) 116 | 117 | for reference in references: 118 | rngrams = _get_ngram_list(reference, ngram=n) 119 | 120 | for ngram in cngrams: 121 | tngram_sample[ngram] = cngrams[ngram] 122 | 123 | ngram_otem = 0 124 | 125 | # case 1: current n-gram doesn't appear in current reference at all, 126 | # but appears in current candidate more than once 127 | if ngram not in rngrams: 128 | if cngrams[ngram] > 1: 129 | ngram_otem = cngrams[ngram] - 1 130 | elif cngrams[ngram] > rngrams[ngram]: 131 | # case 2: the n-gram occurs in both reference and candidate, but the occurrence is more in candidate 132 | ngram_otem = cngrams[ngram] - rngrams[ngram] 133 | 134 | if ngram_otem > 0: 135 | if ongram_sample[ngram] == 0: 136 | ongram_sample[ngram] = ngram_otem 137 | else: 138 | ongram_sample[ngram] = min(ongram_sample[ngram], ngram_otem) 139 | 140 | for ngram in cngrams: 141 | nl = len(ngram.split()) 142 | tngram_corpus[nl] += tngram_sample[ngram] 143 | ongram_corpus[nl] += ongram_sample[ngram] 144 | 145 | if len_ref == 0: 146 | return 0. 147 | 148 | lp = 1. 149 | multi_otem = defaultdict(int) 150 | 151 | for i in range(1, n + 1): 152 | if i in tngram_corpus: 153 | if smooth and i > 1: 154 | ongram_corpus[i] += 1 155 | tngram_corpus[i] += 1 156 | multi_otem[i] += ongram_corpus[i] * 1. / tngram_corpus[i] 157 | 158 | # Over-translation: candidate prefered to be longer, so penalize long translations 159 | if len_c >= len_ref: 160 | lp = math.exp(1. - len_ref * 1. / len_c) 161 | 162 | if weights is None: 163 | weights = [1. / n for _ in range(n)] 164 | assert len(weights) == n, 'ERROR: the length of weights ({}) should be equal to n ({})'.format(len(weights), n) 165 | 166 | score = lp * math.exp(sum(_safe_log(multi_otem[i+1]) * weights[i] for i in range(n))) 167 | 168 | return score 169 | 170 | 171 | def utem(cand, refs, bp='closest', smooth=False, n=4, weights=None): 172 | """Under-Translation Evaluation Metric, LOWER is BETTER""" 173 | len_c = 0 174 | len_ref = 0 175 | 176 | tngram_corpus, mngram_corpus = defaultdict(int), defaultdict(int) 177 | 178 | # scan all candidates in the corpus 179 | for candidate, references in zip(cand, refs): 180 | len_c += len(candidate) 181 | len_ref += _get_length_reference([len(r) for r in references], len(candidate), 182 | strategy='best_match' if bp == 'closest' else 'min') 183 | 184 | # get all n-grams in current candidate from n = 1...4 185 | cngrams = _get_ngram_list(candidate, ngram=n) 186 | 187 | tngram_sample, mngram_sample = defaultdict(list), defaultdict(list) 188 | 189 | for reference in references: 190 | rngrams = _get_ngram_list(reference, ngram=n) 191 | 192 | tngram_ref, mngram_ref = defaultdict(int), defaultdict(int) 193 | 194 | # count the number of under-translation n-grams in current candidate compared with current reference 195 | for ngram in rngrams: 196 | nl = len(ngram.split()) 197 | 198 | tngram_ref[nl] += rngrams[ngram] 199 | 200 | # case 1: current n-gram doesn't appear in the candidate at all 201 | if ngram not in cngrams: 202 | mngram_ref[nl] += rngrams[ngram] 203 | elif rngrams[ngram] > cngrams[ngram]: 204 | # case 2: the n-gram occurs in both reference and candidate, but the occurrence is more in reference 205 | mngram_ref[nl] += rngrams[ngram] - cngrams[ngram] 206 | 207 | for i in tngram_ref: 208 | tngram_sample[i].append(tngram_ref[i]) 209 | mngram_sample[i].append(mngram_ref[i]) 210 | 211 | for i in tngram_sample: 212 | m = _common_strategies(mngram_sample[i])['min'] 213 | t = _common_strategies(tngram_sample[i])['max'] 214 | 215 | mngram_corpus[i] += m 216 | tngram_corpus[i] += t 217 | 218 | if len_ref == 0: 219 | return 0. 220 | 221 | lp = 1. 222 | multi_utem = defaultdict(int) 223 | for i in range(1, n + 1): 224 | if i in tngram_corpus: 225 | if smooth and i > 1: 226 | mngram_corpus[i] += 1 227 | tngram_corpus[i] += 1 228 | multi_utem[i] += mngram_corpus[i] * 1. / tngram_corpus[i] 229 | 230 | # Under-translation: candidates perfered to be shorter, so penalize short translations 231 | if len_c <= len_ref: 232 | lp = math.exp(1. - len_c * 1. / len_ref) 233 | 234 | if weights is None: 235 | weights = [1. / n for _ in range(n)] 236 | assert len(weights) == n, 'ERROR: the length of weights ({}) should be equal to n ({})'.format(len(weights), n) 237 | 238 | score = lp * math.exp(sum(_safe_log(multi_utem[i+1]) * weights[i] for i in range(n))) 239 | 240 | return score 241 | 242 | 243 | def bleu(cand, refs, bp='closest', smooth=False, n=4, weights=None): 244 | """BLEU Evaluation Metric, LARGER is BETTER""" 245 | len_c = 0 246 | len_ref = 0 247 | 248 | tngram_corpus, bngram_corpus = defaultdict(int), defaultdict(int) 249 | 250 | # scan all candidates in the corpus 251 | for candidate, references in zip(cand, refs): 252 | len_c += len(candidate) 253 | len_ref += _get_length_reference([len(r) for r in references], len(candidate), 254 | strategy='best_match' if bp == 'closest' else 'min') 255 | 256 | # get all n-grams in current candidate from n = 1...4 257 | cngrams = _get_ngram_list(candidate, ngram=n) 258 | 259 | tngram_sample, bngram_sample = defaultdict(int), defaultdict(int) 260 | 261 | for reference in references: 262 | rngrams = _get_ngram_list(reference, ngram=n) 263 | 264 | for ngram in cngrams: 265 | tngram_sample[ngram] = cngrams[ngram] 266 | if ngram in rngrams: 267 | bngram_sample[ngram] = max(bngram_sample[ngram], min(rngrams[ngram], cngrams[ngram])) 268 | 269 | for ngram in cngrams: 270 | nl = len(ngram.split()) 271 | tngram_corpus[nl] += tngram_sample[ngram] 272 | bngram_corpus[nl] += bngram_sample[ngram] 273 | 274 | if len_ref == 0: 275 | return 0. 276 | 277 | lp = 1. 278 | multi_bleu = defaultdict(int) 279 | 280 | for i in range(1, n + 1): 281 | if i in tngram_corpus: 282 | if smooth and i > 1: 283 | bngram_corpus[i] += 1 284 | tngram_corpus[i] += 1 285 | multi_bleu[i] += bngram_corpus[i] * 1. / tngram_corpus[i] 286 | 287 | # BLEU: candidate prefered to be longer, so penalize long translations 288 | if len_c <= len_ref: 289 | lp = math.exp(1. - len_ref * 1. / len_c) 290 | 291 | if weights is None: 292 | weights = [1. / n for _ in range(n)] 293 | assert len(weights) == n, 'ERROR: the length of weights ({}) should be equal to n ({})'.format(len(weights), n) 294 | 295 | score = lp * math.exp(sum(_safe_log(multi_bleu[i+1]) * weights[i] for i in range(n))) 296 | 297 | return score 298 | 299 | 300 | if __name__ == "__main__": 301 | parser = argparse.ArgumentParser( 302 | description='Over-translation evaluation metric (OTEM), under-translation evaluation metric (UTEM), ' 303 | 'BLEU on multiple references.') 304 | parser.add_argument('-lc', help='Lowercase, i.e case-insensitive setting', action='store_true') 305 | parser.add_argument('-bp', help='Length penalty', default='closest', choices=['shortest', 'closest']) 306 | parser.add_argument('candidate', help='The candidate translation generated by MT system') 307 | parser.add_argument('reference', help='The references like reference or reference0, reference1, ...') 308 | 309 | args = parser.parse_args() 310 | 311 | cand = args.candidate 312 | refs = _get_refs(args.reference) 313 | 314 | cand_sentences = _read(cand, args.lc) 315 | refs_sentences = [_read(ref, args.lc) for ref in refs] 316 | 317 | assert len(cand_sentences) == len(refs_sentences[0]), \ 318 | 'ERROR: the length of candidate and reference must be the same.' 319 | 320 | refs_sentences = list(zip(*refs_sentences)) 321 | 322 | otem_score = otem(cand_sentences, refs_sentences, n=2) # OTEM-2 323 | utem_score = utem(cand_sentences, refs_sentences, n=4) # UTEM-4 324 | bleu_score = bleu(cand_sentences, refs_sentences, n=4) # BLEU-4 325 | 326 | print('OTEM-2/UTEM-4/BLEU-4: {}/{}/{}'.format(otem_score, utem_score, bleu_score)) 327 | -------------------------------------------------------------------------------- /utils/parallel.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import six 8 | import tensorflow as tf 9 | import tensorflow.contrib as tc 10 | 11 | from tensorflow.python.training import device_setter 12 | from tensorflow.python.framework import device as pydev 13 | from tensorflow.core.framework import node_def_pb2 14 | 15 | from utils import dtype 16 | 17 | 18 | def local_device_setter(num_devices=1, 19 | ps_device_type='cpu', 20 | worker_device='/cpu:0', 21 | ps_ops=None, 22 | ps_strategy=None): 23 | if ps_ops is None: 24 | ps_ops = ['Variable', 'VariableV2', 'VarHandleOp'] 25 | 26 | if ps_strategy is None: 27 | ps_strategy = device_setter._RoundRobinStrategy(num_devices) 28 | if not six.callable(ps_strategy): 29 | raise TypeError("ps_strategy must be callable") 30 | 31 | def _local_device_chooser(op): 32 | current_device = pydev.DeviceSpec.from_string(op.device or "") 33 | 34 | node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def 35 | if node_def.op in ps_ops: 36 | ps_device_spec = pydev.DeviceSpec.from_string( 37 | '/{}:{}'.format(ps_device_type, ps_strategy(op))) 38 | 39 | ps_device_spec.merge_from(current_device) 40 | return ps_device_spec.to_string() 41 | else: 42 | worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "") 43 | worker_device_spec.merge_from(current_device) 44 | return worker_device_spec.to_string() 45 | 46 | return _local_device_chooser 47 | 48 | 49 | def _maybe_repeat(x, n): 50 | if isinstance(x, list): 51 | assert len(x) == n 52 | return x 53 | else: 54 | return [x] * n 55 | 56 | 57 | def _reshape_output(outputs): 58 | # assumption: or outputs[0] are all tensor lists/tuples, 59 | # or outputs[0] are dictionaries 60 | if isinstance(outputs[0], (tuple, list)): 61 | outputs = list(zip(*outputs)) 62 | outputs = tuple([list(o) for o in outputs]) 63 | else: 64 | if not isinstance(outputs[0], dict): 65 | return outputs 66 | 67 | assert isinstance(outputs[0], dict), \ 68 | 'invalid data type %s' % type(outputs[0]) 69 | 70 | combine_outputs = {} 71 | for key in outputs[0]: 72 | combine_outputs[key] = [o[key] for o in outputs] 73 | outputs = combine_outputs 74 | 75 | return outputs 76 | 77 | 78 | # Data-level parallelism 79 | def data_parallelism(device_type, num_devices, fn, *args, **kwargs): 80 | # Replicate args and kwargs 81 | if args: 82 | new_args = [_maybe_repeat(arg, num_devices) for arg in args] 83 | # Transpose 84 | new_args = [list(x) for x in zip(*new_args)] 85 | else: 86 | new_args = [[] for _ in range(num_devices)] 87 | 88 | new_kwargs = [{} for _ in range(num_devices)] 89 | 90 | for k, v in kwargs.items(): 91 | vals = _maybe_repeat(v, num_devices) 92 | 93 | for i in range(num_devices): 94 | new_kwargs[i][k] = vals[i] 95 | 96 | fns = _maybe_repeat(fn, num_devices) 97 | 98 | # Now make the parallel call. 99 | outputs = [] 100 | for i in range(num_devices): 101 | worker = "/{}:{}".format(device_type, i) 102 | if device_type == 'cpu': 103 | _device_setter = local_device_setter(worker_device=worker) 104 | else: 105 | _device_setter = local_device_setter( 106 | ps_device_type='gpu', 107 | worker_device=worker, 108 | ps_strategy=tc.training.GreedyLoadBalancingStrategy( 109 | num_devices, tc.training.byte_size_load_fn) 110 | ) 111 | 112 | with tf.variable_scope(tf.get_variable_scope(), reuse=bool(i != 0), 113 | dtype=tf.as_dtype(dtype.floatx())): 114 | with tf.name_scope("tower_%d" % i): 115 | with tf.device(_device_setter): 116 | outputs.append(fns[i](*new_args[i], **new_kwargs[i])) 117 | 118 | return _reshape_output(outputs) 119 | 120 | 121 | def parallel_model(model_fn, features, devices, use_cpu=False): 122 | device_type = 'gpu' 123 | num_devices = len(devices) 124 | 125 | if use_cpu: 126 | device_type = 'cpu' 127 | num_devices = 1 128 | 129 | outputs = data_parallelism(device_type, num_devices, model_fn, features) 130 | 131 | return outputs 132 | 133 | 134 | def average_gradients(tower_grads, mask=None): 135 | """Modified from Bilm""" 136 | 137 | # optimizer for single device 138 | if len(tower_grads) == 1: 139 | return tower_grads[0] 140 | 141 | # calculate average gradient for each shared variable across all GPUs 142 | def _deduplicate_indexed_slices(values, indices): 143 | """Sums `values` associated with any non-unique `indices`.""" 144 | unique_indices, new_index_positions = tf.unique(indices) 145 | summed_values = tf.unsorted_segment_sum( 146 | values, new_index_positions, 147 | tf.shape(unique_indices)[0]) 148 | return summed_values, unique_indices 149 | 150 | average_grads = [] 151 | for grad_and_vars in zip(*tower_grads): 152 | # Note that each grad_and_vars looks like the following: 153 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 154 | # We need to average the gradients across each GPU. 155 | 156 | g0, v0 = grad_and_vars[0] 157 | 158 | if g0 is None: 159 | # no gradient for this variable, skip it 160 | tf.logging.warn("{} has no gradient".format(v0.name)) 161 | average_grads.append((g0, v0)) 162 | continue 163 | 164 | if isinstance(g0, tf.IndexedSlices): 165 | # If the gradient is type IndexedSlices then this is a sparse 166 | # gradient with attributes indices and values. 167 | # To average, need to concat them individually then create 168 | # a new IndexedSlices object. 169 | indices = [] 170 | values = [] 171 | for g, v in grad_and_vars: 172 | indices.append(g.indices) 173 | values.append(g.values) 174 | all_indices = tf.concat(indices, 0) 175 | if mask is None: 176 | avg_values = tf.concat(values, 0) / len(grad_and_vars) 177 | else: 178 | avg_values = tf.concat(values, 0) / tf.reduce_sum(mask) 179 | # deduplicate across indices 180 | av, ai = _deduplicate_indexed_slices(avg_values, all_indices) 181 | grad = tf.IndexedSlices(av, ai, dense_shape=g0.dense_shape) 182 | else: 183 | # a normal tensor can just do a simple average 184 | grads = [] 185 | for g, v in grad_and_vars: 186 | # Add 0 dimension to the gradients to represent the tower. 187 | expanded_g = tf.expand_dims(g, 0) 188 | # Append on a 'tower' dimension which we will average over 189 | grads.append(expanded_g) 190 | 191 | # Average over the 'tower' dimension. 192 | grad = tf.concat(grads, 0) 193 | if mask is not None: 194 | grad = tf.boolean_mask( 195 | grad, tf.cast(mask, tf.bool), axis=0) 196 | grad = tf.reduce_mean(grad, 0) 197 | 198 | # the Variables are redundant because they are shared 199 | # across towers. So.. just return the first tower's pointer to 200 | # the Variable. 201 | v = grad_and_vars[0][1] 202 | grad_and_var = (grad, v) 203 | 204 | average_grads.append(grad_and_var) 205 | 206 | assert len(average_grads) == len(list(zip(*tower_grads))) 207 | 208 | return average_grads 209 | -------------------------------------------------------------------------------- /utils/queuer.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | The Queue function mainly deals with reading and preparing dataset in a multi-processing manner. 5 | We didnot use the built-in tensorflow function Dataset because it lacks of flexibility. 6 | The function defined below is mainly inspired by https://github.com/ixlan/machine-learning-data-pipeline. 7 | """ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | from multiprocessing import Process, Queue 14 | # from threading import Thread as Process 15 | # from queue import Queue 16 | 17 | TERMINATION_TOKEN = "" 18 | 19 | 20 | def create_iter_from_queue(queue, term_token): 21 | 22 | while True: 23 | input_data_chunk = queue.get() 24 | if input_data_chunk == term_token: 25 | # put it back to the queue to let other processes that feed 26 | # from the same one to know that they should also break 27 | queue.put(term_token) 28 | break 29 | else: 30 | yield input_data_chunk 31 | 32 | 33 | def combine_reader_to_processor(reader, preprocessor): 34 | for data_chunk in reader: 35 | yield preprocessor(data_chunk) 36 | 37 | 38 | class EnQueuer(object): 39 | def __init__(self, 40 | reader, 41 | preprocessor, 42 | worker_processes_num=1, 43 | input_queue_size=5, 44 | output_queue_size=5 45 | ): 46 | if worker_processes_num < 0: 47 | raise ValueError("worker_processes_num must be a " 48 | "non-negative integer.") 49 | 50 | self.worker_processes_number = worker_processes_num 51 | self.preprocessor = preprocessor 52 | self.input_queue_size = input_queue_size 53 | self.output_queue_size = output_queue_size 54 | self.reader = reader 55 | 56 | # make the queue iterable 57 | def __iter__(self): 58 | return self._create_processed_data_chunks_gen(self.reader) 59 | 60 | def _create_processed_data_chunks_gen(self, reader_gen): 61 | if self.worker_processes_number == 0: 62 | itr = self._create_single_process_gen(reader_gen) 63 | else: 64 | itr = self._create_multi_process_gen(reader_gen) 65 | return itr 66 | 67 | def _create_single_process_gen(self, data_producer): 68 | return combine_reader_to_processor(data_producer, self.preprocessor) 69 | 70 | def _create_multi_process_gen(self, reader_gen): 71 | term_tokens_received = 0 72 | output_queue = Queue(self.output_queue_size) 73 | workers = [] 74 | 75 | if self.worker_processes_number > 1: 76 | term_tokens_expected = self.worker_processes_number - 1 77 | input_queue = Queue(self.input_queue_size) 78 | reader_worker = _ParallelWorker(reader_gen, input_queue) 79 | workers.append(reader_worker) 80 | 81 | # adding workers that will process the data 82 | for _ in range(self.worker_processes_number - 1): 83 | # since data-chunks will appear in the queue, making an iterable 84 | # object over it 85 | queue_iter = create_iter_from_queue(input_queue, 86 | TERMINATION_TOKEN) 87 | 88 | data_itr = combine_reader_to_processor(queue_iter, self.preprocessor) 89 | proc_worker = _ParallelWorker(data_chunk_iter=data_itr, 90 | queue=output_queue) 91 | workers.append(proc_worker) 92 | else: 93 | term_tokens_expected = 1 94 | 95 | data_itr = combine_reader_to_processor(reader_gen, self.preprocessor) 96 | proc_worker = _ParallelWorker(data_chunk_iter=data_itr, 97 | queue=output_queue) 98 | workers.append(proc_worker) 99 | 100 | for pr in workers: 101 | pr.daemon = True 102 | pr.start() 103 | 104 | while True: 105 | data_chunk = output_queue.get() 106 | if data_chunk == TERMINATION_TOKEN: 107 | term_tokens_received += 1 108 | # need to received all tokens in order to be sure that 109 | # all data has been processed 110 | if term_tokens_received == term_tokens_expected: 111 | for pr in workers: 112 | pr.join() 113 | break 114 | continue 115 | yield data_chunk 116 | 117 | 118 | class _ParallelWorker(Process): 119 | """Worker to execute data reading or processing on a separate process.""" 120 | 121 | def __init__(self, data_chunk_iter, queue): 122 | super(_ParallelWorker, self).__init__() 123 | self._data_chunk_iterable = data_chunk_iter 124 | self._queue = queue 125 | 126 | def run(self): 127 | for data_chunk in self._data_chunk_iterable: 128 | self._queue.put(data_chunk) 129 | self._queue.put(TERMINATION_TOKEN) 130 | -------------------------------------------------------------------------------- /utils/recorder.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import json 8 | import tensorflow as tf 9 | 10 | 11 | class Recorder(object): 12 | """To save training processes""" 13 | 14 | def load_from_json(self, file_name): 15 | tf.logging.info("Loading recoder file from {}".format(file_name)) 16 | with open(file_name, 'r', encoding='utf-8') as fh: 17 | self.__dict__.update(json.load(fh)) 18 | 19 | def save_to_json(self, file_name): 20 | tf.logging.info("Saving recorder file into {}".format(file_name)) 21 | with open(file_name, 'w', encoding='utf-8') as fh: 22 | json.dump(self.__dict__, fh, indent=2) 23 | -------------------------------------------------------------------------------- /utils/saver.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import tensorflow as tf 9 | 10 | 11 | class Saver(object): 12 | def __init__(self, 13 | checkpoints=5, # save the latest number of checkpoints 14 | output_dir=None, # the output directory 15 | best_score=-1, # the best bleu score before 16 | best_checkpoints=1, # the best checkpoints saved in best checkpoints directory 17 | ): 18 | if output_dir is None: 19 | output_dir = "./output" 20 | self.output_dir = output_dir 21 | self.output_best_dir = os.path.join(output_dir, "best") 22 | 23 | self.saver = tf.train.Saver( 24 | max_to_keep=checkpoints 25 | ) 26 | # handle disrupted checkpoints 27 | if tf.gfile.Exists(self.output_dir): 28 | ckpt = tf.train.get_checkpoint_state(self.output_dir) 29 | if ckpt and ckpt.all_model_checkpoint_paths: 30 | self.saver.recover_last_checkpoints(list(ckpt.all_model_checkpoint_paths)) 31 | 32 | self.best_saver = tf.train.Saver( 33 | max_to_keep=best_checkpoints, 34 | ) 35 | # handle disrupted checkpoints 36 | if tf.gfile.Exists(self.output_best_dir): 37 | ckpt = tf.train.get_checkpoint_state(self.output_best_dir) 38 | if ckpt and ckpt.all_model_checkpoint_paths: 39 | self.best_saver.recover_last_checkpoints(list(ckpt.all_model_checkpoint_paths)) 40 | 41 | self.best_score = best_score 42 | # check best bleu result 43 | metric_dir = os.path.join(self.output_best_dir, "metric.log") 44 | if tf.gfile.Exists(metric_dir): 45 | metric_lines = open(metric_dir).readlines() 46 | if len(metric_lines) > 0: 47 | best_score_line = metric_lines[-1] 48 | self.best_score = float(best_score_line.strip().split()[-1]) 49 | 50 | # check the top_k_best list and results 51 | self.topk_scores = [] 52 | topk_dir = os.path.join(self.output_best_dir, "topk_checkpoint") 53 | ckpt_dir = os.path.join(self.output_best_dir, "checkpoint") 54 | # direct load the topk information from topk_checkpoints 55 | if tf.gfile.Exists(topk_dir): 56 | with tf.gfile.Open(topk_dir) as reader: 57 | for line in reader: 58 | model_name, score = line.strip().split("\t") 59 | self.topk_scores.append((model_name, float(score))) 60 | # backup plan to normal checkpoints and best scores 61 | elif tf.gfile.Exists(ckpt_dir): 62 | latest_checkpoint = tf.gfile.Open(ckpt_dir).readline() 63 | model_name = latest_checkpoint.strip().split(":")[1].strip() 64 | model_name = model_name[1:-1] # remove "" 65 | self.topk_scores.append((model_name, self.best_score)) 66 | self.best_checkpoints = best_checkpoints 67 | 68 | self.score_record = tf.gfile.Open(metric_dir, mode="a+") 69 | 70 | def save(self, session, step, metric_score=None): 71 | if not tf.gfile.Exists(self.output_dir): 72 | tf.gfile.MkDir(self.output_dir) 73 | if not tf.gfile.Exists(self.output_best_dir): 74 | tf.gfile.MkDir(self.output_best_dir) 75 | 76 | self.saver.save(session, os.path.join(self.output_dir, "model"), global_step=step) 77 | 78 | def _move(path, new_path): 79 | if tf.gfile.Exists(path): 80 | if tf.gfile.Exists(new_path): 81 | tf.gfile.Remove(new_path) 82 | tf.gfile.Copy(path, new_path) 83 | 84 | if metric_score is not None and metric_score > self.best_score: 85 | self.best_score = metric_score 86 | 87 | _move(os.path.join(self.output_dir, "param.json"), 88 | os.path.join(self.output_best_dir, "param.json")) 89 | _move(os.path.join(self.output_dir, "record.json"), 90 | os.path.join(self.output_best_dir, "record.json")) 91 | 92 | # this recorder only record best scores 93 | self.score_record.write("Steps {}, Metric Score {}\n".format(step, metric_score)) 94 | self.score_record.flush() 95 | 96 | # either no model is saved, or current metric score is better than the minimum one 97 | if metric_score is not None and \ 98 | (len(self.topk_scores) == 0 or len(self.topk_scores) < self.best_checkpoints or 99 | metric_score > min([v[1] for v in self.topk_scores])): 100 | # manipulate the 'checkpoints', and change the orders 101 | ckpt_dir = os.path.join(self.output_best_dir, "checkpoint") 102 | if len(self.topk_scores) > 0: 103 | sorted_topk_scores = sorted(self.topk_scores, key=lambda x: x[1]) 104 | with tf.gfile.Open(ckpt_dir, mode='w') as writer: 105 | best_ckpt = sorted_topk_scores[-1] 106 | writer.write("model_checkpoint_path: \"{}\"\n".format(best_ckpt[0])) 107 | for model_name, _ in sorted_topk_scores: 108 | writer.write("all_model_checkpoint_paths: \"{}\"\n".format(model_name)) 109 | writer.flush() 110 | 111 | # update best_saver internal checkpoints status 112 | ckpt = tf.train.get_checkpoint_state(self.output_best_dir) 113 | if ckpt and ckpt.all_model_checkpoint_paths: 114 | self.best_saver.recover_last_checkpoints(list(ckpt.all_model_checkpoint_paths)) 115 | 116 | # this change mainly inspired by that sometimes for dataset, 117 | # the best performance is achieved by averaging top-k checkpoints 118 | self.best_saver.save( 119 | session, os.path.join(self.output_best_dir, "model"), global_step=step) 120 | 121 | # handle topk scores 122 | self.topk_scores.append(("model-{}".format(int(step)), float(metric_score))) 123 | sorted_topk_scores = sorted(self.topk_scores, key=lambda x: x[1]) 124 | self.topk_scores = sorted_topk_scores[-self.best_checkpoints:] 125 | topk_dir = os.path.join(self.output_best_dir, "topk_checkpoint") 126 | with tf.gfile.Open(topk_dir, mode='w') as writer: 127 | for model_name, score in self.topk_scores: 128 | writer.write("{}\t{}\n".format(model_name, score)) 129 | writer.flush() 130 | 131 | def restore(self, session, path=None, filter_variables=False): 132 | if path is not None and tf.gfile.Exists(path): 133 | check_dir = path 134 | else: 135 | check_dir = self.output_dir 136 | 137 | checkpoint = os.path.join(check_dir, "checkpoint") 138 | if not tf.gfile.Exists(checkpoint): 139 | tf.logging.warn("No Existing Model detected") 140 | else: 141 | latest_checkpoint = tf.gfile.Open(checkpoint).readline() 142 | model_name = latest_checkpoint.strip().split(":")[1].strip() 143 | model_name = model_name[1:-1] # remove "" 144 | model_path = os.path.join(check_dir, model_name) 145 | model_path = os.path.abspath(model_path) 146 | if not tf.gfile.Exists(model_path+".meta"): 147 | tf.logging.error("model '{}' does not exists" 148 | .format(model_path)) 149 | else: 150 | try: 151 | if path is not None: 152 | raise Exception('bypassing') 153 | self.saver.restore(session, model_path) 154 | except : 155 | # In this case, we simply assume that the cycle part 156 | # is mismatched, where the replicas are missing. 157 | # This would happen if you switch from un-cycle mode 158 | # to cycle mode. 159 | tf.logging.warn("Starting Backup Restore") 160 | ops = [] 161 | reader = tf.train.load_checkpoint(model_path) 162 | for var in tf.global_variables(): 163 | name = var.op.name 164 | 165 | if (not filter_variables and reader.has_tensor(name)) or ( 166 | filter_variables and reader.has_tensor(name) and 'decoder' not in name and 167 | 'global_step' not in name and 'Adam' not in name and 168 | ('embedding' not in name or 'pos_embedding' in name) 169 | ): 170 | tf.logging.info('{} get initialization from {}' 171 | .format(name, name)) 172 | ops.append( 173 | tf.assign(var, reader.get_tensor(name))) 174 | else: 175 | if 'global_step' in name and path is not None: 176 | ops.append(tf.assign(var, 0)) 177 | tf.logging.warn("{} is missed".format(name)) 178 | restore_op = tf.group(*ops, name="restore_global_vars") 179 | session.run(restore_op) 180 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import time 9 | import pkgutil 10 | import collections 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from utils import dtype 15 | 16 | 17 | def batch_indexer(datasize, batch_size): 18 | """Just divide the datasize into batched size""" 19 | dataindex = np.arange(datasize).tolist() 20 | 21 | batchindex = [] 22 | for i in range(datasize // batch_size): 23 | batchindex.append(dataindex[i * batch_size: (i + 1) * batch_size]) 24 | if datasize % batch_size > 0: 25 | batchindex.append(dataindex[-(datasize % batch_size):]) 26 | 27 | return batchindex 28 | 29 | 30 | def token_indexer(dataset, token_size): 31 | """Divide the dataset into token-based batch""" 32 | # assume dataset format: [(len1, len2, ..., lenN)] 33 | dataindex = np.arange(len(dataset)).tolist() 34 | 35 | batchindex = [] 36 | 37 | _batcher = [0.] * len(dataset[0]) 38 | _counter = 0 39 | i = 0 40 | while True: 41 | if i >= len(dataset): break 42 | 43 | # attempt put this datapoint into batch 44 | _batcher = [max(max_l, l) 45 | for max_l, l in zip(_batcher, dataset[i])] 46 | _counter += 1 47 | for l in _batcher: 48 | if _counter * l >= token_size: 49 | # when an extreme instance occur, handle it by making a 1-size batch 50 | if _counter > 1: 51 | batchindex.append(dataindex[i - _counter + 1: i]) 52 | i -= 1 53 | else: 54 | batchindex.append(dataindex[i: i + 1]) 55 | 56 | _counter = 0 57 | _batcher = [0.] * len(dataset[0]) 58 | break 59 | 60 | i += 1 61 | 62 | _counter = sum([len(slice) for slice in batchindex]) 63 | if _counter != len(dataset): 64 | batchindex.append(dataindex[_counter:]) 65 | return batchindex 66 | 67 | 68 | def mask_scale(value, mask, scale=None): 69 | """Prepared for masked softmax""" 70 | if scale is None: 71 | scale = dtype.inf() 72 | return value + (1. - mask) * (-scale) 73 | 74 | 75 | def valid_apply_dropout(x, dropout): 76 | """To check whether the dropout value is valid, apply if valid""" 77 | if dropout is not None and 0. <= dropout <= 1.: 78 | return tf.nn.dropout(x, 1. - dropout) 79 | return x 80 | 81 | 82 | def label_smooth(labels, vocab_size, factor=0.1): 83 | """Smooth the gold label distribution""" 84 | if 0. < factor < 1.: 85 | n = tf.cast(vocab_size - 1, tf.float32) 86 | p = 1. - factor 87 | q = factor / n 88 | 89 | t = tf.one_hot(tf.cast(tf.reshape(labels, [-1]), tf.int32), 90 | depth=vocab_size, on_value=p, off_value=q) 91 | normalizing = -(p * tf.log(p) + n * q * tf.log(q + 1e-20)) 92 | else: 93 | t = tf.one_hot(tf.cast(tf.reshape(labels, [-1]), tf.int32), 94 | depth=vocab_size) 95 | normalizing = 0. 96 | 97 | return t, normalizing 98 | 99 | 100 | def closing_dropout(params): 101 | """Removing all dropouts""" 102 | for k, v in params.values().items(): 103 | if 'dropout' in k: 104 | setattr(params, k, 0.0) 105 | # consider closing label smoothing 106 | if 'label_smoothing' in k: 107 | setattr(params, k, 0.0) 108 | return params 109 | 110 | 111 | def dict_update(d, u): 112 | """Recursive update dictionary""" 113 | for k, v in u.items(): 114 | if isinstance(v, collections.Mapping): 115 | d[k] = dict_update(d.get(k, {}), v) 116 | else: 117 | d[k] = v 118 | return d 119 | 120 | 121 | def embedding_to_padding(emb): 122 | """Calculates the padding mask based on which embeddings are all zero. 123 | We have hacked symbol_modality to return all-zero embeddings for padding. 124 | Args: 125 | emb: a Tensor with shape [..., depth]. 126 | Returns: 127 | a float Tensor with shape [...]. Each element is 1 if its corresponding 128 | embedding vector is all zero, and is 0 otherwise. 129 | """ 130 | emb_sum = tf.reduce_sum(tf.abs(emb), axis=-1) 131 | return tf.to_float(tf.equal(emb_sum, 0.0)) 132 | 133 | 134 | def shape_list(x): 135 | # Copied from Tensor2Tensor 136 | """Return list of dims, statically where possible.""" 137 | x = tf.convert_to_tensor(x) 138 | 139 | # If unknown rank, return dynamic shape 140 | if x.get_shape().dims is None: 141 | return tf.shape(x) 142 | 143 | static = x.get_shape().as_list() 144 | shape = tf.shape(x) 145 | 146 | ret = [] 147 | for i in range(len(static)): 148 | dim = static[i] 149 | if dim is None: 150 | dim = shape[i] 151 | ret.append(dim) 152 | return ret 153 | 154 | 155 | def get_shape_invariants(tensor): 156 | # Copied from Tensor2Tensor 157 | """Returns the shape of the tensor but sets middle dims to None.""" 158 | shape = tensor.shape.as_list() 159 | for i in range(1, len(shape) - 1): 160 | shape[i] = None 161 | 162 | return tf.TensorShape(shape) 163 | 164 | 165 | def merge_neighbor_dims(x, axis=0): 166 | """Merge neighbor dimension of x, start by axis""" 167 | if len(x.get_shape().as_list()) < axis + 2: 168 | return x 169 | 170 | shape = shape_list(x) 171 | shape[axis] *= shape[axis + 1] 172 | shape.pop(axis + 1) 173 | return tf.reshape(x, shape) 174 | 175 | 176 | def unmerge_neighbor_dims(x, depth, axis=0): 177 | """Inverse of merge_neighbor_dims, axis by depth""" 178 | if len(x.get_shape().as_list()) < axis + 1: 179 | return x 180 | 181 | shape = shape_list(x) 182 | width = shape[axis] // depth 183 | new_shape = shape[:axis] + [depth, width] + shape[axis + 1:] 184 | return tf.reshape(x, new_shape) 185 | 186 | 187 | def expand_tile_dims(x, depth, axis=1): 188 | """Expand and Tile x on axis by depth""" 189 | x = tf.expand_dims(x, axis=axis) 190 | tile_dims = [1] * x.shape.ndims 191 | tile_dims[axis] = depth 192 | 193 | return tf.tile(x, tile_dims) 194 | 195 | 196 | def gumbel_noise(shape, eps=None): 197 | """Generate gumbel noise shaped by shape""" 198 | if eps is None: 199 | eps = dtype.epsilon() 200 | 201 | u = tf.random_uniform(shape, minval=0, maxval=1) 202 | return -tf.log(-tf.log(u + eps) + eps) 203 | 204 | 205 | def log_prob_from_logits(logits): 206 | """Probability from un-nomalized logits""" 207 | return logits - tf.reduce_logsumexp(logits, axis=-1, keepdims=True) 208 | 209 | 210 | def batch_coordinates(batch_size, beam_size): 211 | """Batch coordinate indices under beam_size""" 212 | batch_pos = tf.range(batch_size * beam_size) // beam_size 213 | batch_pos = tf.reshape(batch_pos, [batch_size, beam_size]) 214 | 215 | return batch_pos 216 | 217 | 218 | def variable_printer(): 219 | """Print parameters""" 220 | all_weights = {v.name: v for v in tf.trainable_variables()} 221 | total_size = 0 222 | 223 | for v_name in sorted(list(all_weights)): 224 | v = all_weights[v_name] 225 | tf.logging.info("%s\tshape %s", v.name[:-2].ljust(80), 226 | str(v.shape).ljust(20)) 227 | v_size = np.prod(np.array(v.shape.as_list())).tolist() 228 | total_size += v_size 229 | tf.logging.info("Total trainable variables size: %d", total_size) 230 | 231 | 232 | def uniform_splits(total_size, num_shards): 233 | """Split the total_size into uniform num_shards lists""" 234 | size_per_shards = total_size // num_shards 235 | splits = [size_per_shards] * (num_shards - 1) + \ 236 | [total_size - (num_shards - 1) * size_per_shards] 237 | 238 | return splits 239 | 240 | 241 | def fetch_valid_ref_files(path): 242 | """Extracting valid reference files according to MT convention""" 243 | path = os.path.abspath(path) 244 | if tf.gfile.Exists(path): 245 | return [path] 246 | 247 | if not tf.gfile.Exists(path + ".ref0"): 248 | tf.logging.warn("Invalid Reference Format {}".format(path)) 249 | return None 250 | 251 | num = 0 252 | files = [] 253 | while True: 254 | file_path = path + ".ref%s" % num 255 | if tf.gfile.Exists(file_path): 256 | files.append(file_path) 257 | else: 258 | break 259 | num += 1 260 | return files 261 | 262 | 263 | def get_session(gpus): 264 | """Config session with GPUS""" 265 | 266 | sess_config = tf.ConfigProto(allow_soft_placement=True) 267 | sess_config.gpu_options.allow_growth = True 268 | if len(gpus) > 0: 269 | device_str = ",".join([str(i) for i in gpus]) 270 | sess_config.gpu_options.visible_device_list = device_str 271 | sess = tf.Session(config=sess_config) 272 | 273 | return sess 274 | 275 | 276 | def flatten_list(values): 277 | """Flatten a list""" 278 | return [v for value in values for v in value] 279 | 280 | 281 | def remove_invalid_seq(sequence, mask): 282 | """Pick valid sequence elements wrt mask""" 283 | # sequence: [batch, sequence] 284 | # mask: [batch, sequence] 285 | boolean_mask = tf.reduce_sum(mask, axis=0) 286 | 287 | # make sure that there are at least one element in the mask 288 | first_one = tf.one_hot(0, tf.shape(boolean_mask)[0], 289 | dtype=tf.as_dtype(dtype.floatx())) 290 | boolean_mask = tf.cast(boolean_mask + first_one, tf.bool) 291 | 292 | filtered_seq = tf.boolean_mask(sequence, boolean_mask, axis=1) 293 | filtered_mask = tf.boolean_mask(mask, boolean_mask, axis=1) 294 | return filtered_seq, filtered_mask 295 | 296 | 297 | def time_str(t=None): 298 | """String format of the time long data""" 299 | if t is None: 300 | t = time.time() 301 | ts = time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime(t)) 302 | return ts 303 | 304 | 305 | def dynamic_load_module(module, prefix=None): 306 | """Load submodules inside a module, mainly used for model loading, not robust!!!""" 307 | # loading all models under directory `models` dynamically 308 | if not isinstance(module, str): 309 | module = module.__path__ 310 | for importer, modname, ispkg in pkgutil.iter_modules(module): 311 | if prefix is None: 312 | __import__(modname) 313 | else: 314 | __import__("{}.{}".format(prefix, modname)) 315 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import argparse 8 | 9 | 10 | class Vocab(object): 11 | def __init__(self, vocab_file=None): 12 | self.word2id = {} 13 | self.id2word = {} 14 | self.word2count = {} 15 | 16 | self.pad_sym = "" 17 | self.eos_sym = "" 18 | self.unk_sym = "" 19 | 20 | self.insert(self.pad_sym) 21 | self.insert(self.unk_sym) 22 | self.insert(self.eos_sym) 23 | 24 | if vocab_file is not None: 25 | self.load_vocab(vocab_file) 26 | 27 | def insert(self, token): 28 | if token not in self.word2id: 29 | index = len(self.word2id) 30 | self.word2id[token] = index 31 | self.id2word[index] = token 32 | 33 | self.word2count[token] = 0 34 | self.word2count[token] += 1 35 | 36 | def size(self): 37 | return len(self.word2id) 38 | 39 | def load_vocab(self, vocab_file): 40 | with open(vocab_file, 'r', encoding='utf-8') as reader: 41 | for token in reader: 42 | self.insert(token.strip()) 43 | 44 | def get_token(self, id): 45 | if id in self.id2word: 46 | return self.id2word[id] 47 | return self.unk_sym 48 | 49 | def get_id(self, token): 50 | if token in self.word2id: 51 | return self.word2id[token] 52 | return self.word2id[self.unk_sym] 53 | 54 | def sort_vocab(self): 55 | sorted_word2count = sorted( 56 | self.word2count.items(), key=lambda x: - x[1]) 57 | self.word2id, self.id2word = {}, {} 58 | self.insert(self.pad_sym) 59 | self.insert(self.unk_sym) 60 | self.insert(self.eos_sym) 61 | for word, _ in sorted_word2count: 62 | self.insert(word) 63 | 64 | def save_vocab(self, vocab_file, size=1e6): 65 | with open(vocab_file, 'w') as writer: 66 | for id in range(min(self.size(), int(size))): 67 | writer.write(self.id2word[id] + "\n") 68 | 69 | def to_id(self, tokens, append_eos=True): 70 | if not append_eos: 71 | return [self.get_id(token) for token in tokens] 72 | else: 73 | return [self.get_id(token) for token in tokens + [self.eos_sym]] 74 | 75 | def to_tokens(self, ids): 76 | return [self.get_token(id) for id in ids] 77 | 78 | def eos(self): 79 | return self.get_id(self.eos_sym) 80 | 81 | def pad(self): 82 | return self.get_id(self.pad_sym) 83 | 84 | 85 | if __name__ == "__main__": 86 | parser = argparse.ArgumentParser('Vocabulary Preparison') 87 | parser.add_argument('--size', type=int, default=1e6, help='maximum vocabulary size') 88 | parser.add_argument('input', type=str, help='the input file path') 89 | parser.add_argument('output', type=str, help='the output file name') 90 | 91 | args = parser.parse_args() 92 | 93 | vocab = Vocab() 94 | with open(args.input, 'r', encoding='utf-8') as reader: 95 | for line in reader: 96 | for token in line.strip().split(): 97 | vocab.insert(token) 98 | 99 | vocab.sort_vocab() 100 | vocab.save_vocab(args.output, args.size) 101 | 102 | print("Loading {} tokens from {}".format(vocab.size(), args.input)) 103 | --------------------------------------------------------------------------------