├── requirements.txt ├── README.md ├── hyperparams.py ├── prepro.py ├── visual.py ├── eval.py ├── data_load.py ├── modules.py └── train.py /requirements.txt: -------------------------------------------------------------------------------- 1 | nltk>=3.2.4 2 | numpy>=1.13.0 3 | regex>=2017.6.7 4 | tensorflow>=1.2.0 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReCoSa 2 | ReCoSa: Detecting the Relevant Contexts with Self-Attention for Multi-turn Dialogue Generation 3 | https://arxiv.org/abs/1907.05339 4 | 5 | Requirement: 6 | 7 | nltk>=3.2.4 8 | 9 | numpy>=1.13.0 10 | 11 | regex>=2017.6.7 12 | 13 | tensorflow>=1.2.0 14 | 15 | 16 | 1、parameter setting: 17 | hyperparams.py 18 | 19 | 20 | 2、To generate vocab: 21 | python prepro.py 22 | 23 | 24 | 3、To train: 25 | python train.py 26 | 27 | 28 | 4、To eval: 29 | python eval.py 30 | 31 | 32 | 5、The dialogue data:Hello How are you? Good, you? I'm fine, what's new? 33 | 34 | Souce looks like: 35 | 36 | Hello How are you? \ 37 | 38 | Hello How are you? \ Good, you? \ 39 | 40 | Hello How are you? \ Good, you? \ I'm fine, what's new?\ 41 | 42 | 43 | Target: 44 | 45 | Good, you?\ 46 | 47 | I'm fine, what's new?\ 48 | 49 | Nothing much...\ 50 | 51 | -------------------------------------------------------------------------------- /hyperparams.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | ''' 4 | June 2017 by kyubyong park. 5 | kbpark.linguist@gmail.com. 6 | https://www.github.com/kyubyong/transformer 7 | ''' 8 | class Hyperparams: 9 | '''Hyperparameters''' 10 | # data 11 | source_train = 'corpora/JD.train.query' 12 | target_train = 'corpora/JD.train.answer' 13 | source_test = 'corpora/JD.test.query' 14 | target_test = 'corpora/JD.test.answer' 15 | source_dev = 'corpora/JD.dev.query' 16 | target_dev = 'corpora/JD.dev.answer' 17 | # training 18 | batch_size = 32 # alias = N 19 | lr = 0.0001 # learning rate. In paper, learning rate is adjusted to the global step. 20 | logdir = 'JDlogdir1129' # log directory 21 | 22 | # model 23 | maxlen = 50 # Maximum number of words in a sentence. alias = T. 24 | # Feel free to increase this if you are ambitious. 25 | min_cnt = 1 # words whose occurred less than min_cnt are encoded as . 26 | hidden_units = 512 # alias = C 27 | num_blocks = 6 # number of encoder/decoder blocks 28 | num_epochs = 500 29 | num_heads = 8 30 | dropout_rate = 0.1 31 | sinusoid = False # If True, use sinusoid. If false, positional embedding. 32 | 33 | num_layers=1 34 | max_turn=15 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /prepro.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | ''' 4 | June 2017 by kyubyong park. 5 | kbpark.linguist@gmail.com. 6 | https://www.github.com/kyubyong/transformer 7 | ''' 8 | from __future__ import print_function 9 | from hyperparams import Hyperparams as hp 10 | import tensorflow as tf 11 | import numpy as np 12 | import codecs 13 | import os 14 | import regex 15 | from collections import Counter 16 | 17 | def allNum(word): 18 | allNum=True 19 | for ww in word: 20 | if ww>'9' or ww<'0': 21 | allNum=False 22 | break 23 | return allNum 24 | def make_vocab(fpath, fname): 25 | '''Constructs vocabulary. 26 | 27 | Args: 28 | fpath: A string. Input file path. 29 | fname: A string. Output file name. 30 | 31 | Writes vocabulary line by line to `preprocessed/fname` 32 | ''' 33 | text = codecs.open(fpath, 'r', 'utf-8').read() 34 | #text = regex.sub("[^\s\p{Latin}']", "", text) 35 | words = text.split() 36 | word2cnt = Counter(words) 37 | if not os.path.exists('preprocessed'): os.mkdir('preprocessed') 38 | with codecs.open('preprocessed/{}'.format(fname), 'w', 'utf-8') as fout: 39 | fout.write("{}\t1000000000\n{}\t1000000000\n{}\t1000000000\n{}\t1000000000\n".format("", "", "", "")) 40 | for word, cnt in word2cnt.most_common(len(word2cnt)): 41 | if allNum(word)==False: 42 | fout.write(u"{}\t{}\n".format(word, cnt)) 43 | 44 | if __name__ == '__main__': 45 | make_vocab(hp.source_train, "de.vocab.tsv") 46 | make_vocab(hp.target_train, "en.vocab.tsv") 47 | print("Done") 48 | -------------------------------------------------------------------------------- /visual.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import sys 3 | reload(sys) 4 | sys.setdefaultencoding('utf-8') 5 | 6 | 7 | import matplotlib as mpl 8 | mpl.use('Agg') 9 | import seaborn as sns 10 | import pandas as pd 11 | import matplotlib.pyplot as plt 12 | import os as os 13 | 14 | mpl.rcParams['font.sans-serif'] = ["Droid Sans Fallback"] 15 | mpl.rcParams['axes.unicode_minus'] = False 16 | 17 | file1 = open("shell.txt") 18 | head=[[],[],[],[],[],[],[],[]] 19 | columnss=[] 20 | indexx=[] 21 | context=True 22 | wordnum=0 23 | for line in file1: 24 | line = line.split() 25 | if context==True and line[0] != "target:": 26 | print line[0] 27 | columnss.append("".join(w for w in line[:])) 28 | elif line[0]=="target:": 29 | context=False 30 | indexx = [w for w in line[1:]] 31 | else: 32 | ww = line[0] 33 | for ii in range(8): 34 | line=file1.next() 35 | head[ii].append([float(ff) for ff in line.split()]) 36 | wordnum+=1 37 | if wordnum==len(indexx): 38 | context=True 39 | for h in range(8): 40 | train_df=head[h] 41 | df2 = pd.DataFrame(data=train_df, index=indexx, columns=columnss) 42 | df2.to_csv("testfoo.csv" , encoding = "utf-8") 43 | df = pd.read_csv("testfoo.csv" , encoding = "utf-8",index_col=0) 44 | plt.figure(figsize=(25,25)) 45 | title = " ".join(w for w in df.columns) 46 | plt.title(title, y=1.05, size=15) 47 | g = sns.heatmap(df) 48 | plt.xticks(rotation=20) 49 | plt.yticks(rotation=360) 50 | plt.savefig('heatmap'+str(h)+'.png') 51 | head=[[],[],[],[],[],[],[],[]] 52 | columnss=[] 53 | indexx=[] 54 | wordnum=0 55 | break 56 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | ''' 4 | June 2017 by kyubyong park. 5 | kbpark.linguist@gmail.com. 6 | https://www.github.com/kyubyong/transformer 7 | ''' 8 | 9 | from __future__ import print_function 10 | import codecs 11 | import os 12 | 13 | import tensorflow as tf 14 | import numpy as np 15 | 16 | from hyperparams import Hyperparams as hp 17 | from data_load import load_test_data, load_de_vocab, load_en_vocab 18 | from train import Graph 19 | from nltk.translate.bleu_score import corpus_bleu 20 | 21 | def eval(): 22 | # Load graph 23 | g = Graph(is_training=False) 24 | print("Graph loaded") 25 | 26 | # Load data 27 | X,X_length, Sources, Targets = load_test_data() 28 | #print(X) 29 | de2idx, idx2de = load_de_vocab() 30 | en2idx, idx2en = load_en_vocab() 31 | 32 | # X, Sources, Targets = X[:33], Sources[:33], Targets[:33] 33 | 34 | # Start session 35 | with g.graph.as_default(): 36 | sv = tf.train.Supervisor() 37 | with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 38 | ## Restore parameters 39 | sv.saver.restore(sess, tf.train.latest_checkpoint(hp.logdir)) 40 | print("Restored!") 41 | 42 | ## Get model name 43 | mname = open(hp.logdir + '/checkpoint', 'r').read().split('"')[1] # model name 44 | #fftmp=open("tmp.txt","w") 45 | ## Inference 46 | if not os.path.exists('results'): os.mkdir('results') 47 | with codecs.open("results/" + mname, "w", "utf-8") as fout: 48 | list_of_refs, hypotheses = [], [] 49 | for i in range(len(X) // hp.batch_size): 50 | 51 | ### Get mini-batches 52 | x = X[i*hp.batch_size: (i+1)*hp.batch_size] 53 | x_length=X_length[i*hp.batch_size: (i+1)*hp.batch_size] 54 | sources = Sources[i*hp.batch_size: (i+1)*hp.batch_size] 55 | targets = Targets[i*hp.batch_size: (i+1)*hp.batch_size] 56 | #fftmp.write("%s\n"%(" ".join(str(w) for w in x[0][0]).encode("utf-8"))) 57 | #fftmp.write("%s\n"%(sources[0].encode("utf-8"))) 58 | #fftmp.write("%s\n"%(' '.join(str(w) for w in x_length))) 59 | #print (sources) 60 | #print (targets) 61 | ### Autoregressive inference 62 | preds = np.zeros((hp.batch_size, hp.maxlen), np.int32) 63 | for j in range(hp.maxlen): 64 | _preds = sess.run(g.preds, {g.x: x,g.x_length:x_length, g.y: preds}) 65 | preds[:, j] = _preds[:, j] 66 | 67 | ### Write to file 68 | for source, target, pred in zip(sources, targets, preds): # sentence-wise 69 | got = " ".join(idx2en[idx] for idx in pred).split("")[0].strip() 70 | fout.write("- source: " + source +"\n") 71 | fout.write("- expected: " + target + "\n") 72 | fout.write("- got: " + got + "\n\n") 73 | fout.flush() 74 | 75 | # bleu score 76 | ref = target.split() 77 | hypothesis = got.split() 78 | if len(ref) > 3 and len(hypothesis) > 3: 79 | list_of_refs.append([ref]) 80 | hypotheses.append(hypothesis) 81 | 82 | ## Calculate bleu score 83 | score = corpus_bleu(list_of_refs, hypotheses) 84 | fout.write("Bleu Score = " + str(100*score)) 85 | 86 | if __name__ == '__main__': 87 | eval() 88 | print("Done") 89 | 90 | 91 | -------------------------------------------------------------------------------- /data_load.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | ''' 4 | June 2017 by kyubyong park. 5 | kbpark.linguist@gmail.com. 6 | https://www.github.com/kyubyong/transformer 7 | ''' 8 | from __future__ import print_function 9 | from hyperparams import Hyperparams as hp 10 | import tensorflow as tf 11 | import numpy as np 12 | import codecs 13 | import regex 14 | 15 | def load_de_vocab(): 16 | vocab = [line.split()[0] for line in codecs.open('preprocessed/de.vocab.tsv', 'r', 'utf-8').read().splitlines() if int(line.split()[1])>=hp.min_cnt] 17 | word2idx = {word: idx for idx, word in enumerate(vocab)} 18 | idx2word = {idx: word for idx, word in enumerate(vocab)} 19 | return word2idx, idx2word 20 | 21 | def load_en_vocab(): 22 | vocab = [line.split()[0] for line in codecs.open('preprocessed/en.vocab.tsv', 'r', 'utf-8').read().splitlines() if int(line.split()[1])>=hp.min_cnt] 23 | word2idx = {word: idx for idx, word in enumerate(vocab)} 24 | idx2word = {idx: word for idx, word in enumerate(vocab)} 25 | return word2idx, idx2word 26 | 27 | def create_data(source_sents, target_sents): 28 | de2idx, idx2de = load_de_vocab() 29 | en2idx, idx2en = load_en_vocab() 30 | 31 | # Index 32 | x_list, y_list, Sources, Targets = [], [], [], [] 33 | for source_sent, target_sent in zip(source_sents, target_sents): 34 | source_sent_split = source_sent.split(u"") 35 | x=[] 36 | for sss in source_sent_split: 37 | if len(sss.split())==0: 38 | continue 39 | x.append( [de2idx.get(word, 1) for word in (sss + u" ").split()]) # 1: OOV, : End of Text 40 | target_sent_split = target_sent.split(u"") 41 | y = [en2idx.get(word, 1) for word in (target_sent_split[0] + u" ").split()] 42 | if max(len(x), len(y)) <=hp.maxlen: 43 | x_list.append(np.array(x)) 44 | y_list.append(np.array(y)) 45 | Sources.append(source_sent) 46 | Targets.append(target_sent) 47 | 48 | # Pad 49 | X = np.zeros([len(x_list),hp.max_turn, hp.maxlen], np.int32) 50 | Y = np.zeros([len(y_list), hp.maxlen], np.int32) 51 | X_length=np.zeros([len(x_list),hp.max_turn],np.int32) 52 | for i, (x, y) in enumerate(zip(x_list, y_list)): 53 | for j in range(len(x)): 54 | if j >= hp.max_turn : 55 | break 56 | if len(x[j])]+>", "", line) 77 | #line = regex.sub("[^\s\p{Latin}']", "", line) 78 | return line.strip() 79 | 80 | de_sents = [_refine(line) for line in codecs.open(hp.source_test, 'r', 'utf-8').read().split("\n") if line] 81 | en_sents = [_refine(line) for line in codecs.open(hp.target_test, 'r', 'utf-8').read().split("\n") if line] 82 | 83 | X, X_length, Y, Sources, Targets = create_data(de_sents, en_sents) 84 | return X,X_length, Sources, Targets # (1064, 150) 85 | 86 | def load_dev_data(): 87 | def _refine(line): 88 | #line = regex.sub("<[^>]+>", "", line) 89 | #line = regex.sub("[^\s\p{Latin}']", "", line) 90 | return line.strip() 91 | 92 | de_sents = [_refine(line) for line in codecs.open(hp.source_dev, 'r', 'utf-8').read().split("\n") if line] 93 | en_sents = [_refine(line) for line in codecs.open(hp.target_dev, 'r', 'utf-8').read().split("\n") if line] 94 | 95 | X, X_length, Y, Sources, Targets = create_data(de_sents, en_sents) 96 | return X,X_length,Y, Sources, Targets # (1064, 150) 97 | 98 | def get_batch_data(): 99 | # Load data 100 | X,X_length, Y, sources,targets = load_train_data() 101 | 102 | # calc total batch count 103 | num_batch = len(X) // hp.batch_size 104 | 105 | # Convert to tensor 106 | X = tf.convert_to_tensor(X, tf.int32) 107 | Y = tf.convert_to_tensor(Y, tf.int32) 108 | X_length = tf.convert_to_tensor(X_length,tf.int32) 109 | # Create Queues 110 | input_queues = tf.train.slice_input_producer([X,X_length, Y,sources,targets]) 111 | 112 | # create batch queues 113 | x,x_length, y,sources,targets = tf.train.shuffle_batch(input_queues, 114 | num_threads=8, 115 | batch_size=hp.batch_size, 116 | capacity=hp.batch_size*64, 117 | min_after_dequeue=hp.batch_size*32, 118 | allow_smaller_final_batch=False) 119 | 120 | return x,x_length, y, num_batch ,sources,targets# (N, T), (N, T), () 121 | 122 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | ''' 4 | June 2017 by kyubyong park. 5 | kbpark.linguist@gmail.com. 6 | https://www.github.com/kyubyong/transformer 7 | ''' 8 | 9 | from __future__ import print_function 10 | import tensorflow as tf 11 | 12 | def normalize(inputs, 13 | epsilon = 1e-8, 14 | scope="ln", 15 | reuse=None): 16 | '''Applies layer normalization. 17 | 18 | Args: 19 | inputs: A tensor with 2 or more dimensions, where the first dimension has 20 | `batch_size`. 21 | epsilon: A floating number. A very small number for preventing ZeroDivision Error. 22 | scope: Optional scope for `variable_scope`. 23 | reuse: Boolean, whether to reuse the weights of a previous layer 24 | by the same name. 25 | 26 | Returns: 27 | A tensor with the same shape and data dtype as `inputs`. 28 | ''' 29 | with tf.variable_scope(scope, reuse=reuse): 30 | inputs_shape = inputs.get_shape() 31 | params_shape = inputs_shape[-1:] 32 | 33 | mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True) 34 | beta= tf.Variable(tf.zeros(params_shape)) 35 | gamma = tf.Variable(tf.ones(params_shape)) 36 | normalized = (inputs - mean) / ( (variance + epsilon) ** (.5) ) 37 | outputs = gamma * normalized + beta 38 | 39 | return outputs 40 | 41 | def embedding(inputs, 42 | vocab_size, 43 | num_units, 44 | zero_pad=True, 45 | scale=True, 46 | scope="embedding", 47 | reuse=None): 48 | '''Embeds a given tensor. 49 | 50 | Args: 51 | inputs: A `Tensor` with type `int32` or `int64` containing the ids 52 | to be looked up in `lookup table`. 53 | vocab_size: An int. Vocabulary size. 54 | num_units: An int. Number of embedding hidden units. 55 | zero_pad: A boolean. If True, all the values of the fist row (id 0) 56 | should be constant zeros. 57 | scale: A boolean. If True. the outputs is multiplied by sqrt num_units. 58 | scope: Optional scope for `variable_scope`. 59 | reuse: Boolean, whether to reuse the weights of a previous layer 60 | by the same name. 61 | 62 | Returns: 63 | A `Tensor` with one more rank than inputs's. The last dimensionality 64 | should be `num_units`. 65 | 66 | For example, 67 | 68 | ``` 69 | import tensorflow as tf 70 | 71 | inputs = tf.to_int32(tf.reshape(tf.range(2*3), (2, 3))) 72 | outputs = embedding(inputs, 6, 2, zero_pad=True) 73 | with tf.Session() as sess: 74 | sess.run(tf.global_variables_initializer()) 75 | print sess.run(outputs) 76 | >> 77 | [[[ 0. 0. ] 78 | [ 0.09754146 0.67385566] 79 | [ 0.37864095 -0.35689294]] 80 | 81 | [[-1.01329422 -1.09939694] 82 | [ 0.7521342 0.38203377] 83 | [-0.04973143 -0.06210355]]] 84 | ``` 85 | 86 | ``` 87 | import tensorflow as tf 88 | 89 | inputs = tf.to_int32(tf.reshape(tf.range(2*3), (2, 3))) 90 | outputs = embedding(inputs, 6, 2, zero_pad=False) 91 | with tf.Session() as sess: 92 | sess.run(tf.global_variables_initializer()) 93 | print sess.run(outputs) 94 | >> 95 | [[[-0.19172323 -0.39159766] 96 | [-0.43212751 -0.66207761] 97 | [ 1.03452027 -0.26704335]] 98 | 99 | [[-0.11634696 -0.35983452] 100 | [ 0.50208133 0.53509563] 101 | [ 1.22204471 -0.96587461]]] 102 | ``` 103 | ''' 104 | with tf.variable_scope(scope, reuse=reuse): 105 | lookup_table = tf.get_variable('lookup_table', 106 | dtype=tf.float32, 107 | shape=[vocab_size, num_units], 108 | initializer=tf.contrib.layers.xavier_initializer()) 109 | if zero_pad: 110 | lookup_table = tf.concat((tf.zeros(shape=[1, num_units]), 111 | lookup_table[1:, :]), 0) 112 | outputs = tf.nn.embedding_lookup(lookup_table, inputs) 113 | 114 | if scale: 115 | outputs = outputs * (num_units ** 0.5) 116 | 117 | return outputs 118 | 119 | 120 | def positional_encoding(inputs, 121 | num_units, 122 | zero_pad=True, 123 | scale=True, 124 | scope="positional_encoding", 125 | reuse=None): 126 | '''Sinusoidal Positional_Encoding. 127 | 128 | Args: 129 | inputs: A 2d Tensor with shape of (N, T). 130 | num_units: Output dimensionality 131 | zero_pad: Boolean. If True, all the values of the first row (id = 0) should be constant zero 132 | scale: Boolean. If True, the output will be multiplied by sqrt num_units(check details from paper) 133 | scope: Optional scope for `variable_scope`. 134 | reuse: Boolean, whether to reuse the weights of a previous layer 135 | by the same name. 136 | 137 | Returns: 138 | A 'Tensor' with one more rank than inputs's, with the dimensionality should be 'num_units' 139 | ''' 140 | 141 | N, T = inputs.get_shape().as_list() 142 | with tf.variable_scope(scope, reuse=reuse): 143 | position_ind = tf.tile(tf.expand_dims(tf.range(T), 0), [N, 1]) 144 | 145 | # First part of the PE function: sin and cos argument 146 | position_enc = np.array([ 147 | [pos / np.power(10000, 2.*i/num_units) for i in range(num_units)] 148 | for pos in range(T)]) 149 | 150 | # Second part, apply the cosine to even columns and sin to odds. 151 | position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i 152 | position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1 153 | 154 | # Convert to a tensor 155 | lookup_table = tf.convert_to_tensor(position_enc) 156 | 157 | if zero_pad: 158 | lookup_table = tf.concat((tf.zeros(shape=[1, num_units]), 159 | lookup_table[1:, :]), 0) 160 | outputs = tf.nn.embedding_lookup(lookup_table, position_ind) 161 | 162 | if scale: 163 | outputs = outputs * num_units**0.5 164 | 165 | return outputs 166 | 167 | 168 | 169 | def multihead_attention(queries, 170 | keys, 171 | num_units=None, 172 | num_heads=8, 173 | dropout_rate=0, 174 | is_training=True, 175 | causality=False, 176 | scope="multihead_attention", 177 | reuse=None): 178 | '''Applies multihead attention. 179 | 180 | Args: 181 | queries: A 3d tensor with shape of [N, T_q, C_q]. 182 | keys: A 3d tensor with shape of [N, T_k, C_k]. 183 | num_units: A scalar. Attention size. 184 | dropout_rate: A floating point number. 185 | is_training: Boolean. Controller of mechanism for dropout. 186 | causality: Boolean. If true, units that reference the future are masked. 187 | num_heads: An int. Number of heads. 188 | scope: Optional scope for `variable_scope`. 189 | reuse: Boolean, whether to reuse the weights of a previous layer 190 | by the same name. 191 | 192 | Returns 193 | A 3d tensor with shape of (N, T_q, C) 194 | ''' 195 | with tf.variable_scope(scope, reuse=reuse): 196 | # Set the fall back option for num_units 197 | if num_units is None: 198 | num_units = queries.get_shape().as_list[-1] 199 | 200 | # Linear projections 201 | Q = tf.layers.dense(queries, num_units, activation=tf.nn.relu) # (N, T_q, C) 202 | K = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C) 203 | V = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C) 204 | 205 | # Split and concat 206 | Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, C/h) 207 | K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, C/h) 208 | V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, C/h) 209 | 210 | # Multiplication 211 | outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) # (h*N, T_q, T_k) 212 | 213 | # Scale 214 | outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5) 215 | 216 | # Key Masking 217 | key_masks = tf.sign(tf.abs(tf.reduce_sum(keys, axis=-1))) # (N, T_k) 218 | key_masks = tf.tile(key_masks, [num_heads, 1]) # (h*N, T_k) 219 | key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(queries)[1], 1]) # (h*N, T_q, T_k) 220 | 221 | paddings = tf.ones_like(outputs)*(-2**32+1) 222 | outputs = tf.where(tf.equal(key_masks, 0), paddings, outputs) # (h*N, T_q, T_k) 223 | 224 | # Causality = Future blinding 225 | if causality: 226 | diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k) 227 | tril = tf.contrib.linalg.LinearOperatorTriL(diag_vals).to_dense() # (T_q, T_k) 228 | masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k) 229 | 230 | paddings = tf.ones_like(masks)*(-2**32+1) 231 | outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k) 232 | 233 | # Activation 234 | outputs = tf.nn.softmax(outputs) # (h*N, T_q, T_k) 235 | attn = outputs 236 | # Query Masking 237 | query_masks = tf.sign(tf.abs(tf.reduce_sum(queries, axis=-1))) # (N, T_q) 238 | query_masks = tf.tile(query_masks, [num_heads, 1]) # (h*N, T_q) 239 | query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, tf.shape(keys)[1]]) # (h*N, T_q, T_k) 240 | outputs *= query_masks # broadcasting. (N, T_q, C) 241 | 242 | # Dropouts 243 | outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=tf.convert_to_tensor(is_training)) 244 | 245 | # Weighted sum 246 | outputs = tf.matmul(outputs, V_) # ( h*N, T_q, C/h) 247 | 248 | # Restore shape 249 | outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2 ) # (N, T_q, C) 250 | 251 | # Residual connection 252 | outputs += queries 253 | 254 | # Normalize 255 | outputs = normalize(outputs) # (N, T_q, C) 256 | 257 | return outputs,attn 258 | 259 | def feedforward(inputs, 260 | num_units=[2048, 512], 261 | scope="multihead_attention", 262 | reuse=None): 263 | '''Point-wise feed forward net. 264 | 265 | Args: 266 | inputs: A 3d tensor with shape of [N, T, C]. 267 | num_units: A list of two integers. 268 | scope: Optional scope for `variable_scope`. 269 | reuse: Boolean, whether to reuse the weights of a previous layer 270 | by the same name. 271 | 272 | Returns: 273 | A 3d tensor with the same shape and dtype as inputs 274 | ''' 275 | with tf.variable_scope(scope, reuse=reuse): 276 | # Inner layer 277 | params = {"inputs": inputs, "filters": num_units[0], "kernel_size": 1, 278 | "activation": tf.nn.relu, "use_bias": True} 279 | outputs = tf.layers.conv1d(**params) 280 | 281 | # Readout layer 282 | params = {"inputs": outputs, "filters": num_units[1], "kernel_size": 1, 283 | "activation": None, "use_bias": True} 284 | outputs = tf.layers.conv1d(**params) 285 | 286 | # Residual connection 287 | outputs += inputs 288 | 289 | # Normalize 290 | outputs = normalize(outputs) 291 | 292 | return outputs 293 | 294 | def label_smoothing(inputs, epsilon=0.1): 295 | '''Applies label smoothing. See https://arxiv.org/abs/1512.00567. 296 | 297 | Args: 298 | inputs: A 3d tensor with shape of [N, T, V], where V is the number of vocabulary. 299 | epsilon: Smoothing rate. 300 | 301 | For example, 302 | 303 | ``` 304 | import tensorflow as tf 305 | inputs = tf.convert_to_tensor([[[0, 0, 1], 306 | [0, 1, 0], 307 | [1, 0, 0]], 308 | 309 | [[1, 0, 0], 310 | [1, 0, 0], 311 | [0, 1, 0]]], tf.float32) 312 | 313 | outputs = label_smoothing(inputs) 314 | 315 | with tf.Session() as sess: 316 | print(sess.run([outputs])) 317 | 318 | >> 319 | [array([[[ 0.03333334, 0.03333334, 0.93333334], 320 | [ 0.03333334, 0.93333334, 0.03333334], 321 | [ 0.93333334, 0.03333334, 0.03333334]], 322 | 323 | [[ 0.93333334, 0.03333334, 0.03333334], 324 | [ 0.93333334, 0.03333334, 0.03333334], 325 | [ 0.03333334, 0.93333334, 0.03333334]]], dtype=float32)] 326 | ``` 327 | ''' 328 | K = inputs.get_shape().as_list()[-1] # number of channels 329 | return ((1-epsilon) * inputs) + (epsilon / K) 330 | 331 | 332 | 333 | 334 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding: utf-8 -*- 3 | #/usr/bin/python2 4 | ''' 5 | June 2017 by kyubyong park. 6 | kbpark.linguist@gmail.com. 7 | https://www.github.com/kyubyong/transformer 8 | ''' 9 | from __future__ import print_function 10 | import sys 11 | reload(sys) 12 | sys.setdefaultencoding('utf8') 13 | import tensorflow as tf 14 | 15 | from hyperparams import Hyperparams as hp 16 | from data_load import get_batch_data, load_de_vocab, load_en_vocab, load_dev_data 17 | from modules import * 18 | import os, codecs 19 | from tqdm import tqdm 20 | 21 | import numpy as np 22 | import codecs 23 | import nltk 24 | class Graph(): 25 | def __init__(self, is_training=True): 26 | self.graph = tf.Graph() 27 | with self.graph.as_default(): 28 | if is_training: 29 | self.x, self.x_length,self.y, self.num_batch,self.source,self.target = get_batch_data() # (N, T) 30 | else: # inference 31 | self.x = tf.placeholder(tf.int32, shape=(None,hp.max_turn,hp.maxlen)) 32 | self.x_length = tf.placeholder(tf.int32,shape=(None,hp.max_turn)) 33 | self.y = tf.placeholder(tf.int32, shape=(None, hp.maxlen)) 34 | 35 | # define decoder inputs 36 | self.decoder_inputs = tf.concat((tf.ones_like(self.y[:, :1])*2, self.y[:, :-1]), -1) # 2: 37 | 38 | # Load vocabulary 39 | de2idx, idx2de = load_de_vocab() 40 | en2idx, idx2en = load_en_vocab() 41 | 42 | # Encoder 43 | with tf.variable_scope("encoder"): 44 | ## Embedding 45 | embeddingsize = hp.hidden_units/2 46 | self.enc_embed = embedding(tf.reshape(self.x,[-1,hp.maxlen]), 47 | vocab_size=len(de2idx), 48 | num_units=embeddingsize, 49 | scale=True, 50 | scope="enc_embed") 51 | single_cell = tf.nn.rnn_cell.GRUCell(hp.hidden_units) 52 | self.rnn_cell = tf.nn.rnn_cell.MultiRNNCell([single_cell]*hp.num_layers) 53 | print (self.enc_embed.get_shape()) 54 | self.sequence_length=tf.reshape(self.x_length,[-1]) 55 | print(self.sequence_length.get_shape()) 56 | self.uttn_outputs, self.uttn_states = tf.nn.dynamic_rnn(cell=self.rnn_cell, inputs=self.enc_embed,sequence_length=self.sequence_length, dtype=tf.float32,swap_memory=True) 57 | self.enc = tf.reshape(self.uttn_states,[hp.batch_size,hp.max_turn,hp.hidden_units]) 58 | ## Positional Encoding 59 | if hp.sinusoid: 60 | self.enc += positional_encoding(self.x, 61 | num_units=hp.hidden_units, 62 | zero_pad=False, 63 | scale=False, 64 | scope="enc_pe") 65 | else: 66 | self.enc += embedding(tf.tile(tf.expand_dims(tf.range(tf.shape(self.x)[1]), 0), [tf.shape(self.x)[0], 1]), 67 | vocab_size=hp.maxlen, 68 | num_units=hp.hidden_units, 69 | zero_pad=False, 70 | scale=False, 71 | scope="enc_pe") 72 | 73 | 74 | ## Dropout 75 | self.enc = tf.layers.dropout(self.enc, 76 | rate=hp.dropout_rate, 77 | training=tf.convert_to_tensor(is_training)) 78 | 79 | ## Blocks 80 | for i in range(hp.num_blocks): 81 | with tf.variable_scope("num_blocks_{}".format(i)): 82 | ### Multihead Attention 83 | self.enc,_ = multihead_attention(queries=self.enc, 84 | keys=self.enc, 85 | num_units=hp.hidden_units, 86 | num_heads=hp.num_heads, 87 | dropout_rate=hp.dropout_rate, 88 | is_training=is_training, 89 | causality=False) 90 | 91 | ### Feed Forward 92 | self.enc = feedforward(self.enc, num_units=[4*hp.hidden_units, hp.hidden_units]) 93 | 94 | # Decoder 95 | with tf.variable_scope("decoder"): 96 | ## Embedding 97 | self.dec = embedding(self.decoder_inputs, 98 | vocab_size=len(en2idx), 99 | num_units=hp.hidden_units, 100 | scale=True, 101 | scope="dec_embed") 102 | 103 | ## Positional Encoding 104 | if hp.sinusoid: 105 | self.dec += positional_encoding(self.decoder_inputs, 106 | vocab_size=hp.maxlen, 107 | num_units=hp.hidden_units, 108 | zero_pad=False, 109 | scale=False, 110 | scope="dec_pe") 111 | else: 112 | self.dec += embedding(tf.tile(tf.expand_dims(tf.range(tf.shape(self.decoder_inputs)[1]), 0), [tf.shape(self.decoder_inputs)[0], 1]), 113 | vocab_size=hp.maxlen, 114 | num_units=hp.hidden_units, 115 | zero_pad=False, 116 | scale=False, 117 | scope="dec_pe") 118 | 119 | ## Dropout 120 | self.dec = tf.layers.dropout(self.dec, 121 | rate=hp.dropout_rate, 122 | training=tf.convert_to_tensor(is_training)) 123 | 124 | ## Blocks 125 | for i in range(hp.num_blocks): 126 | with tf.variable_scope("num_blocks_{}".format(i)): 127 | ## Multihead Attention ( self-attention) 128 | self.dec,_ = multihead_attention(queries=self.dec, 129 | keys=self.dec, 130 | num_units=hp.hidden_units, 131 | num_heads=hp.num_heads, 132 | dropout_rate=hp.dropout_rate, 133 | is_training=is_training, 134 | causality=True, 135 | scope="self_attention") 136 | 137 | ## Multihead Attention ( vanilla attention) 138 | self.dec,self.attn = multihead_attention(queries=self.dec, 139 | keys=self.enc, 140 | num_units=hp.hidden_units, 141 | num_heads=hp.num_heads, 142 | dropout_rate=hp.dropout_rate, 143 | is_training=is_training, 144 | causality=False, 145 | scope="vanilla_attention") 146 | ## Feed Forward 147 | self.dec = feedforward(self.dec, num_units=[4*hp.hidden_units, hp.hidden_units]) 148 | 149 | # Final linear projection 150 | self.logits = tf.layers.dense(self.dec, len(en2idx)) 151 | self.preds = tf.to_int32(tf.arg_max(self.logits, dimension=-1)) 152 | self.istarget = tf.to_float(tf.not_equal(self.y, 0)) 153 | self.acc = tf.reduce_sum(tf.to_float(tf.equal(self.preds, self.y))*self.istarget)/ (tf.reduce_sum(self.istarget)) 154 | tf.summary.scalar('acc', self.acc) 155 | 156 | if is_training: 157 | # Loss 158 | self.y_smoothed = label_smoothing(tf.one_hot(self.y, depth=len(en2idx))) 159 | self.loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.y_smoothed) 160 | self.mean_loss = tf.reduce_sum(self.loss*self.istarget) / (tf.reduce_sum(self.istarget)) 161 | 162 | # Training Scheme 163 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 164 | self.optimizer = tf.train.AdamOptimizer(learning_rate=hp.lr, beta1=0.9, beta2=0.98, epsilon=1e-8) 165 | self.train_op = self.optimizer.minimize(self.mean_loss, global_step=self.global_step) 166 | 167 | # Summary 168 | tf.summary.scalar('mean_loss', self.mean_loss) 169 | self.merged = tf.summary.merge_all() 170 | 171 | if __name__ == '__main__': 172 | # Load vocabulary 173 | de2idx, idx2de = load_de_vocab() 174 | en2idx, idx2en = load_en_vocab() 175 | 176 | # Construct graph 177 | g = Graph("train"); print("Graph loaded") 178 | X,X_length,Y, Sources, Targets = load_dev_data() 179 | # Start session 180 | sv = tf.train.Supervisor(graph=g.graph, 181 | logdir=hp.logdir, 182 | save_model_secs=0) 183 | #preEpoch= 184 | tfconfig = tf.ConfigProto() 185 | tfconfig.gpu_options.allow_growth = True 186 | with sv.managed_session(config = tfconfig) as sess: 187 | early_break = 0 188 | old_eval_loss=10000 189 | for epoch in range(1, hp.num_epochs+1): 190 | if sv.should_stop(): break 191 | loss=[] 192 | 193 | if early_break >=4: 194 | break 195 | for step in tqdm(range(g.num_batch), total=g.num_batch, ncols=70, leave=False, unit='b'): 196 | _,loss_step,attns,sources,targets = sess.run([g.train_op,g.mean_loss,g.attn,g.source,g.target]) 197 | loss.append(loss_step) 198 | 199 | if step%2000==0: 200 | gs = sess.run(g.global_step) 201 | print("train loss:%.5lf\n"%(np.mean(loss))) 202 | sv.saver.save(sess, hp.logdir + '/model_epoch_%02d_gs_%d' % (epoch, gs)) 203 | 204 | mname = open(hp.logdir + '/checkpoint', 'r').read().split('"')[1] 205 | fout = codecs.open( mname, "w","utf-8") 206 | eval_loss=[] 207 | bleu=[] 208 | 209 | for i in range(len(X) // hp.batch_size): 210 | ### Get mini-batches 211 | x = X[i*hp.batch_size: (i+1)*hp.batch_size] 212 | x_length=X_length[i*hp.batch_size: (i+1)*hp.batch_size] 213 | y = Y[i*hp.batch_size: (i+1)*hp.batch_size] 214 | sources = Sources[i*hp.batch_size: (i+1)*hp.batch_size] 215 | targets = Targets[i*hp.batch_size: (i+1)*hp.batch_size] 216 | eval_bath = sess.run(g.mean_loss, {g.x: x,g.x_length:x_length,g.y: y}) 217 | eval_loss.append( eval_bath) 218 | 219 | preds = np.zeros((hp.batch_size, hp.maxlen), np.int32) 220 | for j in range(hp.maxlen): 221 | _preds = sess.run(g.preds, {g.x: x,g.x_length:x_length, g.y: preds}) 222 | preds[:, j] = _preds[:, j] 223 | 224 | 225 | 226 | ### Write to file 227 | list_of_refs, hypotheses = [], [] 228 | for source, target, pred in zip(sources, targets, preds): # sentence-wise 229 | got = " ".join(idx2en[idx] for idx in pred).split("")[0].strip() 230 | fout.write("- source: " + source +"\n") 231 | fout.write("- expected: " + target + "\n") 232 | fout.write("- got: " + got + "\n\n") 233 | fout.flush() 234 | # bleu score 235 | ref = got.split() 236 | hypothesis = target.split() 237 | score = nltk.translate.bleu_score.sentence_bleu([hypothesis],ref,(0.25, 0.25, 0.25, 0.25),nltk.translate.bleu_score.SmoothingFunction().method1) 238 | bleu.append(score) 239 | fout.write("train loss = %.5lf\teval loss = %.5lf\tBleu Score = %.5lf\n" %(np.mean(loss),np.mean(eval_loss),100*np.mean(bleu))) 240 | print("eval loss:%.5lf"%(np.mean(eval_loss))) 241 | print("Bleu Score:%.5lf"%(100*np.mean(bleu))) 242 | if np.mean(eval_loss) > old_eval_loss: 243 | early_break +=1 244 | else: 245 | early_break = 0 246 | old_eval_loss=np.mean(eval_loss) 247 | if early_break>=4: 248 | break 249 | #attention analysis 250 | break 251 | print("Done") 252 | 253 | 254 | --------------------------------------------------------------------------------