├── __init__.py ├── .gitignore ├── data └── LCSTS │ ├── .DS_Store │ └── preprocess.py ├── config ├── __pycache__ │ └── configurable.cpython-36.pyc ├── configurable.py └── config.json ├── loss.py ├── README.md ├── model.py ├── encoder.py ├── helper.py ├── main.py ├── train_op.py ├── run.py ├── metrics.py └── decoder.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data_transfer.py 2 | .DS_Store 3 | *.pyc 4 | __pycache__ 5 | -------------------------------------------------------------------------------- /data/LCSTS/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StevenWD/DRGD/HEAD/data/LCSTS/.DS_Store -------------------------------------------------------------------------------- /config/__pycache__/configurable.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StevenWD/DRGD/HEAD/config/__pycache__/configurable.cpython-36.pyc -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def cross_entropy_sequence_loss(logits, labels): 5 | cross_entropy = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels, reduction=tf.losses.Reduction.NONE) 6 | loss = tf.reduce_mean(cross_entropy, axis=1) 7 | 8 | return loss 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Deep Recurrent Generative Decoder for Abstractive Text Summarization 2 | 3 | #### Requirement 4 | Tensorflow 1.12 5 | 6 | #### Preprocess data 7 | 8 | ```python 9 | python3 main.py -p 10 | ``` 11 | 12 | #### Train model 13 | 14 | ```python 15 | python3 main.py -t 16 | ``` 17 | 18 | #### Data 19 | 20 | You can download LCSTS dataset from [LCSTS](https://pan.baidu.com/s/1eZyNC7Ult2QvMlBWbjGE8Q) , the access code is : 1ny8 21 | -------------------------------------------------------------------------------- /config/configurable.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | class Configurable(object): 4 | def __init__(self, section): 5 | self._config_filename = '/home/stevenwd/DRGD/config/config.json' 6 | # self._config_filename = '/Users/stevenwd/DRGD/config/config.json' 7 | self._config = json.load(open(self._config_filename, 'r')) 8 | self._section = section 9 | 10 | @property 11 | def config(self): 12 | return self._config[self._section] 13 | 14 | def get_config(self, section=None, key=None): 15 | if section == None: 16 | section = self._section 17 | return self._config[section][key] 18 | 19 | def update_config(self, key, value): 20 | self._config[self._section][key] = value 21 | 22 | def save_config(self): 23 | json.dump(self._config, open(self._config_filename, 'w'), indent=4, ensure_ascii=False) 24 | 25 | @property 26 | def base_dir(self): 27 | return self._config['base_dir'] 28 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from pydoc import locate 3 | from config.configurable import Configurable 4 | from encoder import Encoder 5 | from decoder import Decoder 6 | from train_op import build_train_op 7 | 8 | class DRGD(Configurable): 9 | def __init__(self, mode=tf.contrib.learn.ModeKeys.TRAIN, name='DRGD'): 10 | super(DRGD, self).__init__('model') 11 | self.mode = mode 12 | self.name = name 13 | 14 | def build(self): 15 | with tf.name_scope('placeholders') as scope: 16 | self.source_placeholder = tf.placeholder(dtype=tf.int32, shape=(self.get_config('train', 'batch_size'), self.get_config('data', 'source_max_seq_length')), name='source_placeholder') 17 | self.target_input_placeholder = tf.placeholder(dtype=tf.int32, shape=(self.get_config('train', 'batch_size'), self.get_config('data', 'target_max_seq_length')), name='target_input_placeholder') 18 | self.target_output_placeholder = tf.placeholder(dtype=tf.int32, shape=(self.get_config('train', 'batch_size'), self.get_config('data', 'target_max_seq_length')), name='target_output_placeholder') 19 | self.source_length = tf.placeholder(dtype=tf.int32, shape=(self.get_config('train', 'batch_size'))) 20 | self.target_length = tf.placeholder(dtype=tf.int32, shape=(self.get_config('train', 'batch_size'))) 21 | 22 | with tf.name_scope('encoder') as scope: 23 | self.encoder = Encoder() 24 | 25 | with tf.name_scope('decoder') as scope: 26 | self.decoder = Decoder() 27 | -------------------------------------------------------------------------------- /encoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from config.configurable import Configurable 3 | from collections import namedtuple 4 | from pydoc import locate 5 | 6 | EncoderOutput = namedtuple( 7 | "EncoderOutput", 8 | "outputs final_state attention_values attention_values_length") 9 | 10 | class Encoder(Configurable): 11 | def __init__(self): 12 | super(Encoder, self).__init__('encoder') 13 | self.cell_fw = self.build_cell(self.config['cell_classname'], 'cell_fw') 14 | self.cell_bw = self.build_cell(self.config['cell_classname'], 'cell_bw') 15 | 16 | def encode(self, inputs, length): 17 | # inputs shape: (batch_size, max_len, emd_dim) 18 | # mask shape: (batch_size, max_len) 19 | (output_fw, output_bw), (output_state_fw, output_state_bw) = tf.nn.bidirectional_dynamic_rnn( 20 | cell_fw=self.cell_fw, 21 | cell_bw=self.cell_bw, 22 | inputs=inputs, 23 | sequence_length=length, 24 | scope='encoder', 25 | dtype=tf.float32) 26 | 27 | outputs = tf.concat([output_fw, output_bw], axis=2) 28 | final_state = tf.concat([output_state_fw, output_state_bw], axis=1) 29 | 30 | return EncoderOutput(outputs=outputs, 31 | final_state=final_state, 32 | attention_values=outputs, 33 | attention_values_length=length) 34 | 35 | def build_cell(self, cell_classname, cell_name): 36 | cell_class = locate(cell_classname) 37 | return cell_class(num_units=self.config['cell']['num_units'], 38 | name=cell_name, 39 | **self.config['cell']['cell_params']) 40 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from config.configurable import Configurable 3 | import math 4 | import json 5 | 6 | class Helper(Configurable): 7 | def __init__(self, data_filepath, length_filepath, mode): 8 | # mode: source/target_input/target_output 9 | super(Helper, self).__init__('data') 10 | self.data = np.load(data_filepath) 11 | self.length = np.load(length_filepath) 12 | self.mode = mode 13 | self.maximum_step = math.ceil(self.data.shape[0]/self.get_config('train', 'batch_size')) 14 | self.cursor = 0 15 | 16 | def next_batch(self): 17 | # generate a mini-batch data 18 | # if mode is 'target_input', we need to add a token at the begining of target sequence 19 | 20 | batch_size = self.get_config('train', 'batch_size') 21 | batch = self.data[self.cursor: min(self.data.shape[0], batch_size+self.cursor)] 22 | length = self.length[self.cursor: min(self.length.shape[0], batch_size+self.cursor)] 23 | self.cursor += batch_size 24 | if batch.shape[0] < batch_size: 25 | supplement = self.data[: batch_size-batch.shape[0]] 26 | supplement_length = self.length[: batch_size-batch.shape[0]] 27 | batch = np.concatenate([batch, supplement], axis=0) 28 | length = np.concatenate([length, supplement_length], axis=0) 29 | self.cursor -= self.data.shape[0] 30 | 31 | if self.mode == 'target_input': 32 | batch[:, -1] = self.config['start_id'] 33 | batch = np.roll(batch, shift=1, axis=1) 34 | 35 | return batch, length 36 | 37 | def reset_cursor(self): 38 | self.cursor = 0 39 | -------------------------------------------------------------------------------- /config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "train_file": "data/LCSTS/PART_I.xml", 4 | "dev_file": "data/LCSTS/PART_III.xml", 5 | "test_file": "data/LCSTS/PART_III.xml", 6 | "emd_dim": 200, 7 | "source_max_seq_length": 120, 8 | "target_max_seq_length": 30, 9 | "source_word_num": 10607, 10 | "target_word_num": 5437, 11 | "embedding_filepath": "data/LCSTS/emd_weight.npy", 12 | "source_embedding_filepath": "data/LCSTS/source_emd_weight.npy", 13 | "target_embedding_filepath": "data/LCSTS/target_emd_weight.npy", 14 | "start_id": 1, 15 | "end_id": 2 16 | }, 17 | "encoder": { 18 | "cell_classname": "tensorflow.nn.rnn_cell.GRUCell", 19 | "encoder_params": {}, 20 | "cell": { 21 | "num_units": 128, 22 | "cell_params": {} 23 | } 24 | }, 25 | "decoder": { 26 | "cell_classname": "tensorflow.nn.rnn_cell.GRUCell", 27 | "cell": { 28 | "num_units": 256, 29 | "cell_params": {} 30 | }, 31 | "variable_size": 256, 32 | "beam_search_width": 10, 33 | "length_penalty_weight": 1 34 | }, 35 | "model": {}, 36 | "base_dir": "/home/stevenwd/DRGD/", 37 | "train_op": { 38 | "name": "Adam", 39 | "learning_rate": 0.001, 40 | "params": {}, 41 | "lr_decay": { 42 | "decay_type": "", 43 | "decay_steps": 100, 44 | "decay_rate": 0.999, 45 | "start_decay_at": 1000, 46 | "stop_decay_at": "tf.int32.max", 47 | "min_learning_rate": 1e-12, 48 | "staircase": false 49 | }, 50 | "clip_gradients": 5.0, 51 | "sync_replicas": 0, 52 | "sync_replicas_to_aggregate": 0 53 | }, 54 | "train": { 55 | "batch_size": 928, 56 | "epoch_num": 10, 57 | "metric_path": "metrics.rouge" 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from data.LCSTS.preprocess import preprocess 3 | import json 4 | from config.configurable import Configurable 5 | from train_op import build_train_op 6 | from helper import Helper 7 | from run import train 8 | 9 | config = Configurable('data') 10 | 11 | if __name__=='__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('-p', '--preprocess', help="preprocess train/dev/test data", action='store_true') 14 | parser.add_argument('-t', '--train', help="train DRGD model", action='store_true') 15 | # parser.add_argument('--test', help="test trained DRGD model", action='store_true') 16 | # parser.add_argument('--inference', help="inference based model", action='store_true') 17 | result = parser.parse_args() 18 | 19 | if result.preprocess: 20 | preprocess(fname=config.base_dir+config.config['train_file'], mode='train', config=config) 21 | preprocess(fname=config.base_dir+config.config['dev_file'], mode='dev', config=config) 22 | 23 | if result.train: 24 | source_helper = Helper(data_filepath='./data/LCSTS/train_source.npy', length_filepath='./data/LCSTS/train_source_length.npy', mode='source') 25 | target_input_helper = Helper(data_filepath='./data/LCSTS/train_target.npy', length_filepath='./data/LCSTS/train_target_length.npy', mode='target_input') 26 | target_output_helper = Helper(data_filepath='./data/LCSTS/train_target.npy', length_filepath='./data/LCSTS/train_target_length.npy', mode='target_output') 27 | valid_source_helper = Helper(data_filepath='./data/LCSTS/dev_source.npy', length_filepath='./data/LCSTS/dev_source_length.npy', mode='source') 28 | valid_target_output_helper = Helper(data_filepath='./data/LCSTS/dev_target.npy', length_filepath='./data/LCSTS/dev_target_length.npy', mode='target_output') 29 | char_dict = json.load(open('./data/LCSTS/target_char_dict.json', 'r')) 30 | train(source_helper, target_input_helper, target_output_helper, valid_source_helper, valid_target_output_helper, char_dict) 31 | -------------------------------------------------------------------------------- /train_op.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from pydoc import locate 3 | from config.configurable import Configurable 4 | 5 | config = Configurable('train_op') 6 | 7 | def create_learning_rate_decay_fn(decay_type, decay_steps, decay_rate, start_decay_at=0, stop_decay_at=1e9, min_learning_rate=None, staircase=False): 8 | if decay_type is None or decay_type == '': 9 | return None 10 | 11 | start_decay_at = tf.to_int32(start_decay_at) 12 | stop_decay_at = tf.to_int32(stop_decay_at) 13 | 14 | def decay_fn(learning_rate, global_step): 15 | global_step = tf.to_int32(global_step) 16 | 17 | decay_type_fn = getattr(tf.train, decay_type) 18 | decayed_learning_rate = decay_type_fn( 19 | learning_rate=learning_rate, 20 | global_step=tf.minimum(global_step, stop_decay_at) - start_decay_at, 21 | decay_steps=decay_steps, 22 | decay_rate=decay_rate, 23 | staircase=staircase, 24 | name='decayed_learning_rate') 25 | 26 | final_lr = tf.train.piecewise_constant( 27 | x=global_step, 28 | boundaries=[start_decay_at], 29 | values=[learning_rate, decayed_learning_rate]) 30 | 31 | if min_learning_rate: 32 | final_lr = tf.maximum(final_lr, min_learning_rate) 33 | return final_lr 34 | 35 | return decay_fn 36 | 37 | def build_train_op(loss): 38 | learning_rate_decay_fn = create_learning_rate_decay_fn( 39 | **config.config['lr_decay']) 40 | 41 | optimizer = _create_optimizer() 42 | train_op = tf.contrib.layers.optimize_loss( 43 | loss=loss, 44 | global_step=tf.train.get_global_step(), 45 | learning_rate=config.config['learning_rate'], 46 | learning_rate_decay_fn=learning_rate_decay_fn, 47 | # learning_rate_decay_fn=None, 48 | clip_gradients=_clip_gradients, 49 | optimizer=optimizer, 50 | summaries=['learning_rate', 'loss', 'gradients', 'gradient_norm']) 51 | 52 | return train_op 53 | 54 | def _clip_gradients(grads_and_vars): 55 | gradients, variables = zip(*grads_and_vars) 56 | clip_gradients, _ = tf.clip_by_global_norm( 57 | gradients, config.config['clip_gradients']) 58 | return list(zip(clip_gradients, variables)) 59 | 60 | def _create_optimizer(): 61 | name = config.config['name'] 62 | optimizer = tf.contrib.layers.OPTIMIZER_CLS_NAMES[name]( 63 | learning_rate=config.config['learning_rate'], 64 | **config.config['params']) 65 | 66 | if config.config['sync_replicas'] > 0: 67 | optimizer = tf.train.SyncReplicasOptimizer( 68 | opt=optimizer, 69 | replicas_to_aggregate=config.config['sync_replicas_to_aggregate'], 70 | total_num_replicas=config.config['sync_replicas']) 71 | global_vars.SYNC_REPLICAS_OPTIMIZER = optimizer 72 | 73 | return optimizer 74 | 75 | 76 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from model import DRGD 4 | from tqdm import tqdm 5 | from train_op import build_train_op 6 | from loss import cross_entropy_sequence_loss 7 | import os 8 | from metrics import compute_metric_score 9 | import json 10 | import tensorflow.contrib.seq2seq as seq2seq 11 | 12 | tf.logging.set_verbosity(tf.logging.INFO) 13 | 14 | def train(source_helper, target_input_helper, target_output_helper, valid_source_helper, valid_target_output_helper, char_dict): 15 | model = DRGD() 16 | model.build() 17 | 18 | source_emd_weight = np.load(model.base_dir+model.get_config('data', 'source_embedding_filepath')) 19 | target_emd_weight = np.load(model.base_dir+model.get_config('data', 'target_embedding_filepath')) 20 | 21 | source_embedding = tf.get_variable(initializer=source_emd_weight, name='source_embedding') 22 | target_embedding = tf.get_variable(initializer=target_emd_weight, name='target_embedding') 23 | 24 | source_sequence = tf.nn.embedding_lookup(source_embedding, model.source_placeholder) 25 | target_sequence = tf.nn.embedding_lookup(target_embedding, model.target_input_placeholder) 26 | 27 | encoder_output = model.encoder.encode(source_sequence, model.source_length) 28 | logits, KL = model.decoder.decode(target_sequence, model.target_length, encoder_output) 29 | 30 | y_s = model.decoder.beam_search(encoder_output, target_embedding) 31 | 32 | target_mask = tf.sequence_mask(model.target_length, model.get_config('data', 'target_max_seq_length'), dtype=tf.float32) 33 | crossent = seq2seq.sequence_loss(logits, model.target_output_placeholder, target_mask, average_across_batch=True, average_across_timesteps=True) 34 | kl = tf.reduce_mean(tf.multiply(KL, target_mask), [0, 1]) 35 | 36 | loss = tf.add(crossent, kl) 37 | 38 | opt = tf.train.AdamOptimizer() 39 | train_op = opt.minimize(loss) 40 | 41 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 42 | conf = tf.ConfigProto() 43 | conf.gpu_options.allow_growth=True 44 | 45 | with tf.Session(config=conf) as sess: 46 | sess.run(tf.global_variables_initializer()) 47 | 48 | for epoch in range(model.get_config('train', 'epoch_num')): 49 | for step in range(source_helper.maximum_step): 50 | source_input, source_length = source_helper.next_batch() 51 | target_input, target_length = target_input_helper.next_batch() 52 | target_output, _ = target_output_helper.next_batch() 53 | 54 | _, l = sess.run([train_op, loss], feed_dict={ 55 | model.source_placeholder: source_input, 56 | model.target_input_placeholder: target_input, 57 | model.target_output_placeholder: target_output, 58 | model.source_length: source_length, 59 | model.target_length: target_length}) 60 | 61 | tf.logging.log_every_n(level=tf.logging.INFO, msg='epoch {}/{}, step {}/{}, loss: {}'.format(epoch, model.get_config('train', 'epoch_num'), step, source_helper.maximum_step, l), n=10) 62 | final_output_all = list() 63 | target_output_all = list() 64 | for v_step in tqdm(range(valid_source_helper.maximum_step)): 65 | source_input, source_length = valid_source_helper.next_batch() 66 | target_output, _ = valid_target_output_helper.next_batch() 67 | 68 | final_output = sess.run(y_s, feed_dict={ 69 | model.source_placeholder: source_input, 70 | model.source_length: source_length}) 71 | 72 | final_output_all.append(final_output) 73 | target_output_all.append(target_output) 74 | 75 | final_output = np.concatenate(final_output_all, axis=0) 76 | target_output = np.concatenate(target_output_all, axis=0) 77 | np.save('./final_output.npy', final_output) 78 | np.save('./target_output.npy', target_output) 79 | 80 | score = compute_metric_score(model.get_config('train', 'metric_path'), final_output, target_output, char_dict) 81 | tf.logging.info('Rough metric score :\n{}'.format(score)) 82 | valid_source_helper.reset_cursor() 83 | valid_target_output_helper.reset_cursor() 84 | -------------------------------------------------------------------------------- /data/LCSTS/preprocess.py: -------------------------------------------------------------------------------- 1 | import xmltodict 2 | import numpy as np 3 | import json 4 | from tqdm import tqdm 5 | from keras.preprocessing.sequence import pad_sequences 6 | from copy import deepcopy 7 | 8 | def build_dict(content_list, base_dir): 9 | print('----build character dictionary') 10 | char_dict = dict() 11 | char_dict[''] = len(char_dict) 12 | char_dict[''] = len(char_dict) 13 | char_dict[''] = len(char_dict) 14 | 15 | source_char_dict = deepcopy(char_dict) 16 | target_char_dict = deepcopy(char_dict) 17 | 18 | char_num = dict() 19 | for content in content_list: 20 | chars = [c for c in content] 21 | for c in chars: 22 | char_num[c] = char_num.get(c, 0) + 1 23 | 24 | for content in content_list: 25 | chars = [c for c in content] 26 | for c in chars: 27 | if c not in source_char_dict: 28 | source_char_dict[c] = len(source_char_dict) 29 | if char_num[c] > 100: 30 | target_char_dict[c] = len(target_char_dict) 31 | 32 | json.dump(source_char_dict, open(base_dir+'data/LCSTS/source_char_dict.json', 'w'), indent=4, ensure_ascii=False) 33 | json.dump(target_char_dict, open(base_dir+'data/LCSTS/target_char_dict.json', 'w'), indent=4, ensure_ascii=False) 34 | return source_char_dict, target_char_dict 35 | 36 | def build_emd(char_dict, emd_dim, base_dir, mode): 37 | print('----build embedding matrix') 38 | emd_weight = np.zeros(shape=(1, emd_dim)).astype(np.float32) 39 | emd_weight = np.concatenate((emd_weight, np.random.randn(len(char_dict)-1, emd_dim).astype(np.float32)), axis=0) 40 | np.save(base_dir+'data/LCSTS/'+mode+'_emd_weight.npy', emd_weight) 41 | embedding_filepath = 'data/LCSTS/'+mode+'_emd_weight.npy' 42 | return embedding_filepath 43 | 44 | def preprocess(fname, mode, config): 45 | print('Process file {} in mode {}'.format(fname, mode)) 46 | f = open(fname, 'r') 47 | data = xmltodict.parse(f.read()) 48 | summary_list = list() 49 | text_list = list() 50 | for item in data['data']['doc']: 51 | if mode != 'train' and int(item['human_label']) < 3: 52 | continue 53 | summary_list.append(item['summary']) 54 | text_list.append(item['short_text']) 55 | 56 | if mode == 'train': 57 | source_char_dict, target_char_dict = build_dict(text_list, config.base_dir) 58 | config.update_config('source_word_num', len(source_char_dict)) 59 | config.update_config('target_word_num', len(target_char_dict)) 60 | 61 | embedding_filepath = build_emd(source_char_dict, config.config['emd_dim'], config.base_dir, 'source') 62 | config.update_config('source_embedding_filepath', embedding_filepath) 63 | embedding_filepath = build_emd(target_char_dict, config.config['emd_dim'], config.base_dir, 'target') 64 | config.update_config('target_embedding_filepath', embedding_filepath) 65 | config.save_config() 66 | else: 67 | source_char_dict = json.load(open(config.base_dir+'data/LCSTS/source_char_dict.json', 'r')) 68 | target_char_dict = json.load(open(config.base_dir+'data/LCSTS/target_char_dict.json', 'r')) 69 | 70 | 71 | print('----converting data') 72 | source = list() 73 | source_length = list() 74 | target = list() 75 | target_length = list() 76 | for text in text_list: 77 | feature = list() 78 | for c in [tc for tc in text]: 79 | if c in source_char_dict: 80 | feature.append(source_char_dict[c]) 81 | 82 | feature.append(source_char_dict['']) 83 | source_length.append(min(len(feature), config.config['source_max_seq_length'])) 84 | source.append(feature) 85 | 86 | source = pad_sequences(source, 87 | maxlen=config.config['source_max_seq_length'], 88 | dtype='int32', 89 | padding='post', 90 | truncating='post', 91 | value=source_char_dict['']) 92 | 93 | for summary in summary_list: 94 | feature = list() 95 | for c in [tc for tc in summary]: 96 | if c in target_char_dict: 97 | feature.append(target_char_dict[c]) 98 | 99 | feature.append(target_char_dict['']) 100 | target_length.append(min(len(feature), config.config['target_max_seq_length'])) 101 | target.append(feature) 102 | 103 | target = pad_sequences(target, 104 | maxlen=config.config['target_max_seq_length'], 105 | dtype='int32', 106 | padding='post', 107 | truncating='post', 108 | value=target_char_dict['']) 109 | 110 | np.save(config.base_dir+'data/LCSTS/'+mode+'_source.npy', source) 111 | np.save(config.base_dir+'data/LCSTS/'+mode+'_target.npy', target) 112 | np.save(config.base_dir+'data/LCSTS/'+mode+'_source_length.npy', source_length) 113 | np.save(config.base_dir+'data/LCSTS/'+mode+'_target_length.npy', target_length) 114 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2017 Google Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ROUGe metric implementation. 16 | 17 | This is a modified and slightly extended verison of 18 | https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | from __future__ import unicode_literals 25 | 26 | import itertools 27 | import numpy as np 28 | from pydoc import locate 29 | import json 30 | 31 | #pylint: disable=C0103 32 | 33 | 34 | def _get_ngrams(n, text): 35 | """Calcualtes n-grams. 36 | 37 | Args: 38 | n: which n-grams to calculate 39 | text: An array of tokens 40 | 41 | Returns: 42 | A set of n-grams 43 | """ 44 | ngram_set = set() 45 | text_length = len(text) 46 | max_index_ngram_start = text_length - n 47 | for i in range(max_index_ngram_start + 1): 48 | ngram_set.add(tuple(text[i:i + n])) 49 | return ngram_set 50 | 51 | 52 | def _split_into_words(sentences): 53 | """Splits multiple sentences into words and flattens the result""" 54 | return list(itertools.chain(*[_.split(" ") for _ in sentences])) 55 | 56 | 57 | def _get_word_ngrams(n, sentences): 58 | """Calculates word n-grams for multiple sentences. 59 | """ 60 | assert len(sentences) > 0 61 | assert n > 0 62 | 63 | words = _split_into_words(sentences) 64 | return _get_ngrams(n, words) 65 | 66 | 67 | def _len_lcs(x, y): 68 | """ 69 | Returns the length of the Longest Common Subsequence between sequences x 70 | and y. 71 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 72 | 73 | Args: 74 | x: sequence of words 75 | y: sequence of words 76 | 77 | Returns 78 | integer: Length of LCS between x and y 79 | """ 80 | table = _lcs(x, y) 81 | n, m = len(x), len(y) 82 | return table[n, m] 83 | 84 | 85 | def _lcs(x, y): 86 | """ 87 | Computes the length of the longest common subsequence (lcs) between two 88 | strings. The implementation below uses a DP programming algorithm and runs 89 | in O(nm) time where n = len(x) and m = len(y). 90 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 91 | 92 | Args: 93 | x: collection of words 94 | y: collection of words 95 | 96 | Returns: 97 | Table of dictionary of coord and len lcs 98 | """ 99 | n, m = len(x), len(y) 100 | table = dict() 101 | for i in range(n + 1): 102 | for j in range(m + 1): 103 | if i == 0 or j == 0: 104 | table[i, j] = 0 105 | elif x[i - 1] == y[j - 1]: 106 | table[i, j] = table[i - 1, j - 1] + 1 107 | else: 108 | table[i, j] = max(table[i - 1, j], table[i, j - 1]) 109 | return table 110 | 111 | 112 | def _recon_lcs(x, y): 113 | """ 114 | Returns the Longest Subsequence between x and y. 115 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 116 | 117 | Args: 118 | x: sequence of words 119 | y: sequence of words 120 | 121 | Returns: 122 | sequence: LCS of x and y 123 | """ 124 | i, j = len(x), len(y) 125 | table = _lcs(x, y) 126 | 127 | def _recon(i, j): 128 | """private recon calculation""" 129 | if i == 0 or j == 0: 130 | return [] 131 | elif x[i - 1] == y[j - 1]: 132 | return _recon(i - 1, j - 1) + [(x[i - 1], i)] 133 | elif table[i - 1, j] > table[i, j - 1]: 134 | return _recon(i - 1, j) 135 | else: 136 | return _recon(i, j - 1) 137 | 138 | recon_tuple = tuple(map(lambda x: x[0], _recon(i, j))) 139 | return recon_tuple 140 | 141 | 142 | def rouge_n(evaluated_sentences, reference_sentences, n=2): 143 | """ 144 | Computes ROUGE-N of two text collections of sentences. 145 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/ 146 | papers/rouge-working-note-v1.3.1.pdf 147 | 148 | Args: 149 | evaluated_sentences: The sentences that have been picked by the summarizer 150 | reference_sentences: The sentences from the referene set 151 | n: Size of ngram. Defaults to 2. 152 | 153 | Returns: 154 | A tuple (f1, precision, recall) for ROUGE-N 155 | 156 | Raises: 157 | ValueError: raises exception if a param has len <= 0 158 | """ 159 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 160 | raise ValueError("Collections must contain at least 1 sentence.") 161 | 162 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences) 163 | reference_ngrams = _get_word_ngrams(n, reference_sentences) 164 | reference_count = len(reference_ngrams) 165 | evaluated_count = len(evaluated_ngrams) 166 | 167 | # Gets the overlapping ngrams between evaluated and reference 168 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 169 | overlapping_count = len(overlapping_ngrams) 170 | 171 | # Handle edge case. This isn't mathematically correct, but it's good enough 172 | if evaluated_count == 0: 173 | precision = 0.0 174 | else: 175 | precision = overlapping_count / evaluated_count 176 | 177 | if reference_count == 0: 178 | recall = 0.0 179 | else: 180 | recall = overlapping_count / reference_count 181 | 182 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) 183 | 184 | # return overlapping_count / reference_count 185 | return f1_score, precision, recall 186 | 187 | 188 | def _f_p_r_lcs(llcs, m, n): 189 | """ 190 | Computes the LCS-based F-measure score 191 | Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 192 | rouge-working-note-v1.3.1.pdf 193 | 194 | Args: 195 | llcs: Length of LCS 196 | m: number of words in reference summary 197 | n: number of words in candidate summary 198 | 199 | Returns: 200 | Float. LCS-based F-measure score 201 | """ 202 | r_lcs = llcs / m 203 | p_lcs = llcs / n 204 | beta = p_lcs / (r_lcs + 1e-12) 205 | num = (1 + (beta**2)) * r_lcs * p_lcs 206 | denom = r_lcs + ((beta**2) * p_lcs) 207 | f_lcs = num / (denom + 1e-12) 208 | return f_lcs, p_lcs, r_lcs 209 | 210 | 211 | def rouge_l_sentence_level(evaluated_sentences, reference_sentences): 212 | """ 213 | Computes ROUGE-L (sentence level) of two text collections of sentences. 214 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 215 | rouge-working-note-v1.3.1.pdf 216 | 217 | Calculated according to: 218 | R_lcs = LCS(X,Y)/m 219 | P_lcs = LCS(X,Y)/n 220 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 221 | 222 | where: 223 | X = reference summary 224 | Y = Candidate summary 225 | m = length of reference summary 226 | n = length of candidate summary 227 | 228 | Args: 229 | evaluated_sentences: The sentences that have been picked by the summarizer 230 | reference_sentences: The sentences from the referene set 231 | 232 | Returns: 233 | A float: F_lcs 234 | 235 | Raises: 236 | ValueError: raises exception if a param has len <= 0 237 | """ 238 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 239 | raise ValueError("Collections must contain at least 1 sentence.") 240 | reference_words = _split_into_words(reference_sentences) 241 | evaluated_words = _split_into_words(evaluated_sentences) 242 | m = len(reference_words) 243 | n = len(evaluated_words) 244 | lcs = _len_lcs(evaluated_words, reference_words) 245 | return _f_p_r_lcs(lcs, m, n) 246 | 247 | 248 | def _union_lcs(evaluated_sentences, reference_sentence): 249 | """ 250 | Returns LCS_u(r_i, C) which is the LCS score of the union longest common 251 | subsequence between reference sentence ri and candidate summary C. For example 252 | if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 and 253 | c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 is 254 | “w1 w2” and the longest common subsequence of r_i and c2 is “w1 w3 w5”. The 255 | union longest common subsequence of r_i, c1, and c2 is “w1 w2 w3 w5” and 256 | LCS_u(r_i, C) = 4/5. 257 | 258 | Args: 259 | evaluated_sentences: The sentences that have been picked by the summarizer 260 | reference_sentence: One of the sentences in the reference summaries 261 | 262 | Returns: 263 | float: LCS_u(r_i, C) 264 | 265 | ValueError: 266 | Raises exception if a param has len <= 0 267 | """ 268 | if len(evaluated_sentences) <= 0: 269 | raise ValueError("Collections must contain at least 1 sentence.") 270 | 271 | lcs_union = set() 272 | reference_words = _split_into_words([reference_sentence]) 273 | combined_lcs_length = 0 274 | for eval_s in evaluated_sentences: 275 | evaluated_words = _split_into_words([eval_s]) 276 | lcs = set(_recon_lcs(reference_words, evaluated_words)) 277 | combined_lcs_length += len(lcs) 278 | lcs_union = lcs_union.union(lcs) 279 | 280 | union_lcs_count = len(lcs_union) 281 | union_lcs_value = union_lcs_count / combined_lcs_length 282 | return union_lcs_value 283 | 284 | 285 | def rouge_l_summary_level(evaluated_sentences, reference_sentences): 286 | """ 287 | Computes ROUGE-L (summary level) of two text collections of sentences. 288 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 289 | rouge-working-note-v1.3.1.pdf 290 | 291 | Calculated according to: 292 | R_lcs = SUM(1, u)[LCS(r_i,C)]/m 293 | P_lcs = SUM(1, u)[LCS(r_i,C)]/n 294 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 295 | 296 | where: 297 | SUM(i,u) = SUM from i through u 298 | u = number of sentences in reference summary 299 | C = Candidate summary made up of v sentences 300 | m = number of words in reference summary 301 | n = number of words in candidate summary 302 | 303 | Args: 304 | evaluated_sentences: The sentences that have been picked by the summarizer 305 | reference_sentence: One of the sentences in the reference summaries 306 | 307 | Returns: 308 | A float: F_lcs 309 | 310 | Raises: 311 | ValueError: raises exception if a param has len <= 0 312 | """ 313 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 314 | raise ValueError("Collections must contain at least 1 sentence.") 315 | 316 | # total number of words in reference sentences 317 | m = len(_split_into_words(reference_sentences)) 318 | 319 | # total number of words in evaluated sentences 320 | n = len(_split_into_words(evaluated_sentences)) 321 | 322 | union_lcs_sum_across_all_references = 0 323 | for ref_s in reference_sentences: 324 | union_lcs_sum_across_all_references += _union_lcs(evaluated_sentences, 325 | ref_s) 326 | return _f_p_r_lcs(union_lcs_sum_across_all_references, m, n) 327 | 328 | 329 | def rouge(hypotheses, references): 330 | """Calculates average rouge scores for a list of hypotheses and 331 | references""" 332 | 333 | # Filter out hyps that are of 0 length 334 | # hyps_and_refs = zip(hypotheses, references) 335 | # hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0] 336 | # hypotheses, references = zip(*hyps_and_refs) 337 | 338 | # Calculate ROUGE-1 F1, precision, recall scores 339 | rouge_1 = [ 340 | rouge_n([hyp], [ref], 1) for hyp, ref in zip(hypotheses, references) 341 | ] 342 | rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1)) 343 | 344 | # Calculate ROUGE-2 F1, precision, recall scores 345 | rouge_2 = [ 346 | rouge_n([hyp], [ref], 2) for hyp, ref in zip(hypotheses, references) 347 | ] 348 | rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2)) 349 | 350 | # Calculate ROUGE-L F1, precision, recall scores 351 | rouge_l = [ 352 | rouge_l_sentence_level([hyp], [ref]) 353 | for hyp, ref in zip(hypotheses, references) 354 | ] 355 | rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l)) 356 | 357 | return { 358 | "rouge_1/f_score": rouge_1_f, 359 | "rouge_1/r_score": rouge_1_r, 360 | "rouge_1/p_score": rouge_1_p, 361 | "rouge_2/f_score": rouge_2_f, 362 | "rouge_2/r_score": rouge_2_r, 363 | "rouge_2/p_score": rouge_2_p, 364 | "rouge_l/f_score": rouge_l_f, 365 | "rouge_l/r_score": rouge_l_r, 366 | "rouge_l/p_score": rouge_l_p, 367 | } 368 | 369 | 370 | def compute_metric_score(metric_name, inference, gold, char_dict): 371 | reverse_char_dict = dict() 372 | invalid_list = [char_dict[''], char_dict['']] 373 | for key, value in char_dict.items(): 374 | reverse_char_dict[value] = key 375 | 376 | inference_list = list() 377 | gold_list = list() 378 | for index in range(inference.shape[0]): 379 | i_title = ' '.join([reverse_char_dict[c] for c in inference[index] if c not in invalid_list]) 380 | g_title = ' '.join([reverse_char_dict[c] for c in gold[index] if c not in invalid_list]) 381 | # print(i_title) 382 | # print(g_title) 383 | inference_list.append(i_title) 384 | gold_list.append(g_title) 385 | metric_method = locate(metric_name) 386 | json.dump(inference_list, open('./inference_list.json', 'w'), indent=4, ensure_ascii=False) 387 | json.dump(gold_list, open('./gold_list.json', 'w'), indent=4, ensure_ascii=False) 388 | return metric_method(inference_list, gold_list) 389 | -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from config.configurable import Configurable 3 | from collections import namedtuple 4 | from pydoc import locate 5 | import numpy as np 6 | import tensorflow.contrib.seq2seq as seq2seq 7 | 8 | 9 | BeamSearchDecoderState = namedtuple("BeamSearchDecoderState", 10 | ("cell_state", "log_probs", "finished", "lengths") 11 | ) 12 | 13 | 14 | class Decoder(Configurable): 15 | def __init__(self): 16 | super(Decoder, self).__init__('decoder') 17 | self.cell_1 = self.build_cell(self.config['cell_classname'], 'cell_1') # GRU cell 1 18 | self.cell_2 = self.build_cell(self.config['cell_classname'], 'cell_2') # GRU cell 2 19 | self.init_z = tf.zeros((self.get_config('train', 'batch_size'), self.config['variable_size'])) # initlize value for z 20 | 21 | def decode_onestep(self, inputs, encoder_output, state_1, state_2, z): 22 | batch_size = self.get_config('train', 'batch_size') 23 | variable_size = self.config['variable_size'] 24 | hidden_dim = self.config['cell']['num_units'] 25 | word_num = self.get_config('data', 'target_word_num') 26 | emd_dim = self.get_config('data', 'emd_dim') 27 | source_max_seq_length = self.get_config('data', 'source_max_seq_length') 28 | 29 | source_mask = tf.sequence_mask(lengths=encoder_output.attention_values_length, maxlen=source_max_seq_length, dtype=tf.bool) 30 | 31 | with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE) as scope: 32 | W_dy_zh = tf.get_variable(name='W_dy_zh', shape=[variable_size, hidden_dim], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) 33 | W_dz_hh = tf.get_variable(name='W_dz_hh', shape=[hidden_dim, hidden_dim], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) 34 | b_dy_h = tf.get_variable(name='b_dy_h', shape=[hidden_dim], dtype=tf.float32, initializer=tf.zeros_initializer()) 35 | 36 | W_d_hy = tf.get_variable(name='W_d_hy', shape=[hidden_dim, word_num], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) 37 | b_d_hy = tf.get_variable(name='b_d_hy', shape=[word_num], dtype=tf.float32, initializer=tf.zeros_initializer()) 38 | 39 | with tf.variable_scope('vae', reuse=tf.AUTO_REUSE): 40 | # encoder 41 | W_ez_yh = tf.get_variable('W_ez_yh', shape=[emd_dim, hidden_dim], initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32) 42 | W_ez_zh = tf.get_variable('W_ez_zh', shape=[variable_size, hidden_dim], initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32) 43 | W_ez_hh = tf.get_variable('W_ez_hh', shape=[hidden_dim, hidden_dim], initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32) 44 | b_ez_h = tf.get_variable(name='b_ez_h', shape=[hidden_dim], initializer=tf.zeros_initializer(), dtype=tf.float32) 45 | # mean 46 | W_ez_hm = tf.get_variable('W_ez_mu', shape=[hidden_dim, variable_size], initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32) 47 | b_ez_m = tf.get_variable(name='b_ez_mu', shape=[variable_size], initializer=tf.zeros_initializer(), dtype=tf.float32) 48 | # var 49 | W_h_s = tf.get_variable('W_ez_sigma', shape=[hidden_dim, variable_size], initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32) 50 | b_ez_s = tf.get_variable(name='b_ez_sigma', shape=[variable_size], initializer=tf.zeros_initializer(), dtype=tf.float32) 51 | 52 | last_state_1 = state_1 53 | _, state_1 = self.cell_1(inputs=inputs, state=state_1) 54 | a_ij = self.compute_attention_weight(state_1, encoder_output.attention_values, source_mask) 55 | c_t = tf.reduce_sum(tf.multiply(encoder_output.attention_values, a_ij), axis=1) 56 | 57 | _, state_2 = self.cell_2(inputs=tf.concat([c_t, inputs], axis=-1), state=state_2) 58 | 59 | h_ez_t = tf.nn.sigmoid(tf.add(tf.add_n([tf.matmul(inputs, W_ez_yh), tf.matmul(z, W_ez_zh), tf.matmul(last_state_1, W_ez_hh)]), b_ez_h)) 60 | mean_t = tf.nn.xw_plus_b(h_ez_t, W_ez_hm, b_ez_m) 61 | var_t = tf.exp(tf.nn.xw_plus_b(h_ez_t, W_h_s, b_ez_s)) 62 | sigma = tf.sqrt(var_t) 63 | 64 | eps = tf.random_normal((batch_size, variable_size)) 65 | z_t = tf.add(mean_t, tf.multiply(sigma, eps)) 66 | 67 | h_dy_t = tf.tanh(tf.add(tf.add_n([tf.matmul(z_t, W_dy_zh), tf.matmul(state_2, W_dz_hh)]), b_dy_h)) 68 | logit_t = tf.nn.xw_plus_b(h_dy_t, W_d_hy, b_d_hy) 69 | 70 | KL_t = tf.to_float(-0.5) * tf.reduce_sum(tf.to_float(1.0)+tf.log(var_t)-tf.square(mean_t)-var_t, axis=-1) 71 | 72 | return state_1, state_2, z_t, logit_t, KL_t 73 | 74 | def decode(self, inputs, length, encoder_output): 75 | target_max_seq_length = self.get_config('data', 'target_max_seq_length') 76 | batch_size = self.get_config('train', 'batch_size') 77 | 78 | state_1 = tf.reduce_mean(encoder_output.outputs, axis=1) 79 | state_2 = tf.reduce_mean(encoder_output.outputs, axis=1) 80 | z = self.init_z 81 | logits = list() 82 | KL = list() 83 | for step in range(target_max_seq_length): 84 | state_1, state_2, z, logit, KL_t = self.decode_onestep(inputs[:, step, :], encoder_output, state_1, state_2, z) 85 | logits.append(logit) 86 | KL.append(KL_t) 87 | logits = tf.stack(logits, axis=1) 88 | KL = tf.stack(KL, axis=1) 89 | 90 | return logits, KL 91 | 92 | def beam_search(self, encoder_output, embedding): 93 | vocab_size = self.get_config('data', 'target_word_num') 94 | end_id = self.get_config('data', 'end_id') 95 | start_id = self.get_config('data', 'start_id') 96 | beam_width = self.config['beam_search_width'] 97 | batch_size = self.get_config('train', 'batch_size') 98 | target_max_seq_length = self.get_config('data', 'target_max_seq_length') 99 | length_penalty_weight = self.config['length_penalty_weight'] 100 | hidden_dim = self.config['cell']['num_units'] 101 | batch_size_beam_width = batch_size * beam_width 102 | 103 | start_tokens = tf.ones([batch_size, beam_width], tf.int32) * start_id 104 | start_inputs = tf.nn.embedding_lookup(embedding, start_tokens) 105 | inputs = start_inputs 106 | 107 | finished = tf.one_hot( 108 | indices=tf.zeros(batch_size, tf.int32), 109 | depth=beam_width, 110 | on_value=False, 111 | off_value=True, 112 | dtype=tf.bool) 113 | log_probs = tf.one_hot( 114 | indices=tf.zeros(batch_size, tf.int32), 115 | depth=beam_width, 116 | on_value=tf.convert_to_tensor(0.0, tf.float32), 117 | off_value=tf.convert_to_tensor(-np.Inf, tf.float32), 118 | dtype=tf.float32) 119 | 120 | tile_state = tf.tile(tf.expand_dims(tf.reduce_mean(encoder_output.outputs, axis=1), 1), [1, beam_width, 1]) 121 | tile_z = tf.tile(tf.expand_dims(self.init_z, 1), [1, beam_width, 1]) 122 | 123 | beam_state = BeamSearchDecoderState( 124 | cell_state=(tile_state, tile_state, tile_z), 125 | log_probs=log_probs, 126 | finished=finished, 127 | lengths=tf.zeros([batch_size, beam_width], dtype=tf.int32)) 128 | 129 | y_s = tf.zeros([batch_size, beam_width, 0], dtype=tf.int32) 130 | 131 | for step in range(target_max_seq_length): 132 | state_1, state_2, z = beam_state.cell_state 133 | 134 | new_state_1 = list() 135 | new_state_2 = list() 136 | new_z = list() 137 | logits = list() 138 | for search_step in range(beam_width): 139 | state_1_t, state_2_t, z_t, logits_t, _ = self.decode_onestep(inputs[:, search_step, :], encoder_output, state_1[:, search_step, :], state_2[:, search_step, :], z[:, search_step, :]) 140 | new_state_1.append(state_1_t) 141 | new_state_2.append(state_2_t) 142 | new_z.append(z_t) 143 | logits.append(logits_t) 144 | 145 | state_1 = tf.stack(new_state_1, 1) 146 | state_2 = tf.stack(new_state_2, 1) 147 | z = tf.stack(new_z, 1) 148 | logits = tf.stack(logits, 1) 149 | 150 | prediction_lengths = beam_state.lengths 151 | previously_finished = beam_state.finished 152 | 153 | step_log_probs = tf.nn.log_softmax(logits, axis=-1) 154 | step_log_probs = self.mask_probs(step_log_probs, end_id, previously_finished) 155 | total_probs = tf.expand_dims(beam_state.log_probs, 2) + step_log_probs 156 | 157 | lengths_to_add = tf.one_hot( 158 | indices=tf.fill([batch_size, beam_width], end_id), 159 | depth=vocab_size, 160 | on_value=tf.to_int32(0), 161 | off_value=tf.to_int32(1), 162 | dtype=tf.int32) 163 | add_mask = tf.to_int32(tf.logical_not(previously_finished)) 164 | lengths_to_add = lengths_to_add * tf.expand_dims(add_mask, 2) 165 | new_prediction_lengths = lengths_to_add + tf.expand_dims(prediction_lengths, 2) 166 | 167 | scores = self.get_scores(log_probs=total_probs, 168 | sequence_lengths=new_prediction_lengths, 169 | length_penalty_weight=length_penalty_weight) 170 | 171 | scores_flat = tf.reshape(scores, (batch_size, -1)) 172 | next_beam_scores, word_indices = tf.nn.top_k(scores_flat, k=beam_width, sorted=True) 173 | next_beam_scores = tf.reshape(next_beam_scores, (batch_size, beam_width)) 174 | word_indices = tf.reshape(word_indices, (batch_size, beam_width)) 175 | 176 | next_beam_probs = tf.batch_gather(tf.reshape(total_probs, (batch_size, -1)), word_indices) 177 | 178 | next_word_ids = tf.mod(word_indices, vocab_size) 179 | next_beam_ids = tf.to_int32(tf.div(word_indices, vocab_size)) 180 | 181 | previously_finished = tf.batch_gather(previously_finished, next_beam_ids) 182 | next_finished = tf.logical_or(previously_finished, 183 | tf.equal(next_word_ids, end_id)) 184 | 185 | lengths_to_add = tf.to_int32(tf.logical_not(previously_finished)) 186 | next_prediction_len = tf.batch_gather(beam_state.lengths, next_beam_ids) 187 | next_prediction_len = next_prediction_len + lengths_to_add 188 | 189 | next_state_1 = tf.batch_gather(state_1, next_beam_ids) 190 | next_state_2 = tf.batch_gather(state_2, next_beam_ids) 191 | next_z = tf.batch_gather(z, next_beam_ids) 192 | 193 | beam_state = BeamSearchDecoderState( 194 | cell_state=(next_state_1, next_state_2, next_z), 195 | log_probs=next_beam_probs, 196 | lengths=next_prediction_len, 197 | finished=next_finished) 198 | 199 | inputs = tf.cond( 200 | tf.reduce_all(next_finished), 201 | lambda: start_inputs, lambda: tf.nn.embedding_lookup(embedding, next_word_ids)) 202 | 203 | y_s = tf.concat([tf.batch_gather(y_s, next_beam_ids), tf.expand_dims(next_word_ids, axis=2)], axis=2) 204 | return y_s[:, 0, :] 205 | 206 | 207 | def mask_probs(self, probs, eos_token, finished): 208 | vocab_size = probs.shape[-1] 209 | finished_row = tf.one_hot( 210 | indices=eos_token, 211 | depth=vocab_size, 212 | dtype=tf.float32, 213 | on_value=tf.convert_to_tensor(0., dtype=tf.float32), 214 | off_value=tf.float32.min) 215 | 216 | finished_probs = tf.tile( 217 | tf.reshape(finished_row, [1, 1, -1]), 218 | tf.concat([finished.shape, [1]], 0)) 219 | 220 | finished_mask = tf.tile( 221 | tf.expand_dims(finished, 2), [1, 1, vocab_size]) 222 | 223 | return tf.where(finished_mask, finished_probs, probs) 224 | 225 | def get_scores(self, log_probs, sequence_lengths, length_penalty_weight): 226 | length_penalty_ = self.length_penalty(sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight) 227 | 228 | return log_probs / length_penalty_ 229 | 230 | 231 | def length_penalty(self, sequence_lengths, penalty_factor): 232 | penalty_factor = tf.to_float(penalty_factor) 233 | return tf.div((tf.to_float(5.0) + tf.to_float(sequence_lengths))**penalty_factor, (tf.to_float(1.0 + 5.0))**penalty_factor) 234 | 235 | def build_cell(self, cell_classname, cell_name): 236 | cell_class = locate(cell_classname) 237 | return cell_class(num_units=self.config['cell']['num_units'], 238 | name=cell_name, 239 | **self.config['cell']['cell_params']) 240 | 241 | def compute_attention_weight(self, state, hidden_states, source_mask): 242 | encoder_hidden_hid = self.get_config('encoder', 'cell')['num_units'] 243 | decoder_hidden_hid = self.config['cell']['num_units'] 244 | 245 | with tf.variable_scope('attention', reuse=tf.AUTO_REUSE) as scope: 246 | W_d_hh = tf.get_variable(name='W_d_hh', shape=[encoder_hidden_hid*2, decoder_hidden_hid], initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32) 247 | W_e_hh = tf.get_variable(name='W_e_hh', shape=[decoder_hidden_hid, decoder_hidden_hid], initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32) 248 | b_a = tf.get_variable(name='b_a', shape=[decoder_hidden_hid], initializer=tf.zeros_initializer(), dtype=tf.float32) 249 | v = tf.get_variable(name='v', shape=[decoder_hidden_hid], initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32) 250 | 251 | r_1 = tf.einsum('ijk,kl->ijl', hidden_states, W_e_hh) 252 | r_2 = tf.expand_dims(tf.matmul(state, W_d_hh), axis=1) 253 | a = tf.add(r_1, r_2) 254 | t = tf.tanh(tf.add(a, b_a)) 255 | e_ij = tf.reduce_sum(tf.multiply(v, t), axis=-1) 256 | 257 | mask_value = tf.log(tf.to_float(0.0)) 258 | e_ij_mask = mask_value * tf.ones_like(e_ij) 259 | 260 | e_ij = tf.where(source_mask, e_ij, e_ij_mask) # replace the attention weight on token with -inf, then after softmax it can be zero 261 | 262 | a_ij = tf.expand_dims(tf.nn.softmax(e_ij, axis=1), axis=-1) 263 | return a_ij 264 | --------------------------------------------------------------------------------