├── model ├── __init__.py ├── embedding.py ├── utils.py ├── ffn.py ├── attention.py ├── transformer.py └── lib │ └── beam_search.py ├── utils ├── __init__.py ├── ranker.py └── beam_search.py ├── requires.txt ├── image ├── chat.PNG └── tensorboard.PNG ├── .gitignore ├── download_data.sh ├── config ├── seq_len_32.yml └── test_config.yml ├── README.md ├── train.py └── chat.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requires.txt: -------------------------------------------------------------------------------- 1 | PyYAML==4.2b1 2 | tensorflow-gpu==1.13.1 3 | numpy==1.16.1 4 | -------------------------------------------------------------------------------- /image/chat.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/st9007a/ChineseQABot/HEAD/image/chat.PNG -------------------------------------------------------------------------------- /image/tensorboard.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/st9007a/ChineseQABot/HEAD/image/tensorboard.PNG -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | __pycache__/ 3 | build/ 4 | serve/ 5 | saved_models/ 6 | tensorboard/ 7 | .ipynb_checkpoints/ 8 | -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data 2 | wget https://raw.githubusercontent.com/zake7749/Gossiping-Chinese-Corpus/master/data/Gossiping-QA-Dataset.txt -P data/ 3 | -------------------------------------------------------------------------------- /config/seq_len_32.yml: -------------------------------------------------------------------------------- 1 | arch: 2 | num_hidden_layers: 6 3 | hidden_size: 512 4 | filter_size: 1024 5 | num_heads: 8 6 | 7 | initializer_gain: 1.0 8 | 9 | relu_dropout: 0.1 10 | attention_dropout: 0.1 11 | layer_postprocess_dropout: 0.1 12 | 13 | allow_ffn_pad: True 14 | -------------------------------------------------------------------------------- /config/test_config.yml: -------------------------------------------------------------------------------- 1 | arch: 2 | num_hidden_layers: 6 3 | hidden_size: 256 4 | filter_size: 512 5 | num_heads: 8 6 | 7 | initializer_gain: 1.0 8 | 9 | relu_dropout: 0.1 10 | attention_dropout: 0.1 11 | layer_postprocess_dropout: 0.1 12 | 13 | allow_ffn_pad: True 14 | -------------------------------------------------------------------------------- /model/embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import tensorflow as tf 3 | 4 | class ShareWeightsEmbedding(): 5 | 6 | def __init__(self, vocab_size, hidden_size): 7 | 8 | self.vocab_size = vocab_size 9 | self.hidden_size = hidden_size 10 | 11 | def __call__(self, x): 12 | 13 | with tf.variable_scope('embedding_and_softmax', reuse=tf.AUTO_REUSE): 14 | self.share_weights = tf.get_variable( 15 | 'weights', 16 | [self.vocab_size, self.hidden_size], 17 | initializer=tf.random_normal_initializer(0, self.hidden_size ** -0.5), 18 | ) 19 | 20 | with tf.name_scope('embedding'): 21 | mask = tf.to_float(tf.not_equal(x, 0)) 22 | 23 | embedding = tf.gather(self.share_weights, x) 24 | embedding *= tf.expand_dims(mask, -1) 25 | embedding *= self.hidden_size ** 0.5 26 | 27 | return embedding 28 | 29 | def linear(self, x): 30 | 31 | with tf.name_scope('presoftmax_linear'): 32 | shape = tf.shape(x) 33 | batch_size = shape[0] 34 | length = shape[1] 35 | 36 | x = tf.reshape(x, [-1, self.hidden_size]) 37 | logits = tf.matmul(x, self.share_weights, transpose_b=True) 38 | logits = tf.reshape(logits, [batch_size, length, self.vocab_size]) 39 | 40 | return logits 41 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import math 3 | 4 | import tensorflow as tf 5 | 6 | _NEG_INF = -1e9 7 | 8 | def get_position_encoding(length, hidden_size, min_timescale=0.1, max_timescale=1.0e4): 9 | 10 | position = tf.to_float(tf.range(length)) 11 | num_timescales = hidden_size // 2 12 | 13 | log_timescale_increment = ( 14 | math.log(float(max_timescale) / float(min_timescale)) / 15 | (tf.to_float(num_timescales) - 1)) 16 | 17 | inv_timescales = min_timescale * tf.exp( 18 | tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) 19 | 20 | scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0) 21 | signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) 22 | 23 | return signal 24 | 25 | def get_decoder_self_attention_bias(length): 26 | 27 | with tf.name_scope("decoder_self_attention_bias"): 28 | valid_locs = tf.matrix_band_part(tf.ones([length, length]), -1, 0) 29 | valid_locs = tf.reshape(valid_locs, [1, 1, length, length]) 30 | decoder_bias = _NEG_INF * (1.0 - valid_locs) 31 | 32 | return decoder_bias 33 | 34 | def get_padding(x, padding_value=0): 35 | 36 | with tf.name_scope("padding"): 37 | return tf.to_float(tf.equal(x, padding_value)) 38 | 39 | def get_padding_bias(x): 40 | with tf.name_scope("attention_bias"): 41 | padding = get_padding(x) 42 | attention_bias = padding * _NEG_INF 43 | attention_bias = tf.expand_dims( 44 | tf.expand_dims(attention_bias, axis=1), axis=1) 45 | return attention_bias 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chinese QA Bot 2 | 3 | Using [transformer](https://arxiv.org/abs/1706.03762) to train a Chinese chatbot. 4 | 5 | ## Dataset 6 | 7 | - Reference: [PTT 中文語料](https://github.com/zake7749/Gossiping-Chinese-Corpus) 8 | - Download dataset: `sh download_data.sh` 9 | 10 | ## Code Reference 11 | 12 | - Google transformer architecture: [tensorflow/models/official/transformer](https://github.com/tensorflow/models/tree/master/official/transformer) 13 | 14 | ## Installation 15 | 16 | `pip3 install -r requires.txt` 17 | 18 | ## Training pipeline 19 | 20 | 1. Build the data file with `.tfrecord` format: 21 | ``` 22 | python3 build_data.py 23 | ``` 24 | 25 | 2. Train your model: 26 | ``` 27 | python3 train.py config/test_config.yml 28 | ``` 29 | 30 | You can customize your model architecture by writing a new `.yml` file. 31 | For more detail, see `config/test_config.yml` 32 | 33 | If you want to change the learning rate, total training steps or other training strategies, please modify the code in `train.py`. 34 | 35 | ## Tensorboard 36 | 37 | Type the following command and check the url: http://localhost:8080 38 | 39 | ``` 40 | tensorboard --logdir build --port 8080 41 | ``` 42 | 43 | ![](https://github.com/st9007a/ChineseQABot/blob/master/image/tensorboard.PNG) 44 | 45 | ## Run a simple chatbot 46 | 47 | `train.py` will export a [Tensorflow SavedModel](https://www.tensorflow.org/guide/saved_model#build_and_load_a_savedmodel) every 100000 training steps. 48 | Those models will be placed under `serve` folder. 49 | 50 | To run a simple demo, make sure SavedModel exist and type the following command: 51 | ``` 52 | python3 chat.py serve/[YOUR MODEL FOLDER] 53 | ``` 54 | 55 | ![](https://github.com/st9007a/ChineseQABot/blob/master/image/chat.PNG) 56 | -------------------------------------------------------------------------------- /utils/ranker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from math import pow, log 3 | 4 | import numpy as np 5 | 6 | stopwords = ['嗎', '啊', '啦', '喇', '不', '是', '好', '吧', '你', '就', '的'] 7 | 8 | class Ranker(): 9 | 10 | def __init__(self, repeat_penality=1.5): 11 | self.repeat_penality = repeat_penality 12 | 13 | def fit_transform(self, sequences): 14 | tf = [] 15 | df = {} 16 | token_set = set() 17 | avg_len = 0. 18 | scores = [] 19 | 20 | for seq in sequences: 21 | ttf = {} 22 | exist = set() 23 | for token in seq: 24 | if token in stopwords: 25 | continue 26 | 27 | ttf[token] = ttf.get(token, 0) + 1 28 | token_set.add(token) 29 | 30 | if token not in exist: 31 | df[token] = df.get(token, 0) + 1 32 | exist.add(token) 33 | 34 | tf.append(ttf) 35 | avg_len += len(seq) 36 | 37 | avg_len /= len(sequences) 38 | 39 | for i, seq in enumerate(sequences): 40 | score = 0 41 | tokens = set() 42 | for token in seq: 43 | if token in stopwords: 44 | continue 45 | 46 | s = df[token] * 1.0 / pow(tf[i][token], self.repeat_penality) 47 | tokens.add(token) 48 | 49 | score += s 50 | 51 | score += len(tokens)/len(token_set) 52 | scores.append(score) 53 | 54 | return scores 55 | 56 | if __name__ == '__main__': 57 | corpus = [ 58 | '麥克風測試', 59 | '測試', 60 | ] 61 | 62 | ranker = Ranker() 63 | s = ranker.fit_transform(corpus) 64 | print(s) 65 | -------------------------------------------------------------------------------- /model/ffn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import tensorflow as tf 3 | 4 | class FeedForwardNetwork(): 5 | 6 | def __init__(self, hidden_size, filter_size, relu_dropout, train, allow_pad): 7 | 8 | self.hidden_size = hidden_size 9 | self.filter_size = filter_size 10 | self.relu_dropout = relu_dropout 11 | self.train = train 12 | self.allow_pad = allow_pad 13 | 14 | self.filter_dense_layer = tf.layers.Dense( 15 | filter_size, 16 | use_bias=True, 17 | activation=tf.nn.relu, 18 | name='filter_layer', 19 | ) 20 | 21 | self.output_dense_layer = tf.layers.Dense( 22 | hidden_size, 23 | use_bias=True, 24 | activation=tf.nn.relu, 25 | name='output_layer', 26 | ) 27 | 28 | def __call__(self, x, padding=None): 29 | 30 | padding = None if not self.allow_pad else padding 31 | 32 | shape = tf.shape(x) 33 | batch_size = shape[0] 34 | length = shape[1] 35 | 36 | if padding is not None: 37 | with tf.name_scope('remove_padding'): 38 | 39 | pad_mask = tf.reshape(padding, [-1]) 40 | nonpad_ids = tf.to_int32(tf.where(pad_mask < 1e-9)) 41 | 42 | x = tf.reshape(x, [-1, self.hidden_size]) 43 | x = tf.gather_nd(x, indices=nonpad_ids) 44 | 45 | x.set_shape([None, self.hidden_size]) 46 | x = tf.expand_dims(x, axis=0) 47 | 48 | output = self.filter_dense_layer(x) 49 | 50 | if self.train: 51 | output = tf.nn.dropout(output, 1. - self.relu_dropout) 52 | 53 | output = self.output_dense_layer(output) 54 | 55 | if padding is not None: 56 | with tf.name_scope("re_add_padding"): 57 | output = tf.squeeze(output, axis=0) 58 | output = tf.scatter_nd( 59 | indices=nonpad_ids, 60 | updates=output, 61 | shape=[batch_size * length, self.hidden_size] 62 | ) 63 | output = tf.reshape(output, [batch_size, length, self.hidden_size]) 64 | 65 | return output 66 | -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import tensorflow as tf 3 | 4 | class Attention(): 5 | 6 | def __init__(self, hidden_size, num_heads, attention_dropout, train): 7 | 8 | if hidden_size % num_heads != 0: 9 | raise ValueError('Hidden size must be evenly divisible by the number of heads.') 10 | 11 | self.hidden_size = hidden_size 12 | self.num_heads = num_heads 13 | self.attention_dropout = attention_dropout 14 | self.train = train 15 | 16 | self.q_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="q") 17 | self.k_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="k") 18 | self.v_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="v") 19 | 20 | self.output_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="output_transform") 21 | 22 | def split_head(self, x): 23 | 24 | with tf.name_scope('split_head'): 25 | shape = tf.shape(x) 26 | batch_size = shape[0] 27 | length = shape[1] 28 | 29 | depth = (self.hidden_size // self.num_heads) 30 | 31 | x = tf.reshape(x, [batch_size, length, self.num_heads, depth]) 32 | 33 | return tf.transpose(x, [0, 2, 1, 3]) 34 | 35 | def combine_head(self, x): 36 | 37 | with tf.name_scope('combine_head'): 38 | shape = tf.shape(x) 39 | batch_size = shape[0] 40 | length = shape[2] 41 | 42 | x = tf.transpose(x, [0, 2, 1, 3]) 43 | 44 | return tf.reshape(x, [batch_size, length, self.hidden_size]) 45 | 46 | def __call__(self, x, y, bias): 47 | 48 | q = self.q_dense_layer(x) 49 | k = self.q_dense_layer(y) 50 | v = self.q_dense_layer(y) 51 | 52 | q = self.split_head(q) 53 | k = self.split_head(k) 54 | v = self.split_head(v) 55 | 56 | depth = (self.hidden_size // self.num_heads) 57 | q *= depth ** -0.5 58 | 59 | logits = tf.matmul(q, v, transpose_b=True) + bias 60 | weights = tf.nn.softmax(logits, name='attention_weights') 61 | 62 | if self.train: 63 | weights = tf.nn.dropout(weights, 1 - self.attention_dropout) 64 | 65 | attention_output = tf.matmul(weights, v) 66 | attention_output = self.combine_head(attention_output) 67 | attention_output = self.output_dense_layer(attention_output) 68 | 69 | return attention_output 70 | 71 | class SelfAttention(Attention): 72 | 73 | def __call__(self, x, bias): 74 | return super(SelfAttention, self).__call__(x, x, bias) 75 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import sys 3 | import yaml 4 | from pprint import pprint 5 | 6 | import tensorflow as tf 7 | 8 | from model.transformer import Transformer 9 | 10 | CONFIG = sys.argv[1] 11 | 12 | def load_config(): 13 | 14 | vocab = [] 15 | 16 | with open(CONFIG, 'r') as f: 17 | params = yaml.load(f) 18 | 19 | with open('data/vocab.txt', 'r') as f: 20 | for line in f: 21 | vocab.append(line[:-1]) 22 | 23 | params['arch']['vocab_size'] = len(vocab) 24 | 25 | return params, vocab 26 | 27 | def model_fn(features, labels, mode, params): 28 | 29 | with tf.variable_scope('model'): 30 | model = Transformer(params, mode == tf.estimator.ModeKeys.TRAIN) 31 | 32 | logits = model(features['q'], features['a']) 33 | 34 | if mode == tf.estimator.ModeKeys.PREDICT: 35 | return tf.estimator.EstimatorSpec( 36 | tf.estimator.ModeKeys.PREDICT, 37 | predictions=logits, 38 | export_outputs={ 39 | 'response': tf.estimator.export.PredictOutput(logits) 40 | }) 41 | 42 | xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) 43 | loss = tf.reduce_sum(xentropy) 44 | 45 | optimizer = tf.contrib.opt.LazyAdamOptimizer(learning_rate=1e-3) 46 | train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) 47 | 48 | return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) 49 | 50 | def train_input_fn(): 51 | 52 | def _parse_example(serialized_example): 53 | data_fields = {'q': tf.FixedLenFeature((32,), tf.int64), 'a': tf.FixedLenFeature((32,), tf.int64)} 54 | parsed = tf.parse_single_example(serialized_example, data_fields) 55 | 56 | return {'q': parsed['q'], 'a': parsed['a']}, parsed['a'] 57 | 58 | dataset = tf.data.TFRecordDataset('data/qa.tfrecords') 59 | dataset = dataset.map(_parse_example, num_parallel_calls=4) 60 | dataset = dataset.shuffle(50000) 61 | dataset = dataset.repeat() 62 | dataset = dataset.batch(256) 63 | 64 | return dataset 65 | 66 | def serving_input_fn(): 67 | inputs = {'q': tf.placeholder(tf.int64, [None, 32]), 'a': tf.placeholder(tf.int64, [None, 32])} 68 | return tf.estimator.export.ServingInputReceiver(inputs, inputs) 69 | 70 | if __name__ == '__main__': 71 | 72 | params, vocab = load_config() 73 | 74 | config = tf.estimator.RunConfig(save_checkpoints_steps=5000, model_dir='tensorboard/build3/') 75 | estimator = tf.estimator.Estimator(model_fn=model_fn, params=params['arch'], config=config) 76 | 77 | for i in range(10): 78 | estimator.train(train_input_fn, steps=50000) 79 | estimator.export_savedmodel(export_dir_base='saved_models/serve3/', serving_input_receiver_fn=serving_input_fn) 80 | -------------------------------------------------------------------------------- /chat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from math import log, pow 3 | import sys 4 | import time 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from utils.beam_search import BeamSearch 10 | from utils.ranker import Ranker 11 | 12 | stopwords = set(['嗎', '啊', '啦', '喇', '不', '是', '好', '吧', '你', '就']) 13 | 14 | def load_model(model_dir): 15 | sess = tf.Session() 16 | meta_graph_def = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], model_dir) 17 | signature = meta_graph_def.signature_def 18 | 19 | q_name = signature['serving_default'].inputs['q'].name 20 | a_name = signature['serving_default'].inputs['a'].name 21 | 22 | output_name = signature['serving_default'].outputs['output'].name 23 | 24 | q = sess.graph.get_tensor_by_name(q_name) 25 | a = sess.graph.get_tensor_by_name(a_name) 26 | output = sess.graph.get_tensor_by_name(output_name) 27 | 28 | return sess, {'q': q, 'a': a, 'output': output} 29 | 30 | def load_vocab(): 31 | word_list = [] 32 | word_idx_map = {} 33 | 34 | with open('data/vocab.txt', 'r') as f: 35 | for line in f: 36 | line = line.rstrip('\n') 37 | word_idx_map[line] = len(word_list) 38 | word_list.append(line) 39 | 40 | return word_list, word_idx_map 41 | 42 | if __name__ == '__main__': 43 | 44 | idx2word, word2idx = load_vocab() 45 | sess, tensors = load_model(sys.argv[1]) 46 | beam_search = BeamSearch(session=sess, 47 | eval_tensors=tensors['output'], 48 | feed_tensors=[tensors['q'], tensors['a']], 49 | alpha=0.6, 50 | beam_width=15, 51 | max_length=32, 52 | eos_id=1) 53 | ranker = Ranker(repeat_penality=1.5) 54 | 55 | while True: 56 | test_input = input('請輸入中文句子: ') 57 | 58 | test_input = [word2idx[el] for el in test_input if el in word2idx] 59 | while len(test_input) < 32: 60 | test_input.append(0) 61 | test_input = test_input[:32] 62 | 63 | start = time.time() 64 | finished_seq = beam_search.search(test_input) 65 | end = time.time() 66 | 67 | print('==============') 68 | print('Find {:d} candidates'.format(len(finished_seq))) 69 | print('Search time: %.6f sec' % (end - start)) 70 | 71 | finished_seq.sort(key=lambda x: x['score'], reverse=True) 72 | responses =[] 73 | 74 | for seq in finished_seq: 75 | response = [idx2word[int(el)] for el in seq['ids']] 76 | while len(response) > 0 and (response[-1] == '' or response[-1] == ''): 77 | response.pop() 78 | 79 | response = ''.join(response) 80 | responses.append(response) 81 | 82 | rerank_scores = ranker.fit_transform(responses) 83 | 84 | for rescore, seq, response in zip(rerank_scores, finished_seq, responses): 85 | print('Score: {:.6f}, Re-rank score: {:>9.6f}, Response: {:s}'.format(seq['score'], rescore, response)) 86 | 87 | final_id = rerank_scores.index(max(rerank_scores)) 88 | 89 | print('==============') 90 | print('Finial response: {:s}'.format(responses[final_id])) 91 | print('==============') 92 | -------------------------------------------------------------------------------- /utils/beam_search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from math import log, pow 3 | import bisect 4 | 5 | import numpy as np 6 | 7 | def softmax(x): 8 | return np.exp(x) / np.sum(np.exp(x), axis=0) 9 | 10 | def insort(array, item, key=None): 11 | if len(array) == 0: 12 | array.append(item) 13 | return 14 | 15 | index_array = [key(el) for el in array] if key else array 16 | index_item = key(item) if key else item 17 | 18 | index = bisect.bisect_right(index_array, index_item) 19 | array.insert(index, item) 20 | 21 | class BeamSearch(): 22 | 23 | def __init__(self, session, eval_tensors, feed_tensors, alpha=0.6, beam_width=10, max_length=100, eos_id=1): 24 | self.session = session 25 | self.eval_tensors = eval_tensors 26 | self.feed_tensors = feed_tensors 27 | 28 | self.beam_width = beam_width 29 | self.max_length = max_length 30 | self.alpha = alpha 31 | self.eos_id = eos_id 32 | 33 | def _run(self, vals): 34 | feed_dict = {tensor: value for tensor, value in zip(self.feed_tensors, vals)} 35 | return self.session.run(self.eval_tensors, feed_dict=feed_dict) 36 | 37 | def search(self, input_arr): 38 | res = np.zeros((self.max_length,)) 39 | alive_seq = [] 40 | finished_seq = [] 41 | 42 | alive_seq.append({'score': 0, 'probs': [], 'ids': res}) 43 | 44 | for i in range(self.max_length): 45 | answer = np.array([el['ids'] for el in alive_seq]) 46 | question = np.tile(input_arr, (answer.shape[0], 1)) 47 | 48 | # shape of `decode_ids` = (batch_size, max_length, vocab_size) 49 | decode_ids = self._run([question, answer]) 50 | 51 | new_alive_seq = [] 52 | 53 | for state, decode_proba in zip(alive_seq, decode_ids[:, i, :]): 54 | 55 | # shape of `decode_proba` = (vocab_size,) 56 | candidate = np.argsort(decode_proba)[-self.beam_width:] 57 | decode_proba = softmax(decode_proba) 58 | 59 | for idx in candidate: 60 | proba = log(decode_proba[idx]) 61 | seq = state['probs'] + [proba] 62 | 63 | # See http://opennmt.net/OpenNMT/translation/beam_search/#length-normalization 64 | len_norm = pow((5.+i+1.) / 6, self.alpha) 65 | score = sum(seq) / len_norm 66 | 67 | new_res = np.array(state['ids']) 68 | new_res[i] = idx 69 | 70 | if idx == self.eos_id: 71 | finished_seq.append({'score': score, 'ids': new_res}) 72 | else: 73 | new_alive_seq.append({'score': score, 'probs': seq, 'ids': new_res}) 74 | 75 | alive_seq = new_alive_seq 76 | alive_seq.sort(key=lambda el: el['score']) 77 | alive_seq = alive_seq[-self.beam_width:] 78 | 79 | if len(finished_seq) > 0: 80 | finished_seq.sort(key=lambda el: el['score']) 81 | finished_seq = finished_seq[-self.beam_width:] 82 | 83 | if finished_seq[0]['score'] > alive_seq[-1]['score']: 84 | break 85 | 86 | i = 0 87 | while len(finished_seq) < self.beam_width and i < len(alive_seq): 88 | finished_seq.append(alive_seq[i]) 89 | i += 1 90 | 91 | finished_seq.sort(key=lambda el: el['score']) 92 | 93 | return finished_seq 94 | 95 | if __name__ == '__main__': 96 | 97 | arr = [] 98 | 99 | insort(arr, (100, 'a'), key=lambda x: x[0]) 100 | print(arr) 101 | insort(arr, (300, 'a'), key=lambda x: x[0]) 102 | print(arr) 103 | insort(arr, (200, 'a'), key=lambda x: x[0]) 104 | print(arr) 105 | -------------------------------------------------------------------------------- /model/transformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import tensorflow as tf 3 | 4 | from .attention import Attention, SelfAttention 5 | from .embedding import ShareWeightsEmbedding 6 | from .ffn import FeedForwardNetwork 7 | from .utils import get_padding, get_padding_bias, get_decoder_self_attention_bias, get_position_encoding 8 | 9 | class Transformer(): 10 | 11 | def __init__(self, params, train): 12 | self.params = params 13 | self.train = train 14 | 15 | self.embedding_softmax_layer = ShareWeightsEmbedding( 16 | params['vocab_size'], params['hidden_size']) 17 | 18 | self.encoder_stack = EncoderStack(params, train) 19 | self.decoder_stack = DecoderStack(params, train) 20 | 21 | def __call__(self, inputs, targets): 22 | initializer = tf.variance_scaling_initializer(self.params["initializer_gain"], 23 | mode="fan_avg", distribution="uniform") 24 | 25 | with tf.variable_scope('transformer', initializer=initializer): 26 | attention_bias = get_padding_bias(inputs) 27 | 28 | with tf.variable_scope('encode_stack'): 29 | encoder_outputs = self.encode(inputs, attention_bias) 30 | 31 | with tf.variable_scope('decode_stack'): 32 | logits = self.decode(targets, encoder_outputs, attention_bias) 33 | return logits 34 | 35 | def encode(self, inputs, attention_bias): 36 | 37 | with tf.name_scope('encode'): 38 | 39 | embedded_inputs = self.embedding_softmax_layer(inputs) 40 | inputs_padding = get_padding(inputs) 41 | 42 | with tf.name_scope('add_pos_encoding'): 43 | length = tf.shape(embedded_inputs)[1] 44 | pos_encoding = get_position_encoding(length, self.params['hidden_size']) 45 | 46 | encoder_inputs = embedded_inputs + pos_encoding 47 | 48 | if self.train: 49 | encoder_inputs = tf.nn.dropout(encoder_inputs, 1. - self.params['layer_postprocess_dropout']) 50 | 51 | return self.encoder_stack(encoder_inputs, attention_bias, inputs_padding) 52 | 53 | def decode(self, targets, encoder_outputs, attention_bias): 54 | 55 | with tf.name_scope('decode'): 56 | decoder_inputs = self.embedding_softmax_layer(targets) 57 | 58 | with tf.name_scope('shift_targets'): 59 | decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] 60 | 61 | with tf.name_scope('add_pos_encoding'): 62 | length = tf.shape(decoder_inputs)[1] 63 | decoder_inputs += get_position_encoding(length, self.params['hidden_size']) 64 | 65 | if self.train: 66 | decoder_inputs = tf.nn.dropout(decoder_inputs, 1. - self.params['layer_postprocess_dropout']) 67 | 68 | decoder_self_attention_bias = get_decoder_self_attention_bias(length) 69 | 70 | outputs = self.decoder_stack(decoder_inputs, encoder_outputs, 71 | decoder_self_attention_bias, attention_bias) 72 | 73 | logits = self.embedding_softmax_layer.linear(outputs) 74 | 75 | return logits 76 | 77 | class LayerNormalization(): 78 | 79 | def __init__(self, hidden_size): 80 | self.hidden_size = hidden_size 81 | 82 | def __call__(self, x, epsilon=1e-6): 83 | 84 | self.scale = tf.get_variable('layer_norm_scale', [self.hidden_size], 85 | initializer=tf.ones_initializer()) 86 | 87 | self.bias = tf.get_variable('layer_norm_bias', [self.hidden_size], 88 | initializer=tf.zeros_initializer()) 89 | 90 | mean = tf.reduce_mean(x, axis=[-1], keepdims=True) 91 | variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True) 92 | norm_x = (x - mean) * tf.rsqrt(variance + epsilon) 93 | 94 | return norm_x * self.scale + self.bias 95 | 96 | class PrePostProcessingWrapper(): 97 | 98 | def __init__(self, layer, params, train): 99 | self.layer = layer 100 | self.postprocess_dropout = params['layer_postprocess_dropout'] 101 | self.train = train 102 | 103 | self.layer_norm = LayerNormalization(params['hidden_size']) 104 | 105 | def __call__(self, x, *args, **kwargs): 106 | 107 | y = self.layer_norm(x) 108 | y = self.layer(y, *args, **kwargs) 109 | 110 | if self.train: 111 | y = tf.nn.dropout(y, 1. - self.postprocess_dropout) 112 | 113 | return x + y 114 | 115 | class EncoderStack(): 116 | 117 | def __init__(self, params, train): 118 | 119 | self.layers = [] 120 | 121 | for _ in range(params['num_hidden_layers']): 122 | self_attention_layer = SelfAttention(params['hidden_size'], params['num_heads'], 123 | params['attention_dropout'], train) 124 | 125 | feed_forward_network = FeedForwardNetwork(params['hidden_size'], params['filter_size'], 126 | params['relu_dropout'], train, 127 | params['allow_ffn_pad']) 128 | 129 | self.layers.append([ 130 | PrePostProcessingWrapper(self_attention_layer, params, train), 131 | PrePostProcessingWrapper(feed_forward_network, params, train) 132 | ]) 133 | 134 | self.output_normalization = LayerNormalization(params['hidden_size']) 135 | 136 | def __call__(self, encoder_inputs, attention_bias, inputs_padding): 137 | 138 | for n, layer in enumerate(self.layers): 139 | self_attention_layer = layer[0] 140 | feed_forward_network = layer[1] 141 | 142 | with tf.variable_scope('layer_%d' % n): 143 | with tf.variable_scope('self_attention'): 144 | encoder_inputs = self_attention_layer(encoder_inputs, attention_bias) 145 | 146 | with tf.variable_scope('ffn'): 147 | encoder_inputs = feed_forward_network(encoder_inputs, inputs_padding) 148 | 149 | return self.output_normalization(encoder_inputs) 150 | 151 | class DecoderStack(): 152 | 153 | def __init__(self, params, train): 154 | 155 | self.layers = [] 156 | 157 | for _ in range(params['num_hidden_layers']): 158 | self_attention_layer = SelfAttention(params['hidden_size'], params['num_heads'], 159 | params['attention_dropout'], train) 160 | 161 | enc_dec_attention_layer = Attention(params['hidden_size'], params['num_heads'], 162 | params['attention_dropout'], train) 163 | 164 | feed_forward_network = FeedForwardNetwork(params['hidden_size'], params['filter_size'], 165 | params['relu_dropout'], train, 166 | params['allow_ffn_pad']) 167 | 168 | self.layers.append([ 169 | PrePostProcessingWrapper(self_attention_layer, params, train), 170 | PrePostProcessingWrapper(enc_dec_attention_layer, params, train), 171 | PrePostProcessingWrapper(feed_forward_network, params, train) 172 | ]) 173 | 174 | self.output_normalization = LayerNormalization(params['hidden_size']) 175 | 176 | def __call__(self, decoder_inputs, encoder_outputs, decoder_self_attention_bias, attention_bias): 177 | 178 | for n, layer in enumerate(self.layers): 179 | self_attention = layer[0] 180 | enc_dec_attention = layer[1] 181 | feed_forward_network = layer[2] 182 | 183 | layer_name = 'layer_%d' % n 184 | 185 | with tf.variable_scope(layer_name): 186 | with tf.variable_scope("self_attention"): 187 | decoder_inputs = self_attention( 188 | decoder_inputs, decoder_self_attention_bias) 189 | 190 | with tf.variable_scope("encdec_attention"): 191 | decoder_inputs = enc_dec_attention( 192 | decoder_inputs, encoder_outputs, attention_bias) 193 | 194 | with tf.variable_scope("ffn"): 195 | decoder_inputs = feed_forward_network(decoder_inputs) 196 | 197 | return self.output_normalization(decoder_inputs) 198 | -------------------------------------------------------------------------------- /model/lib/beam_search.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Beam search to find the translated sequence with the highest probability. 16 | 17 | Source implementation from Tensor2Tensor: 18 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/beam_search.py 19 | """ 20 | 21 | import tensorflow as tf 22 | from tensorflow.python.util import nest 23 | 24 | # Default value for INF 25 | INF = 1. * 1e7 26 | 27 | 28 | class _StateKeys(object): 29 | """Keys to dictionary storing the state of the beam search loop.""" 30 | 31 | # Variable storing the loop index. 32 | CUR_INDEX = "CUR_INDEX" 33 | 34 | # Top sequences that are alive for each batch item. Alive sequences are ones 35 | # that have not generated an EOS token. Sequences that reach EOS are marked as 36 | # finished and moved to the FINISHED_SEQ tensor. 37 | # Has shape [batch_size, beam_size, CUR_INDEX + 1] 38 | ALIVE_SEQ = "ALIVE_SEQ" 39 | # Log probabilities of each alive sequence. Shape [batch_size, beam_size] 40 | ALIVE_LOG_PROBS = "ALIVE_LOG_PROBS" 41 | # Dictionary of cached values for each alive sequence. The cache stores 42 | # the encoder output, attention bias, and the decoder attention output from 43 | # the previous iteration. 44 | ALIVE_CACHE = "ALIVE_CACHE" 45 | 46 | # Top finished sequences for each batch item. 47 | # Has shape [batch_size, beam_size, CUR_INDEX + 1]. Sequences that are 48 | # shorter than CUR_INDEX + 1 are padded with 0s. 49 | FINISHED_SEQ = "FINISHED_SEQ" 50 | # Scores for each finished sequence. Score = log probability / length norm 51 | # Shape [batch_size, beam_size] 52 | FINISHED_SCORES = "FINISHED_SCORES" 53 | # Flags indicating which sequences in the finished sequences are finished. 54 | # At the beginning, all of the sequences in FINISHED_SEQ are filler values. 55 | # True -> finished sequence, False -> filler. Shape [batch_size, beam_size] 56 | FINISHED_FLAGS = "FINISHED_FLAGS" 57 | 58 | 59 | class SequenceBeamSearch(object): 60 | """Implementation of beam search loop.""" 61 | 62 | def __init__(self, symbols_to_logits_fn, vocab_size, batch_size, 63 | beam_size, alpha, max_decode_length, eos_id): 64 | self.symbols_to_logits_fn = symbols_to_logits_fn 65 | self.vocab_size = vocab_size 66 | self.batch_size = batch_size 67 | self.beam_size = beam_size 68 | self.alpha = alpha 69 | self.max_decode_length = max_decode_length 70 | self.eos_id = eos_id 71 | 72 | def search(self, initial_ids, initial_cache): 73 | """Beam search for sequences with highest scores.""" 74 | state, state_shapes = self._create_initial_state(initial_ids, initial_cache) 75 | 76 | finished_state = tf.while_loop( 77 | self._continue_search, self._search_step, loop_vars=[state], 78 | shape_invariants=[state_shapes], parallel_iterations=1, back_prop=False) 79 | finished_state = finished_state[0] 80 | 81 | alive_seq = finished_state[_StateKeys.ALIVE_SEQ] 82 | alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS] 83 | finished_seq = finished_state[_StateKeys.FINISHED_SEQ] 84 | finished_scores = finished_state[_StateKeys.FINISHED_SCORES] 85 | finished_flags = finished_state[_StateKeys.FINISHED_FLAGS] 86 | 87 | # Account for corner case where there are no finished sequences for a 88 | # particular batch item. In that case, return alive sequences for that batch 89 | # item. 90 | finished_seq = tf.where( 91 | tf.reduce_any(finished_flags, 1), finished_seq, alive_seq) 92 | finished_scores = tf.where( 93 | tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs) 94 | return finished_seq, finished_scores 95 | 96 | def _create_initial_state(self, initial_ids, initial_cache): 97 | """Return initial state dictionary and its shape invariants. 98 | 99 | Args: 100 | initial_ids: initial ids to pass into the symbols_to_logits_fn. 101 | int tensor with shape [batch_size, 1] 102 | initial_cache: dictionary storing values to be passed into the 103 | symbols_to_logits_fn. 104 | 105 | Returns: 106 | state and shape invariant dictionaries with keys from _StateKeys 107 | """ 108 | # Current loop index (starts at 0) 109 | cur_index = tf.constant(0) 110 | 111 | # Create alive sequence with shape [batch_size, beam_size, 1] 112 | alive_seq = _expand_to_beam_size(initial_ids, self.beam_size) 113 | alive_seq = tf.expand_dims(alive_seq, axis=2) 114 | 115 | # Create tensor for storing initial log probabilities. 116 | # Assume initial_ids are prob 1.0 117 | initial_log_probs = tf.constant( 118 | [[0.] + [-float("inf")] * (self.beam_size - 1)]) 119 | alive_log_probs = tf.tile(initial_log_probs, [self.batch_size, 1]) 120 | 121 | # Expand all values stored in the dictionary to the beam size, so that each 122 | # beam has a separate cache. 123 | alive_cache = nest.map_structure( 124 | lambda t: _expand_to_beam_size(t, self.beam_size), initial_cache) 125 | 126 | # Initialize tensor storing finished sequences with filler values. 127 | finished_seq = tf.zeros(tf.shape(alive_seq), tf.int32) 128 | 129 | # Set scores of the initial finished seqs to negative infinity. 130 | finished_scores = tf.ones([self.batch_size, self.beam_size]) * -INF 131 | 132 | # Initialize finished flags with all False values. 133 | finished_flags = tf.zeros([self.batch_size, self.beam_size], tf.bool) 134 | 135 | # Create state dictionary 136 | state = { 137 | _StateKeys.CUR_INDEX: cur_index, 138 | _StateKeys.ALIVE_SEQ: alive_seq, 139 | _StateKeys.ALIVE_LOG_PROBS: alive_log_probs, 140 | _StateKeys.ALIVE_CACHE: alive_cache, 141 | _StateKeys.FINISHED_SEQ: finished_seq, 142 | _StateKeys.FINISHED_SCORES: finished_scores, 143 | _StateKeys.FINISHED_FLAGS: finished_flags 144 | } 145 | 146 | # Create state invariants for each value in the state dictionary. Each 147 | # dimension must be a constant or None. A None dimension means either: 148 | # 1) the dimension's value is a tensor that remains the same but may 149 | # depend on the input sequence to the model (e.g. batch size). 150 | # 2) the dimension may have different values on different iterations. 151 | state_shape_invariants = { 152 | _StateKeys.CUR_INDEX: tf.TensorShape([]), 153 | _StateKeys.ALIVE_SEQ: tf.TensorShape([None, self.beam_size, None]), 154 | _StateKeys.ALIVE_LOG_PROBS: tf.TensorShape([None, self.beam_size]), 155 | _StateKeys.ALIVE_CACHE: nest.map_structure( 156 | _get_shape_keep_last_dim, alive_cache), 157 | _StateKeys.FINISHED_SEQ: tf.TensorShape([None, self.beam_size, None]), 158 | _StateKeys.FINISHED_SCORES: tf.TensorShape([None, self.beam_size]), 159 | _StateKeys.FINISHED_FLAGS: tf.TensorShape([None, self.beam_size]) 160 | } 161 | 162 | return state, state_shape_invariants 163 | 164 | def _continue_search(self, state): 165 | """Return whether to continue the search loop. 166 | 167 | The loops should terminate when 168 | 1) when decode length has been reached, or 169 | 2) when the worst score in the finished sequences is better than the best 170 | score in the alive sequences (i.e. the finished sequences are provably 171 | unchanging) 172 | 173 | Args: 174 | state: A dictionary with the current loop state. 175 | 176 | Returns: 177 | Bool tensor with value True if loop should continue, False if loop should 178 | terminate. 179 | """ 180 | i = state[_StateKeys.CUR_INDEX] 181 | alive_log_probs = state[_StateKeys.ALIVE_LOG_PROBS] 182 | finished_scores = state[_StateKeys.FINISHED_SCORES] 183 | finished_flags = state[_StateKeys.FINISHED_FLAGS] 184 | 185 | not_at_max_decode_length = tf.less(i, self.max_decode_length) 186 | 187 | # Calculate largest length penalty (the larger penalty, the better score). 188 | max_length_norm = _length_normalization(self.alpha, self.max_decode_length) 189 | # Get the best possible scores from alive sequences. 190 | best_alive_scores = alive_log_probs[:, 0] / max_length_norm 191 | 192 | # Compute worst score in finished sequences for each batch element 193 | finished_scores *= tf.to_float(finished_flags) # set filler scores to zero 194 | lowest_finished_scores = tf.reduce_min(finished_scores, axis=1) 195 | 196 | # If there are no finished sequences in a batch element, then set the lowest 197 | # finished score to -INF for that element. 198 | finished_batches = tf.reduce_any(finished_flags, 1) 199 | lowest_finished_scores += (1. - tf.to_float(finished_batches)) * -INF 200 | 201 | worst_finished_score_better_than_best_alive_score = tf.reduce_all( 202 | tf.greater(lowest_finished_scores, best_alive_scores) 203 | ) 204 | 205 | return tf.logical_and( 206 | not_at_max_decode_length, 207 | tf.logical_not(worst_finished_score_better_than_best_alive_score) 208 | ) 209 | 210 | def _search_step(self, state): 211 | """Beam search loop body. 212 | 213 | Grow alive sequences by a single ID. Sequences that have reached the EOS 214 | token are marked as finished. The alive and finished sequences with the 215 | highest log probabilities and scores are returned. 216 | 217 | A sequence's finished score is calculating by dividing the log probability 218 | by the length normalization factor. Without length normalization, the 219 | search is more likely to return shorter sequences. 220 | 221 | Args: 222 | state: A dictionary with the current loop state. 223 | 224 | Returns: 225 | new state dictionary. 226 | """ 227 | # Grow alive sequences by one token. 228 | new_seq, new_log_probs, new_cache = self._grow_alive_seq(state) 229 | # Collect top beam_size alive sequences 230 | alive_state = self._get_new_alive_state(new_seq, new_log_probs, new_cache) 231 | 232 | # Combine newly finished sequences with existing finished sequences, and 233 | # collect the top k scoring sequences. 234 | finished_state = self._get_new_finished_state(state, new_seq, new_log_probs) 235 | 236 | # Increment loop index and create new state dictionary 237 | new_state = {_StateKeys.CUR_INDEX: state[_StateKeys.CUR_INDEX] + 1} 238 | new_state.update(alive_state) 239 | new_state.update(finished_state) 240 | return [new_state] 241 | 242 | def _grow_alive_seq(self, state): 243 | """Grow alive sequences by one token, and collect top 2*beam_size sequences. 244 | 245 | 2*beam_size sequences are collected because some sequences may have reached 246 | the EOS token. 2*beam_size ensures that at least beam_size sequences are 247 | still alive. 248 | 249 | Args: 250 | state: A dictionary with the current loop state. 251 | Returns: 252 | Tuple of 253 | (Top 2*beam_size sequences [batch_size, 2 * beam_size, cur_index + 1], 254 | Scores of returned sequences [batch_size, 2 * beam_size], 255 | New alive cache, for each of the 2 * beam_size sequences) 256 | """ 257 | i = state[_StateKeys.CUR_INDEX] 258 | alive_seq = state[_StateKeys.ALIVE_SEQ] 259 | alive_log_probs = state[_StateKeys.ALIVE_LOG_PROBS] 260 | alive_cache = state[_StateKeys.ALIVE_CACHE] 261 | 262 | beams_to_keep = 2 * self.beam_size 263 | 264 | # Get logits for the next candidate IDs for the alive sequences. Get the new 265 | # cache values at the same time. 266 | flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size] 267 | flat_cache = nest.map_structure(_flatten_beam_dim, alive_cache) 268 | 269 | flat_logits, flat_cache = self.symbols_to_logits_fn(flat_ids, i, flat_cache) 270 | 271 | # Unflatten logits to shape [batch_size, beam_size, vocab_size] 272 | logits = _unflatten_beam_dim(flat_logits, self.batch_size, self.beam_size) 273 | new_cache = nest.map_structure( 274 | lambda t: _unflatten_beam_dim(t, self.batch_size, self.beam_size), 275 | flat_cache) 276 | 277 | # Convert logits to normalized log probs 278 | candidate_log_probs = _log_prob_from_logits(logits) 279 | 280 | # Calculate new log probabilities if each of the alive sequences were 281 | # extended # by the the candidate IDs. 282 | # Shape [batch_size, beam_size, vocab_size] 283 | log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2) 284 | 285 | # Each batch item has beam_size * vocab_size candidate sequences. For each 286 | # batch item, get the k candidates with the highest log probabilities. 287 | flat_log_probs = tf.reshape(log_probs, 288 | [-1, self.beam_size * self.vocab_size]) 289 | topk_log_probs, topk_indices = tf.nn.top_k(flat_log_probs, k=beams_to_keep) 290 | 291 | # Extract the alive sequences that generate the highest log probabilities 292 | # after being extended. 293 | topk_beam_indices = topk_indices // self.vocab_size 294 | topk_seq, new_cache = _gather_beams( 295 | [alive_seq, new_cache], topk_beam_indices, self.batch_size, 296 | beams_to_keep) 297 | 298 | # Append the most probable IDs to the topk sequences 299 | topk_ids = topk_indices % self.vocab_size 300 | topk_ids = tf.expand_dims(topk_ids, axis=2) 301 | topk_seq = tf.concat([topk_seq, topk_ids], axis=2) 302 | return topk_seq, topk_log_probs, new_cache 303 | 304 | def _get_new_alive_state(self, new_seq, new_log_probs, new_cache): 305 | """Gather the top k sequences that are still alive. 306 | 307 | Args: 308 | new_seq: New sequences generated by growing the current alive sequences 309 | int32 tensor with shape [batch_size, 2 * beam_size, cur_index + 1] 310 | new_log_probs: Log probabilities of new sequences 311 | float32 tensor with shape [batch_size, beam_size] 312 | new_cache: Dict of cached values for each sequence. 313 | 314 | Returns: 315 | Dictionary with alive keys from _StateKeys: 316 | {Top beam_size sequences that are still alive (don't end with eos_id) 317 | Log probabilities of top alive sequences 318 | Dict cache storing decoder states for top alive sequences} 319 | """ 320 | # To prevent finished sequences from being considered, set log probs to -INF 321 | new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id) 322 | new_log_probs += tf.to_float(new_finished_flags) * -INF 323 | 324 | top_alive_seq, top_alive_log_probs, top_alive_cache = _gather_topk_beams( 325 | [new_seq, new_log_probs, new_cache], new_log_probs, self.batch_size, 326 | self.beam_size) 327 | 328 | return { 329 | _StateKeys.ALIVE_SEQ: top_alive_seq, 330 | _StateKeys.ALIVE_LOG_PROBS: top_alive_log_probs, 331 | _StateKeys.ALIVE_CACHE: top_alive_cache 332 | } 333 | 334 | def _get_new_finished_state(self, state, new_seq, new_log_probs): 335 | """Combine new and old finished sequences, and gather the top k sequences. 336 | 337 | Args: 338 | state: A dictionary with the current loop state. 339 | new_seq: New sequences generated by growing the current alive sequences 340 | int32 tensor with shape [batch_size, beam_size, i + 1] 341 | new_log_probs: Log probabilities of new sequences 342 | float32 tensor with shape [batch_size, beam_size] 343 | 344 | Returns: 345 | Dictionary with finished keys from _StateKeys: 346 | {Top beam_size finished sequences based on score, 347 | Scores of finished sequences, 348 | Finished flags of finished sequences} 349 | """ 350 | i = state[_StateKeys.CUR_INDEX] 351 | finished_seq = state[_StateKeys.FINISHED_SEQ] 352 | finished_scores = state[_StateKeys.FINISHED_SCORES] 353 | finished_flags = state[_StateKeys.FINISHED_FLAGS] 354 | 355 | # First append a column of 0-ids to finished_seq to increment the length. 356 | # New shape of finished_seq: [batch_size, beam_size, i + 1] 357 | finished_seq = tf.concat( 358 | [finished_seq, 359 | tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)], axis=2) 360 | 361 | # Calculate new seq scores from log probabilities. 362 | length_norm = _length_normalization(self.alpha, i + 1) 363 | new_scores = new_log_probs / length_norm 364 | 365 | # Set the scores of the still-alive seq in new_seq to large negative values. 366 | new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id) 367 | new_scores += (1. - tf.to_float(new_finished_flags)) * -INF 368 | 369 | # Combine sequences, scores, and flags. 370 | finished_seq = tf.concat([finished_seq, new_seq], axis=1) 371 | finished_scores = tf.concat([finished_scores, new_scores], axis=1) 372 | finished_flags = tf.concat([finished_flags, new_finished_flags], axis=1) 373 | 374 | # Return the finished sequences with the best scores. 375 | top_finished_seq, top_finished_scores, top_finished_flags = ( 376 | _gather_topk_beams([finished_seq, finished_scores, finished_flags], 377 | finished_scores, self.batch_size, self.beam_size)) 378 | 379 | return { 380 | _StateKeys.FINISHED_SEQ: top_finished_seq, 381 | _StateKeys.FINISHED_SCORES: top_finished_scores, 382 | _StateKeys.FINISHED_FLAGS: top_finished_flags 383 | } 384 | 385 | 386 | def sequence_beam_search( 387 | symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size, 388 | alpha, max_decode_length, eos_id): 389 | """Search for sequence of subtoken ids with the largest probability. 390 | 391 | Args: 392 | symbols_to_logits_fn: A function that takes in ids, index, and cache as 393 | arguments. The passed in arguments will have shape: 394 | ids -> [batch_size * beam_size, index] 395 | index -> [] (scalar) 396 | cache -> nested dictionary of tensors [batch_size * beam_size, ...] 397 | The function must return logits and new cache. 398 | logits -> [batch * beam_size, vocab_size] 399 | new cache -> same shape/structure as inputted cache 400 | initial_ids: Starting ids for each batch item. 401 | int32 tensor with shape [batch_size] 402 | initial_cache: dict containing starting decoder variables information 403 | vocab_size: int size of tokens 404 | beam_size: int number of beams 405 | alpha: float defining the strength of length normalization 406 | max_decode_length: maximum length to decoded sequence 407 | eos_id: int id of eos token, used to determine when a sequence has finished 408 | 409 | Returns: 410 | Top decoded sequences [batch_size, beam_size, max_decode_length] 411 | sequence scores [batch_size, beam_size] 412 | """ 413 | batch_size = tf.shape(initial_ids)[0] 414 | sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size, 415 | beam_size, alpha, max_decode_length, eos_id) 416 | return sbs.search(initial_ids, initial_cache) 417 | 418 | 419 | def _log_prob_from_logits(logits): 420 | return logits - tf.reduce_logsumexp(logits, axis=2, keep_dims=True) 421 | 422 | 423 | def _length_normalization(alpha, length): 424 | """Return length normalization factor.""" 425 | return tf.pow(((5. + tf.to_float(length)) / 6.), alpha) 426 | 427 | 428 | def _expand_to_beam_size(tensor, beam_size): 429 | """Tiles a given tensor by beam_size. 430 | 431 | Args: 432 | tensor: tensor to tile [batch_size, ...] 433 | beam_size: How much to tile the tensor by. 434 | 435 | Returns: 436 | Tiled tensor [batch_size, beam_size, ...] 437 | """ 438 | tensor = tf.expand_dims(tensor, axis=1) 439 | tile_dims = [1] * tensor.shape.ndims 440 | tile_dims[1] = beam_size 441 | 442 | return tf.tile(tensor, tile_dims) 443 | 444 | 445 | def _shape_list(tensor): 446 | """Return a list of the tensor's shape, and ensure no None values in list.""" 447 | # Get statically known shape (may contain None's for unknown dimensions) 448 | shape = tensor.get_shape().as_list() 449 | 450 | # Ensure that the shape values are not None 451 | dynamic_shape = tf.shape(tensor) 452 | for i in range(len(shape)): # pylint: disable=consider-using-enumerate 453 | if shape[i] is None: 454 | shape[i] = dynamic_shape[i] 455 | return shape 456 | 457 | 458 | def _get_shape_keep_last_dim(tensor): 459 | shape_list = _shape_list(tensor) 460 | 461 | # Only the last 462 | for i in range(len(shape_list) - 1): 463 | shape_list[i] = None 464 | 465 | if isinstance(shape_list[-1], tf.Tensor): 466 | shape_list[-1] = None 467 | return tf.TensorShape(shape_list) 468 | 469 | 470 | def _flatten_beam_dim(tensor): 471 | """Reshapes first two dimensions in to single dimension. 472 | 473 | Args: 474 | tensor: Tensor to reshape of shape [A, B, ...] 475 | 476 | Returns: 477 | Reshaped tensor of shape [A*B, ...] 478 | """ 479 | shape = _shape_list(tensor) 480 | shape[0] *= shape[1] 481 | shape.pop(1) # Remove beam dim 482 | return tf.reshape(tensor, shape) 483 | 484 | 485 | def _unflatten_beam_dim(tensor, batch_size, beam_size): 486 | """Reshapes first dimension back to [batch_size, beam_size]. 487 | 488 | Args: 489 | tensor: Tensor to reshape of shape [batch_size*beam_size, ...] 490 | batch_size: Tensor, original batch size. 491 | beam_size: int, original beam size. 492 | 493 | Returns: 494 | Reshaped tensor of shape [batch_size, beam_size, ...] 495 | """ 496 | shape = _shape_list(tensor) 497 | new_shape = [batch_size, beam_size] + shape[1:] 498 | return tf.reshape(tensor, new_shape) 499 | 500 | 501 | def _gather_beams(nested, beam_indices, batch_size, new_beam_size): 502 | """Gather beams from nested structure of tensors. 503 | 504 | Each tensor in nested represents a batch of beams, where beam refers to a 505 | single search state (beam search involves searching through multiple states 506 | in parallel). 507 | 508 | This function is used to gather the top beams, specified by 509 | beam_indices, from the nested tensors. 510 | 511 | Args: 512 | nested: Nested structure (tensor, list, tuple or dict) containing tensors 513 | with shape [batch_size, beam_size, ...]. 514 | beam_indices: int32 tensor with shape [batch_size, new_beam_size]. Each 515 | value in beam_indices must be between [0, beam_size), and are not 516 | necessarily unique. 517 | batch_size: int size of batch 518 | new_beam_size: int number of beams to be pulled from the nested tensors. 519 | 520 | Returns: 521 | Nested structure containing tensors with shape 522 | [batch_size, new_beam_size, ...] 523 | """ 524 | # Computes the i'th coodinate that contains the batch index for gather_nd. 525 | # Batch pos is a tensor like [[0,0,0,0,],[1,1,1,1],..]. 526 | batch_pos = tf.range(batch_size * new_beam_size) // new_beam_size 527 | batch_pos = tf.reshape(batch_pos, [batch_size, new_beam_size]) 528 | 529 | # Create coordinates to be passed to tf.gather_nd. Stacking creates a tensor 530 | # with shape [batch_size, beam_size, 2], where the last dimension contains 531 | # the (i, j) gathering coordinates. 532 | coordinates = tf.stack([batch_pos, beam_indices], axis=2) 533 | 534 | return nest.map_structure( 535 | lambda state: tf.gather_nd(state, coordinates), nested) 536 | 537 | 538 | def _gather_topk_beams(nested, score_or_log_prob, batch_size, beam_size): 539 | """Gather top beams from nested structure.""" 540 | _, topk_indexes = tf.nn.top_k(score_or_log_prob, k=beam_size) 541 | return _gather_beams(nested, topk_indexes, batch_size, beam_size) 542 | --------------------------------------------------------------------------------