├── Evaluate.py ├── Readme.md ├── ScriptWriter-CPre.py ├── ScriptWriter.py ├── Utils.py ├── data ├── readme.md ├── sample_data.txt ├── sample_data_english.txt ├── sample_vocab.txt └── sample_vocab_english.txt ├── image ├── example.png ├── matching.png ├── result.png ├── statistics.png └── update.png ├── model └── readme.md ├── modules.py └── output └── readme.md /Evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def compute_r_n_m(scores, labels, count, at): 4 | total = 0 5 | correct = 0 6 | for i in range(len(labels)): 7 | if i % 10 == 0: 8 | total = total + 1 9 | sublist = scores[i:i + count] 10 | pos_score = sublist[0] 11 | sublist = sorted(sublist, key=lambda x: x, reverse=True) 12 | if sublist[at - 1] <= pos_score: 13 | correct += 1 14 | return float(correct) / total 15 | 16 | def compute_mrr(scores, labels, count=10): 17 | total = 0 18 | accumulate_mrr = 0 19 | for i in range(len(labels)): 20 | if i % 10 == 0: 21 | total = total + 1 22 | sublist = scores[i:i + count] 23 | arg_sort = list(np.argsort(sublist)).index(0) 24 | idx = len(sublist) - arg_sort 25 | accumulate_mrr += 1 / idx 26 | return float(accumulate_mrr) / total 27 | 28 | def compute_acc(scores, labels): 29 | scores = (np.asarray(scores) > 0.5).astype(np.int32) 30 | accuracy = sum((scores == labels).astype(np.int32)) / len(labels) 31 | return accuracy 32 | 33 | def evaluate_all(scores, labels): 34 | return compute_acc(scores, labels), compute_r_n_m(scores, labels, 2, 1), compute_r_n_m(scores, labels, 10, 1), compute_r_n_m(scores, labels, 10, 2), \ 35 | compute_r_n_m(scores, labels, 10, 5), compute_mrr(scores, labels) 36 | 37 | def evaluate_all_from_file(path): 38 | scores = [] 39 | labels = [] 40 | with open(path, "r") as f: 41 | for line in f: 42 | score, label = line.strip().split("\t") 43 | scores.append(float(score)) 44 | labels.append(float(labels)) 45 | evaluate_all(scores, labels) 46 | 47 | def recover_and_show(basic_directory): 48 | import pickle 49 | 50 | vocab = {} 51 | vocab_id2word = {} 52 | 53 | with open("./data/vocab.txt", "r", encoding="utf-8") as fr: 54 | for idx, line in enumerate(fr): 55 | line = line.strip().split("\t") 56 | vocab[line[0]] = idx + 1 57 | vocab_id2word[idx + 1] = line[0] 58 | 59 | vocab["_PAD_"] = 0 60 | vocab_id2word[0] = "_PAD_" 61 | 62 | def initialize(): 63 | all_outline = [] 64 | outline_dict = {} 65 | initial_file = basic_directory + "test.multi.0.pkl" 66 | with open(initial_file, 'rb') as f: 67 | _, _, outline, _ = pickle.load(f) 68 | for o in outline: 69 | if o not in all_outline: 70 | all_outline.append(o) 71 | sent_o = "".join([vocab_id2word[x] for x in o]) 72 | outline_dict[sent_o] = [] 73 | return all_outline, outline_dict 74 | 75 | max_turn = 10 76 | all_outline, outline_dict = initialize() 77 | 78 | for turn in range(0, max_turn + 1): 79 | score = [] 80 | with open(basic_directory + "test.result.multi." + str(turn) + ".txt", "r") as fr: 81 | for idx, line in enumerate(fr): 82 | if idx == 0: 83 | continue 84 | score.append(float(line.strip())) 85 | with open(basic_directory + "test.multi." + str(turn) + ".pkl", "rb") as fr: 86 | utterance, response, outline, labels = pickle.load(fr) # except for dl2r 87 | 88 | for i, o in enumerate(outline): 89 | if i % 10 == 0: 90 | score_sub_list = score[i:i + 10] 91 | response_sub_list = response[i:i + 10] 92 | max_idx = score_sub_list.index(max(score_sub_list)) 93 | selected_response = response_sub_list[max_idx] 94 | sent_o = "".join([vocab_id2word[x] for x in o]) 95 | outline_dict[sent_o] = utterance[i] + [selected_response] # for MUSwO 96 | 97 | with open(basic_directory + "test.result.multi.txt", "w", encoding="utf-8") as fw: 98 | for o in all_outline: 99 | sent_o = "".join([vocab_id2word[x] for x in o]) 100 | utterance = outline_dict[sent_o] 101 | fw.write("outline\t" + sent_o + "\n") 102 | for u in utterance: 103 | sent_u = "".join([vocab_id2word[x] for x in u]) 104 | fw.write("script\t" + sent_u + "\n") 105 | fw.write("\n") 106 | 107 | def evaluate_multi_turn_result(t_file, g_file): 108 | test_file = t_file 109 | gold_file = g_file 110 | ft = open(test_file, "r", encoding="utf-8") 111 | fg = open(gold_file, "r", encoding="utf-8") 112 | t, g = ft.readline(), fg.readline() 113 | c = -1 114 | s = -1 115 | all_s = 0 116 | r_all = 0 117 | tmp_r_all = 0 118 | o = None 119 | flag = 1 120 | t_s = [] 121 | g_s = [] 122 | while t and g: 123 | t = t.strip().split("\t") 124 | g = g.strip().split("\t") 125 | if len(t) > 1: 126 | if t[0] == "script": 127 | if t[1] == g[1]: 128 | s += 1 129 | c += 1 130 | t_s.append(t[1]) 131 | g_s.append(g[1]) 132 | if t[0] == "outline": 133 | o = t[1] 134 | else: 135 | tmp_c = -1 136 | tmp_s = -1 137 | for at in g_s: 138 | tmp_c += 1 139 | if at in t_s: 140 | tmp_s += 1 141 | tmp_r = tmp_s / tmp_c 142 | tmp_r_all += tmp_r 143 | r = s / c 144 | r_all += r 145 | all_s += 1 146 | c = -1 147 | s = -1 148 | t_s = [] 149 | g_s = [] 150 | t, g = ft.readline(), fg.readline() 151 | 152 | tmp_c = -1 153 | tmp_s = -1 154 | for at in g_s: 155 | tmp_c += 1 156 | if at in t_s: 157 | tmp_s += 1 158 | tmp_r = tmp_s / tmp_c 159 | tmp_r_all += tmp_r 160 | r = s / c 161 | r_all += r 162 | all_s += 1 163 | 164 | r_all += r 165 | print("p_strict", r_all / all_s) 166 | print("p_weak", tmp_r_all / all_s) 167 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # ScriptWriter: Narrative-Guided Script Generation 2 | 3 | [![made-with-python](https://img.shields.io/badge/Made%20with-Python-red.svg)](#python) 4 | 5 | #### News 6 | - 2022-1-5: Our new work (ScriptWriter-CPre) has been accepted by TOIS! 7 | - 2021-10-19: We upload the code and data for our new model ScriptWriter-CPre. 8 | - 2020-06-11: We find a minor error in the data thus we upload a new one. We provide the code for building data file from text file. You can now generate your own data with the Utils.py. 9 | - 2020-06-09: We upload the code and data. Note that we do not share the vocab.txt in the data currentely due to the copyright issue. We will upload it as soon as possible. 10 | 11 | ## Abstract 12 | This repository contains the source code and datasets for the ACL 2020 paper [ScriptWriter: Narrative-Guided Script Generation](https://www.aclweb.org/anthology/2020.acl-main.765.pdf) and TOIS paper [Leveraging Narrative to Generate Movie Script](https://dl.acm.org/doi/pdf/10.1145/3507356) by Zhu et al.
13 | 14 | It is appealing to have a system that generates a story or scripts automatically from a storyline, even though this is still out of our reach. In dialogue systems, it would also be useful to drive dialogues by a dialogue plan. In this paper, we address a key problem involved in these applications - guiding a dialogue by a narrative. The proposed model ScriptWriter selects the best response among the candidates that fit the context as well as the given narrative. It keeps track of what in the narrative has been said and what is to be said. A narrative plays a different role than the context (i.e., previous utterances), which is generally used in current dialogue systems. Due to the unavailability of data for this new application, we construct a new large-scale data collection GraphMovie from a movie website where end-users can upload their narratives freely when watching a movie. Experimental results on the dataset show that our proposed approach based on narratives significantly outperforms the baselines that simply use the narrative as a kind of context. 15 | 16 | Authors: Yutao Zhu, Ruihua Song, Zhicheng Dou, Jian-Yun Nie, Jin Zhou 17 | 18 | ## Requirements 19 | We test the code with the following packages.
20 | - Python 3.5
21 | - Tensorflow 1.5 (with GPU support)
22 | - Keras 2.2.4
23 | 24 | ## Usage 25 | - Unzip the compressed [data file](https://drive.google.com/file/d/1fJKI9fzUhPM2dKq2zAFWLbtltv6PT2wh/view?usp=sharing) to the data directory.
26 | - python3 ScriptWriter.py (or python3 ScriptWriter-CPre.py) 27 | 28 | ## Results 29 | | Model | R2@1 | R10@1 | R10@2 | R10@5 | MRR | P_strict | P_weak | 30 | | ----------------- | ----- | ----- | ----- | ----- | ----- | -------- | ------ | 31 | | MVLSTM | 0.651 | 0.217 | 0.384 | 0.732 | 0.395 | 0.198 | 0.224 | 32 | | DL2R | 0.643 | 0.210 | 0.321 | 0.638 | 0.314 | 0.230 | 0.243 | 33 | | SMN | 0.641 | 0.176 | 0.333 | 0.696 | 0.392 | 0.197 | 0.236 | 34 | | DAM | 0.631 | 0.240 | 0.398 | 0.733 | 0.408 | 0.226 | 0.236 | 35 | | DUA | 0.654 | 0.237 | 0.403 | 0.736 | 0.396 | 0.223 | 0.251 | 36 | | IMN | 0.686 | 0.301 | 0.450 | 0.759 | 0.463 | 0.304 | 0.325 | 37 | | IOI | 0.710 | 0.341 | 0.491 | 0.774 | 0.464 | 0.324 | 0.337 | 38 | | MSN | 0.724 | 0.329 | 0.511 | 0.794 | 0.464 | 0.314 | 0.346 | 39 | | ScriptWriter | 0.730 | 0.365 | 0.537 | 0.814 | 0.503 | 0.373 | 0.383 | 40 | | ScriptWriter-CPre | 0.756 | 0.398 | 0.557 | 0.817 | 0.504 | 0.392 | 0.409 | 41 | 42 | ## Citations 43 | If you use the code and datasets, please cite the following paper: 44 | ``` 45 | @inproceedings{zhu-etal-2020-scriptwriter, 46 | title = {{S}cript{W}riter: Narrative-Guided Script Generation}, 47 | author = {Zhu, Yutao and 48 | Song, Ruihua and 49 | Dou, Zhicheng and 50 | Nie, Jian-Yun and 51 | Zhou, Jin}, 52 | booktitle = {Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics}, 53 | month = jul, 54 | year = {2020}, 55 | address = {Online}, 56 | publisher = {Association for Computational Linguistics}, 57 | url = {https://www.aclweb.org/anthology/2020.acl-main.765}, 58 | pages = {8647--8657} 59 | } 60 | @article{zhu-etal-2022-leverage, 61 | title = {Leveraging Narrative to Generate Movie Script}, 62 | author = {Zhu, Yutao and 63 | Song, Ruihua and 64 | Nie, Jian-Yun and 65 | Du, Pan and 66 | Dou, Zhicheng and 67 | Zhou, Jin}, 68 | year = {2022}, 69 | issue_date = {October 2022}, 70 | publisher = {Association for Computing Machinery}, 71 | address = {New York, NY, USA}, 72 | volume = {40}, 73 | number = {4}, 74 | issn = {1046-8188}, 75 | url = {https://doi.org/10.1145/3507356}, 76 | doi = {10.1145/3507356}, 77 | journal = {ACM Trans. Inf. Syst.}, 78 | month = {mar}, 79 | articleno = {86}, 80 | numpages = {32} 81 | } 82 | 83 | ``` 84 | -------------------------------------------------------------------------------- /ScriptWriter-CPre.py: -------------------------------------------------------------------------------- 1 | from modules import * 2 | from keras.preprocessing.sequence import pad_sequences 3 | import Utils 4 | import Evaluate 5 | import pickle 6 | import os 7 | from tqdm import tqdm 8 | import logging 9 | 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 11 | 12 | # for train 13 | embedding_file = "./data/embeddings.pkl" 14 | train_file = "./data/train.gr.pkl" 15 | val_file = "./data/dev.gr.pkl" 16 | evaluate_file = "./data/test.gr.pkl" 17 | 18 | save_path = "./model/cpre/" 19 | result_path = "./output/cpre/" 20 | log_path = "./model/cpre/" 21 | 22 | max_sentence_len = 50 23 | max_num_utterance = 11 24 | batch_size = 50 25 | eval_batch_size = 100 26 | 27 | class ScriptWriter_cpre(): 28 | def __init__(self, eta=0.5): 29 | self.max_num_utterance = max_num_utterance 30 | self.negative_samples = 1 31 | self.max_sentence_len = max_sentence_len 32 | self.word_embedding_size = 200 33 | self.hidden_units = 200 34 | self.total_words = 43514 35 | self.batch_size = batch_size 36 | self.eval_batch_size = eval_batch_size 37 | self.learning_rate_ph = tf.placeholder(tf.float32, shape=[], name='learning_rate') 38 | self.dropout_rate = 0 39 | self.num_heads = 1 40 | self.num_blocks = 3 41 | self.eta = eta 42 | self.gamma = tf.get_variable('gamma', shape=1, dtype=tf.float32, trainable=True, initializer=tf.constant_initializer(0.5)) 43 | 44 | self.embedding_ph = tf.placeholder(tf.float32, shape=(self.total_words, self.word_embedding_size)) 45 | self.utterance_ph = tf.placeholder(tf.int32, shape=(None, max_num_utterance, max_sentence_len)) 46 | self.response_ph = tf.placeholder(tf.int32, shape=(None, max_sentence_len)) 47 | self.gt_response_ph = tf.placeholder(tf.int32, shape=(None, max_sentence_len)) 48 | self.y_true_ph = tf.placeholder(tf.int32, shape=(None,)) 49 | self.narrative_ph = tf.placeholder(tf.int32, shape=(None, max_sentence_len)) 50 | 51 | self.word_embeddings = tf.get_variable('word_embeddings_v', shape=(self.total_words, self.word_embedding_size), dtype=tf.float32, trainable=False) 52 | self.embedding_init = self.word_embeddings.assign(self.embedding_ph) 53 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 54 | self.is_training = True 55 | print("current eta: ", self.eta) 56 | 57 | def load(self, previous_modelpath): 58 | sess = tf.Session() 59 | latest_ckpt = tf.train.latest_checkpoint(previous_modelpath) 60 | # print("recover from checkpoint: " + latest_ckpt) 61 | variables = tf.contrib.framework.get_variables_to_restore() 62 | saver = tf.train.Saver(variables) 63 | saver.restore(sess, latest_ckpt) 64 | return sess 65 | 66 | def build(self): 67 | all_utterances = tf.unstack(self.utterance_ph, num=self.max_num_utterance, axis=1) 68 | reuse = None 69 | alpha_1, alpha_2 = None, None 70 | 71 | response_embeddings = embedding(self.response_ph, initializer=self.word_embeddings) 72 | Hr_stack = [response_embeddings] 73 | for i in range(self.num_blocks): 74 | with tf.variable_scope("num_blocks_{}".format(i)): 75 | response_embeddings, _ = multihead_attention(queries=response_embeddings, keys=response_embeddings, num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate) 76 | response_embeddings = feedforward(response_embeddings, num_units=[self.hidden_units, self.hidden_units]) 77 | Hr_stack.append(response_embeddings) 78 | 79 | gt_response_embeddings = embedding(self.gt_response_ph, initializer=self.word_embeddings) 80 | Hgtr_stack = [gt_response_embeddings] 81 | for i in range(self.num_blocks): 82 | with tf.variable_scope("num_blocks_{}".format(i), reuse=True): 83 | gt_response_embeddings, _ = multihead_attention(queries=gt_response_embeddings, keys=gt_response_embeddings, num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate) 84 | gt_response_embeddings = feedforward(gt_response_embeddings, num_units=[self.hidden_units, self.hidden_units]) 85 | Hgtr_stack.append(gt_response_embeddings) 86 | 87 | narrative_embeddings = embedding(self.narrative_ph, initializer=self.word_embeddings) 88 | Hn_stack = [narrative_embeddings] 89 | for i in range(self.num_blocks): 90 | with tf.variable_scope("num_blocks_{}".format(i), reuse=True): 91 | narrative_embeddings, _ = multihead_attention(queries=narrative_embeddings, keys=narrative_embeddings, num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate) 92 | narrative_embeddings = feedforward(narrative_embeddings, num_units=[self.hidden_units, self.hidden_units]) 93 | Hn_stack.append(narrative_embeddings) 94 | 95 | Mur, Mun = [], [] 96 | self.decay_factor = [] 97 | last_u_reps = [] 98 | turn_id = 0 99 | for utterance in all_utterances: 100 | utterance_embeddings = embedding(utterance, initializer=self.word_embeddings) 101 | Hu_stack = [utterance_embeddings] 102 | for i in range(self.num_blocks): 103 | with tf.variable_scope("num_blocks_{}".format(i), reuse=True): 104 | utterance_embeddings, _ = multihead_attention(queries=utterance_embeddings, keys=utterance_embeddings, num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate) 105 | utterance_embeddings = feedforward(utterance_embeddings, num_units=[self.hidden_units, self.hidden_units]) 106 | Hu_stack.append(utterance_embeddings) 107 | 108 | if turn_id == self.max_num_utterance - 1: 109 | last_u_reps = Hu_stack 110 | 111 | r_a_u_stack = [] 112 | u_a_r_stack = [] 113 | 114 | for i in range(self.num_blocks + 1): 115 | with tf.variable_scope("utterance_attention_response_{}".format(i), reuse=reuse): 116 | u_a_r, _ = multihead_attention(queries=Hu_stack[i], keys=Hr_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate) 117 | u_a_r = feedforward(u_a_r, num_units=[self.hidden_units, self.hidden_units]) 118 | u_a_r_stack.append(u_a_r) 119 | with tf.variable_scope("response_attention_utterance_{}".format(i), reuse=reuse): 120 | r_a_u, _ = multihead_attention(queries=Hr_stack[i], keys=Hu_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate) 121 | r_a_u = feedforward(r_a_u, num_units=[self.hidden_units, self.hidden_units]) 122 | r_a_u_stack.append(r_a_u) 123 | u_a_r_stack.extend(Hu_stack) 124 | r_a_u_stack.extend(Hr_stack) 125 | 126 | n_a_u_stack = [] 127 | u_a_n_stack = [] 128 | for i in range(self.num_blocks + 1): 129 | with tf.variable_scope("narrative_attention_response_{}".format(i), reuse=reuse): 130 | n_a_u, _ = multihead_attention(queries=Hn_stack[i], keys=Hu_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate) 131 | n_a_u = feedforward(n_a_u, num_units=[self.hidden_units, self.hidden_units]) 132 | n_a_u_stack.append(n_a_u) 133 | with tf.variable_scope("response_attention_narrative_{}".format(i), reuse=reuse): 134 | u_a_n, alpha_1 = multihead_attention(queries=Hu_stack[i], keys=Hn_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate) 135 | u_a_n = feedforward(u_a_n, num_units=[self.hidden_units, self.hidden_units]) 136 | u_a_n_stack.append(u_a_n) 137 | n_a_u_stack.extend(Hn_stack) 138 | u_a_n_stack.extend(Hu_stack) 139 | 140 | u_a_r = tf.stack(u_a_r_stack, axis=-1) 141 | r_a_u = tf.stack(r_a_u_stack, axis=-1) 142 | u_a_n = tf.stack(u_a_n_stack, axis=-1) 143 | n_a_u = tf.stack(n_a_u_stack, axis=-1) 144 | 145 | with tf.variable_scope('similarity'): 146 | # sim shape [batch, max_sent_len, max_sent_len, 2 * (stack_num + 1)] 147 | sim_ur = tf.einsum('biks,bjks->bijs', u_a_r, r_a_u) / tf.sqrt(200.0) # for no rp and normal 148 | sim_un = tf.einsum('biks,bjks->bijs', u_a_n, n_a_u) / tf.sqrt(200.0) # for no rp and normal 149 | 150 | self_n = tf.nn.l2_normalize(tf.stack(Hn_stack, axis=-1)) # #for no rp 151 | self_u = tf.nn.l2_normalize(tf.stack(Hu_stack, axis=-1)) # #for no rp 152 | Hn_stack_tensor = tf.stack(Hn_stack, axis=-1) # [batch, o_len, embedding_size, stack] 153 | with tf.variable_scope('similarity'): 154 | self_sim = tf.einsum('biks,bjks->bijs', self_u, self_n) # [batch, u_len, o_len, stack] 155 | self_sim = 1 - self.gamma * tf.reduce_sum(self_sim, axis=1) # [batch, (1), o_len, stack] 156 | Hn_stack = tf.einsum('bjkl,bjl->bjkl', Hn_stack_tensor, self_sim) 157 | Hn_stack = tf.unstack(Hn_stack, axis=-1, num=self.num_blocks + 1) 158 | 159 | Mur.append(sim_ur) 160 | Mun.append(sim_un) 161 | turn_id += 1 162 | if not reuse: 163 | reuse = True 164 | 165 | Hn_stack_for_tracking = tf.layers.dense(tf.stack(Hn_stack, axis=2), self.hidden_units) # [batch, o_len, stack, embedding_size] 166 | Hn_stack_for_tracking = tf.transpose(Hn_stack_for_tracking, perm=[0, 1, 3, 2]) # [batch, o_len, embedding_size, stack] 167 | Hlastu_stack_for_tracking = tf.stack(last_u_reps, axis=-1) # [batch, u_len, embedding_size, stack] 168 | Hr_stack_for_tracking = tf.stack(Hgtr_stack, axis=-1) # [batch, r_len, embedding_size, stack] 169 | Hlastu = tf.transpose(Hlastu_stack_for_tracking, perm=[0, 2, 3, 1]) 170 | Hlastu = tf.squeeze(tf.layers.dense(Hlastu, 1), axis=-1) # [batch, embedding_size, stack] 171 | p1_tensor = tf.nn.softmax(tf.einsum('bnds,bds->bns', Hn_stack_for_tracking, Hlastu), axis=1) # [batch, o_len, stack] 172 | Hlastur = tf.transpose(Hr_stack_for_tracking, perm=[0, 2, 3, 1]) 173 | Hlastur = tf.squeeze(tf.layers.dense(Hlastur, 1), axis=-1) # [batch, embedding_size, stack] 174 | p2_tensor = tf.nn.softmax(tf.einsum('bnds,bds->bns', Hn_stack_for_tracking, Hlastur), axis=1) # [batch, o_len, stack] 175 | p1 = tf.unstack(p1_tensor, num=self.num_blocks + 1, axis=-1) 176 | p2 = tf.unstack(p2_tensor, num=self.num_blocks + 1, axis=-1) 177 | KL_loss = 0.0 178 | for i in range(self.num_blocks + 1): 179 | KL_loss += tf.reduce_mean(tf.keras.losses.kullback_leibler_divergence(p1[i], p2[i])) 180 | KL_loss /= (self.num_blocks + 1) 181 | 182 | r_a_n_stack = [] 183 | n_a_r_stack = [] 184 | for i in range(self.num_blocks + 1): 185 | with tf.variable_scope("narrative_attention_response_{}".format(i), reuse=True): 186 | n_a_r, _ = multihead_attention(queries=Hn_stack[i], keys=Hr_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate) 187 | n_a_r = feedforward(n_a_r, num_units=[self.hidden_units, self.hidden_units]) 188 | n_a_r_stack.append(n_a_r) 189 | with tf.variable_scope("response_attention_narrative_{}".format(i), reuse=True): 190 | r_a_n, _ = multihead_attention(queries=Hr_stack[i], keys=Hn_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate) 191 | r_a_n = feedforward(r_a_n, num_units=[self.hidden_units, self.hidden_units]) 192 | r_a_n_stack.append(r_a_n) 193 | 194 | n_a_r_stack.extend(Hn_stack) 195 | r_a_n_stack.extend(Hr_stack) 196 | n_a_r = tf.stack(n_a_r_stack, axis=-1) 197 | r_a_n = tf.stack(r_a_n_stack, axis=-1) 198 | 199 | with tf.variable_scope('similarity'): 200 | Mrn = tf.einsum('biks,bjks->bijs', n_a_r, r_a_n) / tf.sqrt(200.0) 201 | self.rosim = Mrn 202 | Mur = tf.stack(Mur, axis=1) 203 | Mun = tf.stack(Mun, axis=1) 204 | with tf.variable_scope('cnn_aggregation'): 205 | conv3d = tf.layers.conv3d(Mur, filters=32, kernel_size=[3, 3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv1") 206 | pool3d = tf.layers.max_pooling3d(conv3d, pool_size=[3, 3, 3], strides=[3, 3, 3], padding="SAME") 207 | conv3d2 = tf.layers.conv3d(pool3d, filters=32, kernel_size=[3, 3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv2") 208 | pool3d2 = tf.layers.max_pooling3d(conv3d2, pool_size=[3, 3, 3], strides=[3, 3, 3], padding="SAME") 209 | mur = tf.contrib.layers.flatten(pool3d2) 210 | with tf.variable_scope('cnn_aggregation', reuse=True): 211 | conv3d = tf.layers.conv3d(Mun, filters=32, kernel_size=[3, 3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv1") 212 | pool3d = tf.layers.max_pooling3d(conv3d, pool_size=[3, 3, 3], strides=[3, 3, 3], padding="SAME") 213 | conv3d2 = tf.layers.conv3d(pool3d, filters=32, kernel_size=[3, 3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv2") 214 | pool3d2 = tf.layers.max_pooling3d(conv3d2, pool_size=[3, 3, 3], strides=[3, 3, 3], padding="SAME") 215 | mun = tf.contrib.layers.flatten(pool3d2) 216 | with tf.variable_scope('cnn_aggregation'): 217 | conv2d = tf.layers.conv2d(Mrn, filters=32, kernel_size=[3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv2d") 218 | pool2d = tf.layers.max_pooling2d(conv2d, pool_size=[3, 3], strides=[3, 3], padding="SAME") 219 | conv2d2 = tf.layers.conv2d(pool2d, filters=32, kernel_size=[3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv2d2") 220 | pool2d2 = tf.layers.max_pooling2d(conv2d2, pool_size=[3, 3], strides=[3, 3], padding="SAME") 221 | mrn = tf.contrib.layers.flatten(pool2d2) 222 | 223 | all_vector = tf.concat([mur, mun, mrn], axis=-1) 224 | logits = tf.reshape(tf.layers.dense(all_vector, 1, kernel_initializer=tf.orthogonal_initializer()), [-1]) 225 | 226 | self.y_pred = tf.sigmoid(logits) 227 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate_ph, beta1=0.9, beta2=0.98, epsilon=1e-8) 228 | RS_loss = tf.reduce_mean(tf.clip_by_value(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.cast(self.y_true_ph, tf.float32), logits=logits), -10, 10)) 229 | self.loss = self.eta * RS_loss + (1 - self.eta) * KL_loss 230 | self.all_variables = tf.global_variables() 231 | self.grads_and_vars = optimizer.compute_gradients(self.loss) 232 | 233 | for grad, var in self.grads_and_vars: 234 | if grad is None: 235 | print(var) 236 | 237 | self.capped_gvs = [(tf.clip_by_value(grad, -5, 5), var) for grad, var in self.grads_and_vars] 238 | self.train_op = optimizer.apply_gradients(self.capped_gvs, global_step=self.global_step) 239 | self.saver = tf.train.Saver(max_to_keep=10) 240 | self.alpha_1 = alpha_1 241 | # self.alpha_2 = alpha_2 242 | # self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_step) 243 | 244 | 245 | def evaluate(model_path, eval_file, output_path, eta): 246 | with open(eval_file, 'rb') as f: 247 | utterance, response, narrative, gt_response, y_true = pickle.load(f) 248 | 249 | current_lr = 1e-3 250 | all_candidate_scores = [] 251 | dataset = tf.data.Dataset.from_tensor_slices((utterance, narrative, response, gt_response, y_true)).batch(eval_batch_size) 252 | iterator = dataset.make_initializable_iterator() 253 | data_iterator = iterator.get_next() 254 | 255 | with open(embedding_file, 'rb') as f: 256 | embeddings = pickle.load(f) 257 | 258 | model = ScriptWriter_cpre(eta) 259 | model.build() 260 | sess = model.load(model_path) 261 | sess.run(iterator.initializer) 262 | sess.run(model.embedding_init, feed_dict={model.embedding_ph: embeddings}) 263 | 264 | test_loss = 0.0 265 | step = 0 266 | try: 267 | with tqdm(total=len(y_true), ncols=100) as pbar: 268 | while True: 269 | bu, bn, br, bgtr, by = data_iterator 270 | bu, bn, br, bgtr, by = sess.run([bu, bn, br, bgtr, by]) 271 | candidate_scores, loss = sess.run([model.y_pred, model.loss], feed_dict={ 272 | model.utterance_ph: bu, 273 | model.narrative_ph: bn, 274 | model.response_ph: br, 275 | model.gt_response_ph: bgtr, 276 | model.y_true_ph: by, 277 | model.learning_rate_ph: current_lr 278 | }) 279 | all_candidate_scores.append(candidate_scores) 280 | test_loss += loss 281 | pbar.update(model.eval_batch_size) 282 | step += 1 283 | except tf.errors.OutOfRangeError: 284 | pass 285 | 286 | sess.close() 287 | tf.reset_default_graph() 288 | 289 | all_candidate_scores = np.concatenate(all_candidate_scores, axis=0) 290 | with open(output_path + "test.result.micro_session.txt", "w") as fw: 291 | for sc in all_candidate_scores.tolist(): 292 | fw.write(str(sc) + "\n") 293 | return Evaluate.evaluate_all(all_candidate_scores, y_true), test_loss / step, all_candidate_scores.tolist() 294 | 295 | 296 | def simple_evaluate(sess, model, eval_file): 297 | with open(eval_file, 'rb') as f: 298 | utterance, response, narrative, y_true = pickle.load(f) 299 | utterance, utterance_len = Utils.multi_sequences_padding(utterance, max_sentence_len, max_num_utterance=max_num_utterance) 300 | utterance = np.array(utterance) 301 | narrative = np.array(pad_sequences(narrative, padding='post', maxlen=max_sentence_len)) 302 | response = np.array(pad_sequences(response, padding='post', maxlen=max_sentence_len)) 303 | y_true = np.array(y_true) 304 | all_candidate_scores = [] 305 | dataset = tf.data.Dataset.from_tensor_slices((utterance, narrative, response, y_true)).batch(eval_batch_size) 306 | iterator = dataset.make_initializable_iterator() 307 | data_iterator = iterator.get_next() 308 | sess.run(iterator.initializer) 309 | current_lr = 1e-3 310 | test_loss = 0.0 311 | step = 0 312 | try: 313 | with tqdm(total=len(y_true), ncols=100) as pbar: 314 | while True: 315 | bu, bn, br, by = data_iterator 316 | bu, bn, br, by = sess.run([bu, bn, br, by]) 317 | candidate_scores, loss = sess.run([model.y_pred, model.loss], feed_dict={ 318 | model.utterance_ph: bu, 319 | model.narrative_ph: bn, 320 | model.response_ph: br, 321 | model.y_true_ph: by, 322 | model.gt_response_ph: br, 323 | model.learning_rate_ph: current_lr 324 | }) 325 | all_candidate_scores.append(candidate_scores) 326 | test_loss += loss 327 | pbar.update(eval_batch_size) 328 | step += 1 329 | except tf.errors.OutOfRangeError: 330 | pass 331 | all_candidate_scores = np.concatenate(all_candidate_scores, axis=0) 332 | return Evaluate.evaluate_all(all_candidate_scores, y_true), test_loss / step, all_candidate_scores.tolist() 333 | 334 | 335 | def evaluate_multi_turns(test_file, model_path, output_path): 336 | vocab = {} 337 | vocab_id2word = {} 338 | 339 | with open(embedding_file, 'rb') as f: 340 | embeddings = pickle.load(f) 341 | 342 | model = ScriptWriter_cpre() 343 | model.build() 344 | sess = model.load(model_path) 345 | sess.run(model.embedding_init, feed_dict={model.embedding_ph: embeddings}) 346 | 347 | with open("./data/vocab.txt", "r", encoding="utf-8") as fr: 348 | for idx, line in enumerate(fr): 349 | line = line.strip().split("\t") 350 | vocab[line[0]] = idx + 1 351 | vocab_id2word[idx + 1] = line[0] 352 | vocab["_PAD_"] = 0 353 | vocab_id2word[0] = "_PAD_" 354 | 355 | def initialize(test_file): 356 | initial_file = output_path + "test.multi.0.pkl" 357 | max_turn = 0 358 | narrative_dict = {} 359 | narrative_dict_score = {} 360 | 361 | with open(test_file, 'rb') as f: 362 | utterance, response, outline, labels = pickle.load(f) 363 | new_utterance, new_response, new_narrative, new_labels = [], [], [], [] 364 | for i in range(len(response)): 365 | ut = utterance[i] 366 | if len(ut) == 1: 367 | o = outline[i] 368 | r = response[i] 369 | l = labels[i] 370 | new_utterance.append(ut) 371 | new_response.append(r) 372 | new_narrative.append(o) 373 | new_labels.append(l) 374 | if len(ut) > max_turn: 375 | max_turn = len(ut) 376 | o = "".join([vocab_id2word[x] for x in outline[i]]) 377 | if o not in narrative_dict: 378 | narrative_dict[o] = {0: outline[i]} 379 | narrative_dict_score[o] = {0: [-1]} 380 | r = response[i] 381 | l = labels[i] 382 | if len(ut) in narrative_dict[o]: 383 | narrative_dict[o][len(ut)].append(r) 384 | narrative_dict_score[o][len(ut)].append(l) 385 | else: 386 | narrative_dict[o][len(ut)] = [r] 387 | narrative_dict_score[o][len(ut)] = [l] 388 | 389 | pickle.dump(narrative_dict, open(output_path + "response_candidate.pkl", "wb")) 390 | 391 | new_data = [new_utterance, new_response, new_narrative, new_labels] 392 | pickle.dump(new_data, open(initial_file, "wb")) 393 | 394 | (r2_1, r10_1, r10_2, r10_5, mrr), eva_loss, result = simple_evaluate(sess, model, initial_file) 395 | with open(output_path + "test.result.multi.0.txt", "w") as fw: 396 | fw.write("R2@1: %f, R10@1: %f, R10@2: %f, R10@5: %f, MRR: %f\n" % (r2_1, r10_1, r10_2, r10_5, mrr)) 397 | for r in result: 398 | fw.write(str(r) + "\n") 399 | 400 | return max_turn, narrative_dict, narrative_dict_score 401 | 402 | max_turn, narrative_dict, narrative_dict_score = initialize(test_file) 403 | for turn in range(1, max_turn): 404 | score = [] 405 | with open(output_path + "test.result.multi." + str(turn - 1) + ".txt", "r") as fr: 406 | for idx, line in enumerate(fr): 407 | if idx == 0: 408 | continue 409 | score.append(float(line.strip())) 410 | with open(output_path + "test.multi." + str(turn - 1) + ".pkl", "rb") as fr: 411 | utterance, response, narrative, y_true = pickle.load(fr) 412 | 413 | new_utterance = [] 414 | new_response = [] 415 | new_narrative = [] 416 | new_labels = [] 417 | 418 | for i, o in enumerate(narrative): 419 | if i % 10 == 0: 420 | sent_o = "".join([vocab_id2word[x] for x in o]) 421 | if turn + 1 in narrative_dict[sent_o]: 422 | new_response.extend(narrative_dict[sent_o][turn + 1]) 423 | score_sub_list = score[i:i + 10] 424 | response_sub_list = response[i:i + 10] 425 | max_idx = score_sub_list.index(max(score_sub_list)) 426 | selected_response = response_sub_list[max_idx] 427 | for ut in utterance[i:i + 10]: 428 | tmp = ut + [selected_response] 429 | new_utterance.append(tmp) 430 | new_narrative.extend([o] * 10) 431 | new_labels.extend(narrative_dict_score[sent_o][turn + 1]) 432 | 433 | new_data = [new_utterance, new_response, new_narrative, new_labels] 434 | new_file = output_path + "test.multi." + str(turn) + ".pkl" 435 | pickle.dump(new_data, open(new_file, "wb")) 436 | 437 | (r2_1, r10_1, r10_2, r10_5, mrr), eva_loss, result = simple_evaluate(sess, model, new_file) 438 | with open(output_path + "test.result.multi." + str(turn) + ".txt", "w") as fw: 439 | fw.write("R2@1: %f, R10@1: %f, R10@2: %f, R10@5: %f, MRR: %f\n" % (r2_1, r10_1, r10_2, r10_5, mrr)) 440 | for r in result: 441 | fw.write(str(r) + "\n") 442 | 443 | 444 | def train(eta=0.5, load=False, model_path=None, logger=None): 445 | config = tf.ConfigProto(allow_soft_placement=True) 446 | config.gpu_options.allow_growth = True 447 | epoch = 0 448 | best_result = [0.0, 0.0, 0.0, 0.0, 0.0] 449 | with tf.Session(config=config) as sess: 450 | with open(embedding_file, 'rb') as f: 451 | embeddings = pickle.load(f, encoding="bytes") 452 | with open(train_file, 'rb') as f: 453 | utterance_train, response_train, narrative_train, gt_response_train, y_true_train = pickle.load(f) 454 | with open(val_file, "rb") as f: 455 | utterance_val, response_val, narrative_val, gt_response_val, y_true_val = pickle.load(f) 456 | 457 | train_dataset = tf.data.Dataset.from_tensor_slices((utterance_train, narrative_train, response_train, gt_response_train, y_true_train)).shuffle(1024).batch(batch_size) 458 | train_iterator = train_dataset.make_initializable_iterator() 459 | train_data_iterator = train_iterator.get_next() 460 | 461 | val_dataset = tf.data.Dataset.from_tensor_slices((utterance_val, narrative_val, response_val, gt_response_val, y_true_val)).batch(batch_size) 462 | val_iterator = val_dataset.make_initializable_iterator() 463 | val_data_iterator = val_iterator.get_next() 464 | 465 | model = ScriptWriter_cpre(eta=eta) 466 | model.build() 467 | 468 | if load: 469 | sess = model.load(model_path) 470 | 471 | sess.run(tf.global_variables_initializer()) 472 | sess.run(model.embedding_init, feed_dict={model.embedding_ph: embeddings}) 473 | current_lr = 1e-3 474 | 475 | while epoch < 4: 476 | print("\nEpoch ", epoch + 1, "/ 4") 477 | train_loss = 0.0 478 | sess.run(train_iterator.initializer) 479 | step = 0 480 | try: 481 | with tqdm(total=len(y_true_train), ncols=100) as pbar: 482 | while True: 483 | bu, bn, br, bgtr, by = train_data_iterator 484 | bu, bn, br, bgtr, by = sess.run([bu, bn, br, bgtr, by]) 485 | _, loss = sess.run([model.train_op, model.loss], feed_dict={ 486 | model.utterance_ph: bu, 487 | model.narrative_ph: bn, 488 | model.response_ph: br, 489 | model.gt_response_ph: bgtr, 490 | model.y_true_ph: by, 491 | model.learning_rate_ph: current_lr 492 | }) 493 | train_loss += loss 494 | pbar.set_postfix(learning_rate=current_lr, loss=loss) 495 | pbar.update(model.batch_size) 496 | step += 1 497 | if step % 500 == 0: 498 | val_loss = 0.0 499 | val_step = 0 500 | sess.run(val_iterator.initializer) 501 | all_candidate_scores = [] 502 | try: 503 | while True: 504 | bu, bn, br, bgtr, by = val_data_iterator 505 | bu, bn, br, bgtr, by = sess.run([bu, bn, br, bgtr, by]) 506 | candidate_scores, loss = sess.run([model.y_pred, model.loss], feed_dict={ 507 | model.utterance_ph: bu, 508 | model.narrative_ph: bn, 509 | model.response_ph: br, 510 | model.gt_response_ph: bgtr, 511 | model.y_true_ph: by, 512 | }) 513 | all_candidate_scores.append(candidate_scores) 514 | val_loss += loss 515 | val_step += 1 516 | except tf.errors.OutOfRangeError: 517 | pass 518 | all_candidate_scores = np.concatenate(all_candidate_scores, axis=0) 519 | result = Evaluate.evaluate_all(all_candidate_scores, y_true_val) 520 | if result[0] + result[1] + result[2] + result[3] + result[4] > best_result[0] + best_result[1] + best_result[2] + best_result[3] + best_result[4]: 521 | best_result = result 522 | tqdm.write("Current best result on validation set: r2@1 %.3f, r10@1 %.3f, r10@2 %.3f, r10@5 %.3f, mrr %.3f" % (best_result[0], best_result[1], best_result[2], best_result[3], best_result[4])) 523 | logger.info("Current best result on validation set: r2@1 %.3f, r10@1 %.3f, r10@2 %.3f, r10@5 %.3f, mrr %.3f" % (best_result[0], best_result[1], best_result[2], best_result[3], best_result[4])) 524 | model.saver.save(sess, save_path + "model") 525 | patience = 0 526 | else: 527 | patience += 1 528 | if patience >= 3: 529 | current_lr *= 0.5 530 | except tf.errors.OutOfRangeError: 531 | pass 532 | 533 | val_loss = 0.0 534 | val_step = 0 535 | sess.run(val_iterator.initializer) 536 | all_candidate_scores = [] 537 | try: 538 | while True: 539 | bu, bn, br, bgtr, by = val_data_iterator 540 | bu, bn, br, bgtr, by = sess.run([bu, bn, br, bgtr, by]) 541 | candidate_scores, loss = sess.run([model.y_pred, model.loss], feed_dict={ 542 | model.utterance_ph: bu, 543 | model.narrative_ph: bn, 544 | model.response_ph: br, 545 | model.gt_response_ph: bgtr, 546 | model.y_true_ph: by 547 | }) 548 | all_candidate_scores.append(candidate_scores) 549 | val_loss += loss 550 | val_step += 1 551 | except tf.errors.OutOfRangeError: 552 | pass 553 | all_candidate_scores = np.concatenate(all_candidate_scores, axis=0) 554 | result = Evaluate.evaluate_all(all_candidate_scores, y_true_val) 555 | if result[0] + result[1] + result[2] + result[3] + result[4] > best_result[0] + best_result[1] + best_result[2] + best_result[3] + best_result[4]: 556 | best_result = result 557 | tqdm.write("Current best result on validation set: r2@1 %.3f, r10@1 %.3f, r10@2 %.3f, r10@5 %.3f, mrr %.3f" % (best_result[0], best_result[1], best_result[2], best_result[3], best_result[4])) 558 | logger.info("Current best result on validation set: r2@1 %.3f, r10@1 %.3f, r10@2 %.3f, r10@5 %.3f, mrr %.3f" % (best_result[0], best_result[1], best_result[2], best_result[3], best_result[4])) 559 | model.saver.save(sess, save_path + "model") 560 | tqdm.write('Epoch No: %d, the train loss is %f, the dev loss is %f' % (epoch + 1, train_loss / step, val_loss / val_step)) 561 | epoch += 1 562 | sess.close() 563 | tf.reset_default_graph() 564 | 565 | 566 | if __name__ == "__main__": 567 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 568 | log_path = "./model/cpre/all_log" 569 | logging.basicConfig(filename=log_path, level=logging.INFO) 570 | logger = logging.getLogger(__name__) 571 | eta = 0.7 572 | save_path = "./model/cpre/" 573 | result_path = "./output/cpre/" 574 | if not os.path.exists(save_path): 575 | os.mkdir(save_path) 576 | if not os.path.exists(result_path): 577 | os.mkdir(result_path) 578 | logger.info("Current Eta: %.2f" % eta) 579 | train(eta=eta, logger=logger) 580 | 581 | (r2_1, r10_1, r10_2, r10_5, mrr), eva_loss, _ = evaluate(save_path, evaluate_file, output_path=result_path, eta=eta) 582 | print("Loss on test set: %f, R2@1: %f, R10@1: %f, R10@2: %f, R10@5: %f, MRR: %f" % (eva_loss, r2_1, r10_1, r10_2, r10_5, mrr)) 583 | 584 | # to evaluate multi-turn results, the vocab file is needed 585 | # evaluate_multi_turns(test_file=evaluate_file, model_path=save_path, output_path=result_path) 586 | # Evaluate.recover_and_show(result_path) 587 | # test_file = result_path + "test.result.multi.txt" 588 | # gold_file = "../data/ground_truth.result.mul.txt" 589 | # Evaluate.evaluate_multi_turn_result(test_file, gold_file) 590 | 591 | -------------------------------------------------------------------------------- /ScriptWriter.py: -------------------------------------------------------------------------------- 1 | from modules import * 2 | from keras.preprocessing.sequence import pad_sequences 3 | import Utils 4 | import Evaluate 5 | import pickle 6 | import os 7 | from tqdm import tqdm 8 | 9 | embedding_file = "./data/embeddings.pkl" 10 | train_file = "./data/train.pkl" 11 | val_file = "./data/dev.pkl" 12 | evaluate_file = "./data/test.pkl" 13 | evaluate_embedding_file = "./data/embeddings.pkl" 14 | 15 | max_sentence_len = 50 16 | max_num_utterance = 11 17 | batch_size = 64 18 | eval_batch_size = 64 19 | 20 | 21 | class ScripteWriter(): 22 | def __init__(self, data_iterator): 23 | self.max_num_utterance = max_num_utterance 24 | self.negative_samples = 1 25 | self.max_sentence_len = max_sentence_len 26 | self.word_embedding_size = 200 27 | self.hidden_units = 200 28 | self.total_words = 43514 29 | self.batch_size = batch_size 30 | self.eval_batch_size = eval_batch_size 31 | self.initial_learning_rate = 1e-3 32 | self.dropout_rate = 0 33 | self.num_heads = 1 34 | self.num_blocks = 3 35 | # self.gamma = 0.1 36 | self.gamma = tf.get_variable('gamma', shape=1, dtype=tf.float32, trainable=True, initializer=tf.constant_initializer(0.5)) 37 | 38 | self.utterance_ph = data_iterator[0] 39 | self.response_ph = data_iterator[4] 40 | self.y_true = data_iterator[6] 41 | self.embedding_ph = tf.placeholder(tf.float32, shape=(self.total_words, self.word_embedding_size)) 42 | self.response_len = data_iterator[5] 43 | self.all_utterance_len_ph = data_iterator[1] 44 | self.narrative_ph = data_iterator[2] 45 | self.narrative_len = data_iterator[3] 46 | self.word_embeddings = tf.get_variable('word_embeddings_v', shape=(self.total_words, self.word_embedding_size), dtype=tf.float32, trainable=False) 47 | self.embedding_init = self.word_embeddings.assign(self.embedding_ph) 48 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 49 | self.is_training = True 50 | 51 | def load(self, previous_modelpath): 52 | sess = tf.Session() 53 | latest_ckpt = tf.train.latest_checkpoint(previous_modelpath) 54 | # latest_ckpt = previous_modelpath + "model.4" 55 | print("recover from checkpoint: " + latest_ckpt) 56 | variables = tf.contrib.framework.get_variables_to_restore() 57 | saver = tf.train.Saver(variables) 58 | saver.restore(sess, latest_ckpt) 59 | return sess 60 | 61 | def build(self): 62 | all_utterances = tf.unstack(self.utterance_ph, num=self.max_num_utterance, axis=1) 63 | all_utterance_len = tf.unstack(self.all_utterance_len_ph, num=self.max_num_utterance, axis=1) 64 | reuse = None 65 | alpha_1 = None 66 | 67 | response_embeddings = embedding(self.response_ph, initializer=self.word_embeddings) 68 | response_embeddings = tf.layers.dropout(response_embeddings, rate=self.dropout_rate, training=tf.convert_to_tensor(self.is_training)) 69 | Hr_stack = [response_embeddings] 70 | for i in range(self.num_blocks): 71 | with tf.variable_scope("num_blocks_{}".format(i)): 72 | response_embeddings, _ = multihead_attention(queries=response_embeddings, keys=response_embeddings, num_units=self.hidden_units, num_heads=self.num_heads, dropout_rate=self.dropout_rate, is_training=self.is_training, causality=False) 73 | response_embeddings = feedforward(response_embeddings, num_units=[self.hidden_units, self.hidden_units]) 74 | Hr_stack.append(response_embeddings) 75 | 76 | narrative_embeddings = embedding(self.narrative_ph, initializer=self.word_embeddings) 77 | narrative_embeddings = tf.layers.dropout(narrative_embeddings, rate=self.dropout_rate, training=tf.convert_to_tensor(self.is_training)) 78 | Hn_stack = [narrative_embeddings] 79 | for i in range(self.num_blocks): 80 | with tf.variable_scope("num_blocks_{}".format(i), reuse=True): 81 | narrative_embeddings, _ = multihead_attention(queries=narrative_embeddings, keys=narrative_embeddings, num_units=self.hidden_units, num_heads=self.num_heads, dropout_rate=self.dropout_rate, is_training=self.is_training, causality=False) 82 | narrative_embeddings = feedforward(narrative_embeddings, num_units=[self.hidden_units, self.hidden_units]) 83 | Hn_stack.append(narrative_embeddings) 84 | 85 | Mur, Mun = [], [] 86 | self.decay_factor = [] 87 | 88 | for utterance, utterance_len in zip(all_utterances, all_utterance_len): 89 | utterance_embeddings = embedding(utterance, initializer=self.word_embeddings) 90 | Hu_stack = [utterance_embeddings] 91 | for i in range(self.num_blocks): 92 | with tf.variable_scope("num_blocks_{}".format(i), reuse=True): 93 | utterance_embeddings, _ = multihead_attention(queries=utterance_embeddings, keys=utterance_embeddings, num_units=self.hidden_units, num_heads=self.num_heads, dropout_rate=self.dropout_rate, is_training=self.is_training, causality=False) 94 | utterance_embeddings = feedforward(utterance_embeddings, num_units=[self.hidden_units, self.hidden_units]) 95 | Hu_stack.append(utterance_embeddings) 96 | 97 | r_a_u_stack = [] 98 | u_a_r_stack = [] 99 | 100 | for i in range(self.num_blocks + 1): 101 | with tf.variable_scope("utterance_attention_response_{}".format(i), reuse=reuse): 102 | u_a_r, _ = multihead_attention(queries=Hu_stack[i], keys=Hr_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False) 103 | u_a_r = feedforward(u_a_r, num_units=[self.hidden_units, self.hidden_units]) 104 | u_a_r_stack.append(u_a_r) 105 | with tf.variable_scope("response_attention_utterance_{}".format(i), reuse=reuse): 106 | r_a_u, _ = multihead_attention(queries=Hr_stack[i], keys=Hu_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False) 107 | r_a_u = feedforward(r_a_u, num_units=[self.hidden_units, self.hidden_units]) 108 | r_a_u_stack.append(r_a_u) 109 | u_a_r_stack.extend(Hu_stack) 110 | r_a_u_stack.extend(Hr_stack) 111 | 112 | n_a_u_stack = [] 113 | u_a_n_stack = [] 114 | for i in range(self.num_blocks + 1): 115 | with tf.variable_scope("narrative_attention_utterance_{}".format(i), reuse=reuse): 116 | n_a_u, _ = multihead_attention(queries=Hn_stack[i], keys=Hu_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False) 117 | n_a_u = feedforward(n_a_u, num_units=[self.hidden_units, self.hidden_units]) 118 | n_a_u_stack.append(n_a_u) 119 | with tf.variable_scope("utterance_attention_narrative_{}".format(i), reuse=reuse): 120 | u_a_n, alpha_1 = multihead_attention(queries=Hu_stack[i], keys=Hn_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False) 121 | u_a_n = feedforward(u_a_n, num_units=[self.hidden_units, self.hidden_units]) 122 | u_a_n_stack.append(u_a_n) 123 | 124 | n_a_u_stack.extend(Hn_stack) 125 | u_a_n_stack.extend(Hu_stack) 126 | 127 | u_a_r = tf.stack(u_a_r_stack, axis=-1) 128 | r_a_u = tf.stack(r_a_u_stack, axis=-1) 129 | u_a_n = tf.stack(u_a_n_stack, axis=-1) 130 | n_a_u = tf.stack(n_a_u_stack, axis=-1) 131 | 132 | with tf.variable_scope('similarity'): 133 | # sim shape [batch, max_sent_len, max_sent_len, 2 * (stack_num + 1)] 134 | sim_ur = tf.einsum('biks,bjks->bijs', u_a_r, r_a_u) / tf.sqrt(200.0) 135 | sim_un = tf.einsum('biks,bjks->bijs', u_a_n, n_a_u) / tf.sqrt(200.0) 136 | 137 | self_n = tf.nn.l2_normalize(tf.stack(Hn_stack, axis=-1)) 138 | self_u = tf.nn.l2_normalize(tf.stack(Hu_stack, axis=-1)) 139 | with tf.variable_scope('similarity'): 140 | self_sim = tf.einsum('biks,bjks->bijs', self_u, self_n) # [batch * len * len * stack] 141 | self_sim = tf.unstack(self_sim, axis=-1, num=self.num_blocks + 1) 142 | reuse2 = reuse 143 | for i in range(self.num_blocks + 1): 144 | tmp_self_sim = tf.expand_dims(self_sim[i], axis=-1) 145 | tmp_self_sim = 1 - self.gamma * tf.layers.conv2d(tmp_self_sim, filters=1, kernel_size=[max_sentence_len, 1], padding="valid", kernel_initializer=tf.ones_initializer, use_bias=False, trainable=False, reuse=reuse2) # for auto2 146 | tmp_self_sim = tf.squeeze(tmp_self_sim, axis=1) 147 | tmp_self_sim = tf.squeeze(tmp_self_sim, axis=-1) 148 | Hn_stack[i] = tf.einsum('bik,bi->bik', Hn_stack[i], tmp_self_sim) 149 | reuse2 = True 150 | 151 | Mur.append(sim_ur) 152 | Mun.append(sim_un) 153 | 154 | if not reuse: 155 | reuse = True 156 | 157 | r_a_n_stack = [] 158 | n_a_r_stack = [] 159 | reuse2 = False 160 | for i in range(self.num_blocks + 1): 161 | with tf.variable_scope("narrative_attention_response_{}".format(i), reuse=reuse2): 162 | n_a_r, _ = multihead_attention(queries=Hn_stack[i], keys=Hr_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False) 163 | n_a_r = feedforward(n_a_r, num_units=[self.hidden_units, self.hidden_units]) 164 | n_a_r_stack.append(n_a_r) 165 | with tf.variable_scope("response_attention_narrative_{}".format(i), reuse=reuse2): 166 | r_a_n, _ = multihead_attention(queries=Hr_stack[i], keys=Hn_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False) 167 | r_a_n = feedforward(r_a_n, num_units=[self.hidden_units, self.hidden_units]) 168 | r_a_n_stack.append(r_a_n) 169 | 170 | n_a_r_stack.extend(Hn_stack) 171 | r_a_n_stack.extend(Hr_stack) 172 | n_a_r = tf.stack(n_a_r_stack, axis=-1) 173 | r_a_n = tf.stack(r_a_n_stack, axis=-1) 174 | 175 | with tf.variable_scope('similarity'): 176 | Mrn = tf.einsum('biks,bjks->bijs', n_a_r, r_a_n) / tf.sqrt(200.0) 177 | 178 | Mur = tf.stack(Mur, axis=1) 179 | Mun = tf.stack(Mun, axis=1) 180 | with tf.variable_scope('cnn_aggregation'): 181 | conv3d = tf.layers.conv3d(Mur, filters=32, kernel_size=[3, 3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv1") 182 | pool3d = tf.layers.max_pooling3d(conv3d, pool_size=[3, 3, 3], strides=[3, 3, 3], padding="SAME") 183 | conv3d2 = tf.layers.conv3d(pool3d, filters=16, kernel_size=[3, 3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv2") 184 | pool3d2 = tf.layers.max_pooling3d(conv3d2, pool_size=[3, 3, 3], strides=[3, 3, 3], padding="SAME") 185 | mur = tf.contrib.layers.flatten(pool3d2) 186 | with tf.variable_scope('cnn_aggregation', reuse=True): 187 | conv3d = tf.layers.conv3d(Mun, filters=32, kernel_size=[3, 3, 3], padding="SAME", activation=tf.nn.elu, 188 | kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv1") 189 | pool3d = tf.layers.max_pooling3d(conv3d, pool_size=[3, 3, 3], strides=[3, 3, 3], padding="SAME") 190 | conv3d2 = tf.layers.conv3d(pool3d, filters=16, kernel_size=[3, 3, 3], padding="SAME", activation=tf.nn.elu, 191 | kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv2") 192 | pool3d2 = tf.layers.max_pooling3d(conv3d2, pool_size=[3, 3, 3], strides=[3, 3, 3], padding="SAME") 193 | mun = tf.contrib.layers.flatten(pool3d2) 194 | 195 | with tf.variable_scope('cnn_aggregation'): 196 | conv2d = tf.layers.conv2d(Mrn, filters=32, kernel_size=[3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv2d") 197 | pool2d = tf.layers.max_pooling2d(conv2d, pool_size=[3, 3], strides=[3, 3], padding="SAME") 198 | conv2d2 = tf.layers.conv2d(pool2d, filters=16, kernel_size=[3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv2d2") 199 | pool2d2 = tf.layers.max_pooling2d(conv2d2, pool_size=[3, 3], strides=[3, 3], padding="SAME") 200 | mrn = tf.contrib.layers.flatten(pool2d2) 201 | 202 | all_vector = tf.concat([mur, mun, mrn], axis=-1) 203 | logits = tf.reshape(tf.layers.dense(all_vector, 1, kernel_initializer=tf.orthogonal_initializer()), [-1]) 204 | 205 | self.y_pred = tf.sigmoid(logits) 206 | self.learning_rate = tf.train.exponential_decay(self.initial_learning_rate, global_step=self.global_step, decay_steps=1000, decay_rate=0.9, staircase=True) 207 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=0.9, beta2=0.98, epsilon=1e-8) 208 | self.loss = tf.reduce_mean(tf.clip_by_value(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.cast(self.y_true, tf.float32), logits=logits), -10, 10)) 209 | self.all_variables = tf.global_variables() 210 | self.grads_and_vars = optimizer.compute_gradients(self.loss) 211 | 212 | for grad, var in self.grads_and_vars: 213 | if grad is None: 214 | print(var) 215 | 216 | self.capped_gvs = [(tf.clip_by_value(grad, -5, 5), var) for grad, var in self.grads_and_vars] 217 | self.train_op = optimizer.apply_gradients(self.capped_gvs, global_step=self.global_step) 218 | self.saver = tf.train.Saver(max_to_keep=10) 219 | self.alpha_1 = alpha_1 220 | 221 | 222 | def evaluate(model_path, eval_file, output_path): 223 | with open(eval_file, 'rb') as f: 224 | utterance, response, narrative, labels = pickle.load(f) 225 | 226 | all_candidate_scores = [] 227 | utterance, utterance_len = Utils.multi_sequences_padding(utterance, max_sentence_len, max_num_utterance=max_num_utterance) 228 | utterance, utterance_len = np.array(utterance), np.array(utterance_len) 229 | narrative_len = np.array(Utils.get_sequences_length(narrative, maxlen=max_sentence_len)) 230 | narrative = np.array(pad_sequences(narrative, padding='post', maxlen=max_sentence_len)) 231 | response_len = np.array(Utils.get_sequences_length(response, maxlen=max_sentence_len)) 232 | response = np.array(pad_sequences(response, padding='post', maxlen=max_sentence_len)) 233 | y_true = np.array(labels) 234 | 235 | dataset = tf.data.Dataset.from_tensor_slices((utterance, utterance_len, narrative, narrative_len, response, response_len, y_true)) 236 | dataset = dataset.batch(eval_batch_size) 237 | iterator = dataset.make_initializable_iterator() 238 | 239 | data_iterator = iterator.get_next() 240 | 241 | with open(evaluate_embedding_file, 'rb') as f: 242 | embeddings = pickle.load(f) 243 | 244 | model = ScripteWriter(data_iterator) 245 | model.build() 246 | sess = model.load(model_path) 247 | sess.run(iterator.initializer) 248 | sess.run(model.embedding_init, feed_dict={model.embedding_ph: embeddings}) 249 | 250 | test_loss = 0.0 251 | step = 0 252 | try: 253 | with tqdm(total=len(y_true)) as pbar: 254 | while True: 255 | candidate_scores, loss = sess.run([model.y_pred, model.loss]) 256 | all_candidate_scores.append(candidate_scores) 257 | test_loss += loss 258 | pbar.update(model.eval_batch_size) 259 | step += 1 260 | except tf.errors.OutOfRangeError: 261 | pass 262 | 263 | all_candidate_scores = np.concatenate(all_candidate_scores, axis=0) 264 | with open(output_path + "test.result.micro_session.txt", "w") as fw: 265 | for sc in all_candidate_scores.tolist(): 266 | fw.write(str(sc) + "\n") 267 | return Evaluate.evaluate_all(all_candidate_scores, labels), test_loss / step, all_candidate_scores.tolist() 268 | 269 | 270 | def simple_evaluate(sess, model, iterator, utterance_ph, utterance_len_ph, narrative_ph, narrative_len_ph, response_ph, response_len_ph, y_true_ph, eval_file): 271 | with open(eval_file, 'rb') as f: 272 | utterance, response, narrative, labels = pickle.load(f) 273 | 274 | all_candidate_scores = [] 275 | utterance, utterance_len = Utils.multi_sequences_padding(utterance, max_sentence_len, max_num_utterance=max_num_utterance) 276 | utterance, utterance_len = np.array(utterance), np.array(utterance_len) 277 | narrative_len = np.array(Utils.get_sequences_length(narrative, maxlen=max_sentence_len)) 278 | narrative = np.array(pad_sequences(narrative, padding='post', maxlen=max_sentence_len)) 279 | response_len = np.array(Utils.get_sequences_length(response, maxlen=max_sentence_len)) 280 | response = np.array(pad_sequences(response, padding='post', maxlen=max_sentence_len)) 281 | y_true = np.array(labels) 282 | 283 | sess.run(iterator.initializer, feed_dict={utterance_ph: utterance, 284 | utterance_len_ph: utterance_len, 285 | narrative_ph: narrative, 286 | narrative_len_ph: narrative_len, 287 | response_ph: response, 288 | response_len_ph: response_len, 289 | y_true_ph: y_true}) 290 | 291 | test_loss = 0.0 292 | step = 0 293 | try: 294 | with tqdm(total=len(y_true)) as pbar: 295 | while True: 296 | candidate_scores, loss = sess.run([model.y_pred, model.loss]) 297 | all_candidate_scores.append(candidate_scores) 298 | test_loss += loss 299 | pbar.update(eval_batch_size) 300 | step += 1 301 | except tf.errors.OutOfRangeError: 302 | pass 303 | 304 | all_candidate_scores = np.concatenate(all_candidate_scores, axis=0) 305 | return Evaluate.evaluate_all(all_candidate_scores, labels), test_loss / step, all_candidate_scores.tolist() 306 | 307 | 308 | def evaluate_multi_turns(test_file, model_path, output_path): 309 | vocab = {} 310 | vocab_id2word = {} 311 | 312 | utterance_ph = tf.placeholder(tf.int32, shape=(None, max_num_utterance, max_sentence_len)) 313 | response_ph = tf.placeholder(tf.int32, shape=(None, max_sentence_len)) 314 | y_true_ph = tf.placeholder(tf.int32, shape=(None,)) 315 | response_len_ph = tf.placeholder(tf.int32, shape=(None,)) 316 | utterance_len_ph = tf.placeholder(tf.int32, shape=(None, max_num_utterance)) 317 | narrative_ph = tf.placeholder(tf.int32, shape=(None, max_sentence_len)) 318 | narrative_len_ph = tf.placeholder(tf.int32, shape=(None,)) 319 | 320 | with open(evaluate_embedding_file, 'rb') as f: 321 | embeddings = pickle.load(f) 322 | 323 | dataset = tf.data.Dataset.from_tensor_slices((utterance_ph, utterance_len_ph, narrative_ph, narrative_len_ph, response_ph, response_len_ph, y_true_ph)) 324 | dataset = dataset.batch(eval_batch_size) 325 | iterator = dataset.make_initializable_iterator() 326 | 327 | data_iterator = iterator.get_next() 328 | 329 | model = ScripteWriter(data_iterator) 330 | model.build() 331 | sess = model.load(model_path) 332 | sess.run(model.embedding_init, feed_dict={model.embedding_ph: embeddings}) 333 | 334 | with open("./data/vocab.txt", "r", encoding="utf-8") as fr: 335 | for idx, line in enumerate(fr): 336 | line = line.strip().split("\t") 337 | vocab[line[0]] = idx + 1 338 | vocab_id2word[idx + 1] = line[0] 339 | vocab["_PAD_"] = 0 340 | vocab_id2word[0] = "_PAD_" 341 | 342 | def initialize(test_file): 343 | initial_file = output_path + "test.multi.0.pkl" 344 | max_turn = 0 345 | narrative_dict = {} 346 | narrative_dict_score = {} 347 | 348 | with open(test_file, 'rb') as f: 349 | utterance, response, narrative, labels = pickle.load(f) 350 | new_utterance, new_response, new_narrative, new_labels = [], [], [], [] 351 | for i in range(len(response)): 352 | ut = utterance[i] 353 | if len(ut) == 1: 354 | o = narrative[i] 355 | r = response[i] 356 | l = labels[i] 357 | new_utterance.append(ut) 358 | new_response.append(r) 359 | new_narrative.append(o) 360 | new_labels.append(l) 361 | if len(ut) > max_turn: 362 | max_turn = len(ut) 363 | o = "".join([vocab_id2word[x] for x in narrative[i]]) 364 | if o not in narrative_dict: 365 | narrative_dict[o] = {0: narrative[i]} 366 | narrative_dict_score[o] = {0: [-1]} 367 | r = response[i] 368 | l = labels[i] 369 | if len(ut) in narrative_dict[o]: 370 | narrative_dict[o][len(ut)].append(r) 371 | narrative_dict_score[o][len(ut)].append(l) 372 | else: 373 | narrative_dict[o][len(ut)] = [r] 374 | narrative_dict_score[o][len(ut)] = [l] 375 | 376 | pickle.dump(narrative_dict, open(output_path + "response_candidate.pkl", "wb")) 377 | 378 | new_data = [new_utterance, new_response, new_narrative, new_labels] 379 | pickle.dump(new_data, open(initial_file, "wb")) 380 | 381 | (acc, r2_1, r10_1, r10_2, r10_5, mrr), eva_loss, result = simple_evaluate(sess, model, iterator, utterance_ph, utterance_len_ph, narrative_ph, narrative_len_ph, response_ph, response_len_ph, y_true_ph, initial_file) 382 | with open(output_path + "test.result.multi.0.txt", "w") as fw: 383 | fw.write("R2@1: %f, R10@1: %f, R10@2: %f, R10@5: %f, MRR: %f\n" % (r2_1, r10_1, r10_2, r10_5, mrr)) 384 | for r in result: 385 | fw.write(str(r) + "\n") 386 | 387 | return max_turn, narrative_dict, narrative_dict_score 388 | 389 | max_turn, narrative_dict, narrative_dict_score = initialize(test_file) 390 | 391 | for turn in range(1, max_turn): 392 | # for turn in range(1, 2): 393 | score = [] 394 | with open(output_path + "test.result.multi." + str(turn - 1) + ".txt", "r") as fr: 395 | for idx, line in enumerate(fr): 396 | if idx == 0: 397 | continue 398 | score.append(float(line.strip())) 399 | with open(output_path + "test.multi." + str(turn - 1) + ".pkl", "rb") as fr: 400 | utterance, response, narrative, labels = pickle.load(fr) 401 | 402 | new_utterance = [] 403 | new_response = [] 404 | new_narrative = [] 405 | new_labels = [] 406 | 407 | for i, o in enumerate(narrative): 408 | if i % 10 == 0: 409 | sent_o = "".join([vocab_id2word[x] for x in o]) 410 | if turn + 1 in narrative_dict[sent_o]: 411 | new_response.extend(narrative_dict[sent_o][turn + 1]) 412 | score_sub_list = score[i:i + 10] 413 | response_sub_list = response[i:i + 10] 414 | max_idx = score_sub_list.index(max(score_sub_list)) 415 | selected_response = response_sub_list[max_idx] 416 | for ut in utterance[i:i + 10]: 417 | tmp = ut + [selected_response] 418 | new_utterance.append(tmp) 419 | new_narrative.extend([o] * 10) 420 | new_labels.extend(narrative_dict_score[sent_o][turn + 1]) 421 | 422 | new_data = [new_utterance, new_response, new_narrative, new_labels] 423 | new_file = output_path + "test.multi." + str(turn) + ".pkl" 424 | pickle.dump(new_data, open(new_file, "wb")) 425 | 426 | (acc, r2_1, r10_1, r10_2, r10_5, mrr), eva_loss, result = simple_evaluate(sess, model, iterator, utterance_ph, utterance_len_ph, narrative_ph, narrative_len_ph, response_ph, response_len_ph, y_true_ph, new_file) 427 | with open(output_path + "test.result.multi." + str(turn) + ".txt", "w") as fw: 428 | fw.write("R2@1: %f, R10@1: %f, R10@2: %f, R10@5: %f, MRR: %f\n" % (r2_1, r10_1, r10_2, r10_5, mrr)) 429 | for r in result: 430 | fw.write(str(r) + "\n") 431 | 432 | 433 | def train(load=False, model_path=None): 434 | best_val_loss = 100000.0 435 | config = tf.ConfigProto(allow_soft_placement=True) 436 | config.gpu_options.allow_growth = True 437 | epoch = 0 438 | with tf.Session(config=config) as sess: 439 | with open(embedding_file, 'rb') as f: 440 | embeddings = pickle.load(f, encoding="bytes") 441 | with open(train_file, 'rb') as f: 442 | utterance, response, narrative, labels = pickle.load(f) 443 | with open(val_file, "rb") as f: 444 | utterance_val, response_val, narrative_val, labels_val = pickle.load(f) 445 | 446 | state = np.random.get_state() 447 | np.random.shuffle(utterance) 448 | np.random.set_state(state) 449 | np.random.shuffle(response) 450 | np.random.set_state(state) 451 | np.random.shuffle(labels) 452 | np.random.set_state(state) 453 | np.random.shuffle(narrative) 454 | 455 | utterance_ph = tf.placeholder(tf.int32, shape=(None, max_num_utterance, max_sentence_len)) 456 | response_ph = tf.placeholder(tf.int32, shape=(None, max_sentence_len)) 457 | y_true_ph = tf.placeholder(tf.int32, shape=(None,)) 458 | response_len_ph = tf.placeholder(tf.int32, shape=(None,)) 459 | utterance_len_ph = tf.placeholder(tf.int32, shape=(None, max_num_utterance)) 460 | narrative_ph = tf.placeholder(tf.int32, shape=(None, max_sentence_len)) 461 | narrative_len_ph = tf.placeholder(tf.int32, shape=(None,)) 462 | 463 | utterance_train, utterance_len_train = Utils.multi_sequences_padding(utterance, max_sentence_len, max_num_utterance=max_num_utterance) 464 | utterance_train, utterance_len_train = np.array(utterance_train), np.array(utterance_len_train) 465 | response_len_train = np.array(Utils.get_sequences_length(response, maxlen=max_sentence_len)) 466 | response_train = np.array(pad_sequences(response, padding='post', maxlen=max_sentence_len)) 467 | narrative_len_train = np.array(Utils.get_sequences_length(narrative, maxlen=max_sentence_len)) 468 | narrative_train = np.array(pad_sequences(narrative, padding='post', maxlen=max_sentence_len)) 469 | y_true_train = np.array(labels) 470 | 471 | utterance_val, utterance_len_val = Utils.multi_sequences_padding(utterance_val, max_sentence_len, max_num_utterance=max_num_utterance) 472 | utterance_val, utterance_len_val = np.array(utterance_val), np.array(utterance_len_val) 473 | response_len_val = np.array(Utils.get_sequences_length(response_val, maxlen=max_sentence_len)) 474 | response_val = np.array(pad_sequences(response_val, padding='post', maxlen=max_sentence_len)) 475 | narrative_len_val = np.array(Utils.get_sequences_length(narrative_val, maxlen=max_sentence_len)) 476 | narrative_val = np.array(pad_sequences(narrative_val, padding='post', maxlen=max_sentence_len)) 477 | y_true_val = np.array(labels_val) 478 | 479 | dataset = tf.data.Dataset.from_tensor_slices((utterance_ph, utterance_len_ph, narrative_ph, narrative_len_ph, response_ph, response_len_ph, 480 | y_true_ph)).shuffle(1000) 481 | dataset = dataset.batch(batch_size) 482 | iterator = dataset.make_initializable_iterator() 483 | 484 | data_iterator = iterator.get_next() 485 | 486 | model = ScripteWriter(data_iterator) 487 | model.build() 488 | 489 | if load: 490 | sess = model.load(model_path) 491 | 492 | sess.run(tf.global_variables_initializer()) 493 | sess.run(model.embedding_init, feed_dict={model.embedding_ph: embeddings}) 494 | 495 | while epoch < 8: 496 | train_loss = 0.0 497 | sess.run(iterator.initializer, feed_dict={utterance_ph: utterance_train, 498 | utterance_len_ph: utterance_len_train, 499 | narrative_ph: narrative_train, 500 | narrative_len_ph: narrative_len_train, 501 | response_ph: response_train, 502 | response_len_ph: response_len_train, 503 | y_true_ph: y_true_train}) 504 | step = 0 505 | try: 506 | with tqdm(total=len(y_true_train)) as pbar: 507 | while True: 508 | _, loss, lr = sess.run([model.train_op, model.loss, model.learning_rate]) 509 | train_loss += loss 510 | pbar.set_postfix(learning_rate=lr, loss=loss) 511 | pbar.update(model.batch_size) 512 | step += 1 513 | except tf.errors.OutOfRangeError: 514 | pass 515 | 516 | val_loss = 0.0 517 | val_step = 0 518 | sess.run(iterator.initializer, feed_dict={utterance_ph: utterance_val, 519 | utterance_len_ph: utterance_len_val, 520 | narrative_ph: narrative_val, 521 | narrative_len_ph: narrative_len_val, 522 | response_ph: response_val, 523 | response_len_ph: response_len_val, 524 | y_true_ph: y_true_val}) 525 | try: 526 | while True: 527 | loss = sess.run(model.loss) 528 | val_loss += loss 529 | val_step += 1 530 | except tf.errors.OutOfRangeError: 531 | pass 532 | 533 | print('Epoch No: %d, the train loss is %f, the dev loss is %f' % (epoch + 1, train_loss / step, val_loss / val_step)) 534 | if val_loss / val_step < best_val_loss: 535 | best_val_loss = val_loss / val_step 536 | model.saver.save(sess, "./model/model.{0}".format(epoch + 1)) 537 | print("Save model.{}".format(epoch + 1)) 538 | epoch += 1 539 | 540 | if __name__ == "__main__": 541 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 542 | is_train = True 543 | previous_train_modelpath = "./model/" 544 | if is_train: 545 | train(False, previous_train_modelpath) 546 | else: 547 | # check the validation loss obtained in the training process and use the saved model with the smallest validation loss 548 | (acc, r2_1, r10_1, r10_2, r10_5, mrr), eva_loss, _ = evaluate(previous_train_modelpath, evaluate_file, output_path="./output/") 549 | print("Loss on test set: %f, Accuracy: %f, R2@1: %f, R10@1: %f, R10@2: %f, R10@5: %f, MRR: %f" % (eva_loss, acc, r2_1, r10_1, r10_2, r10_5, mrr)) 550 | 551 | # to evaluate multi-turn results, the vocab file is needed 552 | # evaluate_multi_turns(test_file=evaluate_file, model_path=previous_train_modelpath, output_path="./output/") 553 | # Evaluate.recover_and_show(basic_directory="./output/") 554 | # test_file = basic_directory + "test.result.multi.txt" 555 | # gold_file = "./data/ground_truth.result.mul.txt" 556 | # Evaluate.evaluate_multi_turn_result(test_file, gold_file) 557 | -------------------------------------------------------------------------------- /Utils.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import pickle 3 | import numpy as np 4 | from keras.preprocessing.sequence import pad_sequences 5 | from keras.preprocessing.text import text_to_word_sequence 6 | 7 | 8 | def multi_sequences_padding(all_sequences, max_sentence_len=50, max_num_utterance=10): 9 | PAD_SEQUENCE = [0] * max_sentence_len 10 | padded_sequences = [] 11 | sequences_length = [] 12 | for sequences in all_sequences: 13 | sequences_len = len(sequences) 14 | sequences_length.append(get_sequences_length(sequences, maxlen=max_sentence_len)) 15 | if sequences_len < max_num_utterance: 16 | sequences += [PAD_SEQUENCE] * (max_num_utterance - sequences_len) 17 | sequences_length[-1] += [0] * (max_num_utterance - sequences_len) 18 | else: 19 | sequences = sequences[-max_num_utterance:] 20 | sequences_length[-1] = sequences_length[-1][-max_num_utterance:] 21 | sequences = pad_sequences(sequences, padding='post', maxlen=max_sentence_len) 22 | padded_sequences.append(sequences) 23 | return padded_sequences, sequences_length 24 | 25 | 26 | def get_sequences_length(sequences, maxlen): 27 | sequences_length = [min(len(sequence), maxlen) for sequence in sequences] 28 | return sequences_length 29 | 30 | 31 | def generate_data_with_random_samples(): 32 | # generate negative samples randomly 33 | # In training set, for each sample, we randomly sample a response as a negative candidate 34 | # In development and test set, for each sample, we randomly sample 9 responses as negative candidates and we add a "EOS" response as a candidate to let model select when to stop 35 | import random 36 | import pickle 37 | vocab = {} 38 | positive_data = [] 39 | EOS_ID = 7 40 | with open("./data/sample_vocab.txt", "r", encoding="utf-8") as fr: 41 | for idx, line in enumerate(fr): 42 | line = line.strip().split("\t") 43 | vocab[line[0]] = idx + 1 44 | with open("./data/sample_data.txt", "r", encoding="utf-8") as fr: 45 | tmp = [] 46 | for line in fr: 47 | line = line.strip() 48 | if len(line) > 0: 49 | line = line.split("\t") 50 | if line[0] == "narrative": 51 | tmp.append(line[1]) 52 | elif line[0] == "script": 53 | tmp.append(line[1]) 54 | else: 55 | narrative = tmp[0] 56 | context = tmp[1:] 57 | narrative_id = [vocab.get(word, 0) for word in narrative.split()] 58 | context_id = [[vocab.get(word, 0) for word in sent.split()] for sent in context] 59 | if len(narrative_id) == 0 or len(context_id) == 0: 60 | continue 61 | data = [context_id, narrative_id, 1] 62 | positive_data.append(data) 63 | tmp = [] 64 | random.shuffle(positive_data) 65 | print("all suitable sessions: ", len(positive_data)) 66 | train_num = int(len(positive_data) * 0.9) 67 | dev_test_num = int(len(positive_data) * 0.05) 68 | train, dev, test = positive_data[:train_num], positive_data[train_num: train_num + dev_test_num], positive_data[train_num + dev_test_num:] 69 | train_all, dev_all, test_all = [], [], [] 70 | for context_id, narrative_id, _ in train: 71 | num_context = len(context_id) 72 | for i in range(1, num_context): 73 | context = context_id[:i] 74 | response = context_id[i] 75 | train_all.append([context, response, narrative_id, 1]) 76 | flag = True 77 | while flag: 78 | random_idx = random.randint(0, len(positive_data) - 1) 79 | random_context = positive_data[random_idx][0] 80 | random_idx_2 = random.randint(0, len(random_context) - 1) 81 | random_response = random_context[random_idx_2] 82 | if len(response) != len(random_response): 83 | flag = False 84 | train_all.append([context, random_response, narrative_id, 0]) 85 | else: 86 | for idx, wid in enumerate(response): 87 | if wid != random_response[idx]: 88 | flag = False 89 | train_all.append([context, random_response, narrative_id, 0]) 90 | break 91 | print(train_all[0], train_all[1]) 92 | for context_id, narrative_id, _ in dev: 93 | num_context = len(context_id) 94 | for i in range(1, num_context): 95 | context = context_id[:i] 96 | response = context_id[i] 97 | dev_all.append([context, response, narrative_id, 1]) 98 | count = 0 99 | negative_samples = [] 100 | while count < 9: 101 | random_idx = random.randint(0, len(positive_data) - 1) 102 | random_context = positive_data[random_idx][0] 103 | random_idx_2 = random.randint(0, len(random_context) - 1) 104 | random_response = random_context[random_idx_2] 105 | if random_response not in negative_samples and random_response != [EOS_ID]: 106 | if len(response) != len(random_response): 107 | dev_all.append([context, random_response, narrative_id, 0]) 108 | count += 1 109 | negative_samples.append(random_response) 110 | else: 111 | for idx, wid in enumerate(response): 112 | if wid != random_response[idx]: 113 | dev_all.append([context, random_response, narrative_id, 0]) 114 | count += 1 115 | negative_samples.append(random_response) 116 | break 117 | if response == [EOS_ID]: 118 | dev_all.append([context, [EOS_ID], narrative_id, 1]) 119 | else: 120 | dev_all.append([context, [EOS_ID], narrative_id, 0]) 121 | print(dev_all[0], dev_all[1], dev_all[2]) 122 | for context_id, narrative_id, _ in test: 123 | num_context = len(context_id) 124 | for i in range(1, num_context): 125 | context = context_id[:i] 126 | response = context_id[i] 127 | test_all.append([context, response, narrative_id, 1]) 128 | count = 0 129 | negative_samples = [] 130 | while count < 9: 131 | random_idx = random.randint(0, len(positive_data) - 1) 132 | random_context = positive_data[random_idx][0] 133 | random_idx_2 = random.randint(0, len(random_context) - 1) 134 | random_response = random_context[random_idx_2] 135 | if random_response not in negative_samples and random_response != [EOS_ID]: 136 | if len(response) != len(random_response): 137 | test_all.append([context, random_response, narrative_id, 0]) 138 | negative_samples.append(random_response) 139 | count += 1 140 | else: 141 | for idx, id in enumerate(response): 142 | if id != random_response[idx]: 143 | test_all.append([context, random_response, narrative_id, 0]) 144 | negative_samples.append(random_response) 145 | count += 1 146 | break 147 | if response == [EOS_ID]: 148 | test_all.append([context, [EOS_ID], narrative_id, 1]) 149 | else: 150 | test_all.append([context, [EOS_ID], narrative_id, 0]) 151 | print(test_all[0], test_all[1], test_all[2]) 152 | context, response, narrative, label = [], [], [], [] 153 | print("train size: ", len(train_all)) 154 | for data in train_all: 155 | context.append(data[0]) 156 | response.append(data[1]) 157 | narrative.append(data[2]) 158 | label.append(data[3]) 159 | train = [context, response, narrative, label] 160 | pickle.dump(train, open("./data/train.multi.pkl", "wb")) 161 | context, response, narrative, label = [], [], [], [] 162 | print("dev size: ", len(dev_all)) 163 | for data in dev_all: 164 | context.append(data[0]) 165 | response.append(data[1]) 166 | narrative.append(data[2]) 167 | label.append(data[3]) 168 | dev = [context, response, narrative, label] 169 | pickle.dump(dev, open("./data/dev.multi.pkl", "wb")) 170 | context, response, narrative, label = [], [], [], [] 171 | print("test size: ", len(test_all)) 172 | for data in test_all: 173 | context.append(data[0]) 174 | response.append(data[1]) 175 | narrative.append(data[2]) 176 | label.append(data[3]) 177 | test = [context, response, narrative, label] 178 | pickle.dump(test, open("./data/test.multi.pkl", "wb")) 179 | 180 | 181 | def generate_data_with_solr_samples(): 182 | # generate negative samples from solr 183 | # this is only for development and test set since training set has only one negative sample 184 | import pickle 185 | import pysolr 186 | import jieba 187 | 188 | EOS_ID = 7 189 | 190 | def query_comt(post, num): 191 | # the format of the solr data: "ut1: xxxxx, ut2: xxxxx", where ut1 is the index 192 | solr = pysolr.Solr('xxxxxx', timeout=10) # write your Solr address 193 | post = "ut1:(" + post + ")" 194 | results = solr.search(q=post, **{'rows': num}) # rows equal to the number of pairs you want to retrieve 195 | return results 196 | 197 | vocab = {} 198 | vocab_id2word = {} 199 | 200 | with open("./data/vocab.txt", "r", encoding="utf-8") as fr: 201 | for idx, line in enumerate(fr): 202 | line = line.strip().split("\t") 203 | vocab[line[0]] = idx + 1 204 | vocab_id2word[idx + 1] = line[0] 205 | 206 | dev = pickle.load(open("./data/dev.multi.pkl", "rb")) 207 | context, response, narrative, label = dev[0], dev[1], dev[2], dev[3] 208 | num = len(response) 209 | dev_all = [] 210 | for i in range(num): 211 | # One positive sample, nine negative samples and a "EOS" sample. 1 + 9 + 1 = 11 212 | if i % 11 == 0 and int(label[i]) == 1: 213 | count = 0 214 | context_ = context[i] 215 | pos_response = "".join([vocab_id2word[x] for x in response[i]]) 216 | last_ut = "".join([vocab_id2word[x] for x in context_[-1]]).replace(".", "").replace("?", "").replace("\"", "").replace(":", "") 217 | dev_all.append([context[i], response[i], narrative[i], 1]) 218 | negative_samples = query_comt(last_ut, 15) 219 | for result in negative_samples: 220 | if result['ut2'] != pos_response: 221 | negtive_sample = [vocab[x] for x in jieba.lcut(result['ut2']) if x != ' ' and x != '\xa0' and x != '\u3000'] 222 | dev_all.append([context[i], negtive_sample, narrative[i], 0]) 223 | count += 1 224 | if count == 8: 225 | break 226 | if count != 8: 227 | last = 8 - count 228 | for j in range(last): 229 | negtive_sample = response[i + 1 + j] 230 | dev_all.append([context[i], negtive_sample, narrative[i], 0]) 231 | if response[i] == [EOS_ID]: 232 | dev_all.append([context[i], [EOS_ID], narrative[i], 1]) 233 | else: 234 | dev_all.append([context[i], [EOS_ID], narrative[i], 0]) 235 | 236 | test = pickle.load(open("./data/test.multi.pkl", "rb")) 237 | context, response, narrative, label = test[0], test[1], test[2], test[3] 238 | num = len(response) 239 | test_all = [] 240 | print("start test") 241 | for i in range(num): 242 | if i % 11 == 0 and int(label[i]) == 1: 243 | count = 0 244 | context_ = context[i] 245 | pos_response = "".join([vocab_id2word[x] for x in response[i]]) 246 | last_ut = "".join([vocab_id2word[x] for x in context_[-1]]).replace(".", "").replace("?", "").replace("\"", "").replace(":", "") 247 | test_all.append([context[i], response[i], narrative[i], 1]) 248 | negative_samples = query_comt(last_ut, 15) 249 | for result in negative_samples: 250 | if result['ut2'] != pos_response: 251 | negtive_sample = [vocab[x] for x in jieba.lcut(result['ut2']) if x != ' ' and x != '\xa0' and x != '\u3000'] 252 | test_all.append([context[i], negtive_sample, narrative[i], 0]) 253 | count += 1 254 | if count == 8: 255 | break 256 | if count != 8: 257 | last = 8 - count 258 | for j in range(last): 259 | negtive_sample = response[i + 1 + j] 260 | test_all.append([context[i], negtive_sample, narrative[i], 0]) 261 | if response[i] == [EOS_ID]: 262 | test_all.append([context[i], [EOS_ID], narrative[i], 1]) 263 | else: 264 | test_all.append([context[i], [EOS_ID], narrative[i], 0]) 265 | 266 | context, response, narrative, label = [], [], [], [] 267 | print("dev size: ", len(dev_all)) 268 | for data in dev_all: 269 | context.append(data[0]) 270 | response.append(data[1]) 271 | narrative.append(data[2]) 272 | label.append(data[3]) 273 | dev = [context, response, narrative, label] 274 | pickle.dump(dev, open("./data/dev.pkl", "wb")) 275 | context, response, narrative, label = [], [], [], [] 276 | print("test size: ", len(test_all)) 277 | for data in test_all: 278 | context.append(data[0]) 279 | response.append(data[1]) 280 | narrative.append(data[2]) 281 | label.append(data[3]) 282 | test = [context, response, narrative, label] 283 | pickle.dump(test, open("./data/test.pkl", "wb")) 284 | -------------------------------------------------------------------------------- /data/readme.md: -------------------------------------------------------------------------------- 1 | Unzip data.zip to this directory. 2 | -------------------------------------------------------------------------------- /data/sample_data.txt: -------------------------------------------------------------------------------- 1 | narrative 珍妮不喜欢回家。为了陪珍妮,甘决定晚点回家。珍妮是甘最好的朋友 2 | script 妈会担心我的 3 | script 再坐一会! 4 | script 珍妮为了某种理由不想回家 5 | script 好,珍妮,我留下来 6 | script 她是我最特别的朋友 7 | script EOS 8 | 9 | narrative 珍妮不喜欢回家。为了陪珍妮,甘决定晚点回家。珍妮是甘最好的朋友 10 | script 妈会担心我的 11 | script 再坐一会! 12 | script 珍妮为了某种理由不想回家 13 | script 好,珍妮,我留下来 14 | script 她是我最特别的朋友 15 | script EOS 16 | 17 | narrative 珍妮不喜欢回家。为了陪珍妮,甘决定晚点回家。珍妮是甘最好的朋友 18 | script 妈会担心我的 19 | script 再坐一会! 20 | script 珍妮为了某种理由不想回家 21 | script 好,珍妮,我留下来 22 | script 她是我最特别的朋友 23 | script EOS 24 | 25 | narrative 珍妮不喜欢回家。为了陪珍妮,甘决定晚点回家。珍妮是甘最好的朋友 26 | script 妈会担心我的 27 | script 再坐一会! 28 | script 珍妮为了某种理由不想回家 29 | script 好,珍妮,我留下来 30 | script 她是我最特别的朋友 31 | script EOS 32 | 33 | narrative 珍妮不喜欢回家。为了陪珍妮,甘决定晚点回家。珍妮是甘最好的朋友 34 | script 妈会担心我的 35 | script 再坐一会! 36 | script 珍妮为了某种理由不想回家 37 | script 好,珍妮,我留下来 38 | script 她是我最特别的朋友 39 | script EOS 40 | -------------------------------------------------------------------------------- /data/sample_data_english.txt: -------------------------------------------------------------------------------- 1 | narrative xxxxxxxxxxxxxxx 2 | script xxxxxxxxxxxx 3 | script xxxxxxxxxxxx 4 | script xxxxxxxxxxxx 5 | script xxxxxxxxxxxx 6 | script xxxxxxxxxxxx 7 | script xxxxxxxxxxxx 8 | 9 | narrative xxxxxxxxxxxxxxx 10 | script xxxxxxxxxxxx 11 | script xxxxxxxxxxxx 12 | script xxxxxxxxxxxx 13 | script xxxxxxxxxxxx 14 | script xxxxxxxxxxxx 15 | script xxxxxxxxxxxx 16 | 17 | narrative xxxxxxxxxxxxxxx 18 | script xxxxxxxxxxxx 19 | script xxxxxxxxxxxx 20 | script xxxxxxxxxxxx 21 | script xxxxxxxxxxxx 22 | script xxxxxxxxxxxx 23 | script xxxxxxxxxxxx 24 | 25 | narrative xxxxxxxxxxxxxxx 26 | script xxxxxxxxxxxx 27 | script xxxxxxxxxxxx 28 | script xxxxxxxxxxxx 29 | script xxxxxxxxxxxx 30 | script xxxxxxxxxxxx 31 | script xxxxxxxxxxxx -------------------------------------------------------------------------------- /data/sample_vocab.txt: -------------------------------------------------------------------------------- 1 | , 50435 2 | 的 37684 3 | 。 27141 4 | 了 22941 5 | 我 21728 6 | 你 20009 7 | EOS 16109 8 | 他 12878 9 | 是 12859 10 | ? 12549 11 | -------------------------------------------------------------------------------- /data/sample_vocab_english.txt: -------------------------------------------------------------------------------- 1 | , 50435 2 | I 37684 3 | . 27141 4 | you 22941 5 | my 21728 6 | yours 20009 7 | EOS 16109 8 | he 12878 9 | is 12859 10 | ? 12549 -------------------------------------------------------------------------------- /image/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaoD/ScriptWriter/89808c9c8d3c7e4114643209c44c5e37602781d1/image/example.png -------------------------------------------------------------------------------- /image/matching.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaoD/ScriptWriter/89808c9c8d3c7e4114643209c44c5e37602781d1/image/matching.png -------------------------------------------------------------------------------- /image/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaoD/ScriptWriter/89808c9c8d3c7e4114643209c44c5e37602781d1/image/result.png -------------------------------------------------------------------------------- /image/statistics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaoD/ScriptWriter/89808c9c8d3c7e4114643209c44c5e37602781d1/image/statistics.png -------------------------------------------------------------------------------- /image/update.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaoD/ScriptWriter/89808c9c8d3c7e4114643209c44c5e37602781d1/image/update.png -------------------------------------------------------------------------------- /model/readme.md: -------------------------------------------------------------------------------- 1 | The trained model is saved here. 2 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def normalize(inputs, epsilon=1e-8, scope="normalize", reuse=None): 6 | with tf.variable_scope(scope, reuse=reuse): 7 | inputs_shape = inputs.get_shape() 8 | params_shape = inputs_shape[-1:] 9 | mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True) 10 | beta = tf.get_variable(name='beta', shape=params_shape, dtype=tf.float32, initializer=tf.zeros_initializer()) 11 | gamma = tf.get_variable(name='scale', shape=params_shape, dtype=tf.float32, initializer=tf.ones_initializer()) 12 | normalized = (inputs - mean) / ((variance + epsilon) ** .5) 13 | outputs = gamma * normalized + beta 14 | return outputs 15 | 16 | 17 | def embedding(inputs, vocab_size=None, embedding_size=None, zero_pad=False, scale=False, scope="embedding", reuse=None, initializer=None): 18 | with tf.variable_scope(scope, reuse=reuse): 19 | if initializer: 20 | lookup_table = initializer 21 | else: 22 | lookup_table = tf.get_variable('lookup_table', dtype=tf.float32, shape=[vocab_size, embedding_size], initializer=tf.contrib.layers.xavier_initializer()) 23 | if zero_pad: 24 | lookup_table = tf.concat((tf.zeros(shape=[1, embedding_size]), lookup_table[1:, :]), 0) 25 | outputs = tf.nn.embedding_lookup(lookup_table, inputs) 26 | if scale: 27 | outputs = outputs * (embedding_size ** 0.5) 28 | return outputs 29 | 30 | 31 | def positional_encoding(inputs, num_units, max_len, zero_pad=True, scale=True, scope="positional_encoding", reuse=None): 32 | # N, T = inputs.get_shape().as_list() 33 | # N, T = tf.shape(inputs) 34 | inputs_shape = tf.shape(inputs) 35 | N = inputs_shape[0] 36 | T_real = inputs_shape[1] 37 | T = max_len 38 | with tf.variable_scope(scope, reuse=reuse, dtype=tf.float32): 39 | position_ind = tf.tile(tf.expand_dims(tf.range(T_real), 0), [N, 1]) 40 | position_enc = np.array([[pos / np.power(10000, 2 * (i // 2) /num_units) for i in range(num_units)] for pos in range(T)], dtype=np.float32) 41 | position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) 42 | position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) 43 | lookup_table = tf.convert_to_tensor(position_enc) 44 | if zero_pad: 45 | lookup_table = tf.concat((tf.zeros(shape=[1, num_units]), lookup_table[1:, :]), 0) 46 | outputs = tf.nn.embedding_lookup(lookup_table, position_ind) 47 | if scale: 48 | outputs = outputs * (num_units ** 0.5) 49 | return outputs 50 | 51 | 52 | def multihead_attention(queries, keys, num_units=None, num_heads=8, dropout_rate=0, is_training=True, causality=False, scope="multihead_attention", reuse=None): 53 | with tf.variable_scope(scope, reuse=reuse): 54 | if num_units is None: 55 | num_units = queries.get_shape().as_list()[-1] 56 | Q = tf.layers.dense(queries, num_units, activation=tf.nn.relu) # (N, T_q, C) 57 | K = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C) 58 | V = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C) 59 | 60 | Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, C/h) 61 | K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, C/h) 62 | V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, C/h) 63 | 64 | outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) # (h*N, T_q, T_k) 65 | outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5) 66 | key_masks = tf.sign(tf.abs(tf.reduce_sum(keys, axis=-1))) # (N, T_k) 67 | key_masks = tf.tile(key_masks, [num_heads, 1]) # (h*N, T_k) 68 | key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(queries)[1], 1]) # (h*N, T_q, T_k) 69 | paddings = tf.ones_like(outputs) * (-2 ** 32 + 1) 70 | outputs = tf.where(tf.equal(key_masks, 0), paddings, outputs) # (h*N, T_q, T_k) 71 | if causality: 72 | diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k) 73 | tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k) 74 | masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k) 75 | paddings = tf.ones_like(masks) * (-2 ** 32 + 1) 76 | outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k) 77 | outputs = tf.nn.softmax(outputs) # (h*N, T_q, T_k) 78 | attn_weights = outputs 79 | query_masks = tf.sign(tf.abs(tf.reduce_sum(queries, axis=-1))) # (N, T_q) 80 | query_masks = tf.tile(query_masks, [num_heads, 1]) # (h*N, T_q) 81 | query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, tf.shape(keys)[1]]) # (h*N, T_q, T_k) 82 | outputs *= query_masks # broadcasting. (N, T_q, C) 83 | outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=tf.convert_to_tensor(is_training)) 84 | outputs = tf.matmul(outputs, V_) # ( h*N, T_q, C/h) 85 | outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2) # (N, T_q, C) 86 | outputs += queries 87 | outputs = normalize(outputs, scope=scope) # (N, T_q, C) 88 | 89 | return outputs, attn_weights 90 | 91 | 92 | def feedforward(inputs, num_units=[2048, 512], scope="feed_forward", reuse=None): 93 | with tf.variable_scope(scope, reuse=reuse): 94 | params = {"inputs": inputs, "filters": num_units[0], "kernel_size": 1, "activation": tf.nn.relu, "use_bias": True} 95 | outputs = tf.layers.conv1d(**params) 96 | params = {"inputs": outputs, "filters": num_units[1], "kernel_size": 1, "activation": None, "use_bias": True} 97 | outputs = tf.layers.conv1d(**params) 98 | outputs += inputs 99 | outputs = normalize(outputs, scope=scope) 100 | return outputs 101 | 102 | 103 | def label_smoothing(inputs, epsilon=0.1): 104 | K = inputs.get_shape().as_list()[-1] # number of channels 105 | return ((1 - epsilon) * inputs) + (epsilon / K) 106 | 107 | -------------------------------------------------------------------------------- /output/readme.md: -------------------------------------------------------------------------------- 1 | The output is saved here. 2 | --------------------------------------------------------------------------------