├── Evaluate.py
├── Readme.md
├── ScriptWriter-CPre.py
├── ScriptWriter.py
├── Utils.py
├── data
├── readme.md
├── sample_data.txt
├── sample_data_english.txt
├── sample_vocab.txt
└── sample_vocab_english.txt
├── image
├── example.png
├── matching.png
├── result.png
├── statistics.png
└── update.png
├── model
└── readme.md
├── modules.py
└── output
└── readme.md
/Evaluate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def compute_r_n_m(scores, labels, count, at):
4 | total = 0
5 | correct = 0
6 | for i in range(len(labels)):
7 | if i % 10 == 0:
8 | total = total + 1
9 | sublist = scores[i:i + count]
10 | pos_score = sublist[0]
11 | sublist = sorted(sublist, key=lambda x: x, reverse=True)
12 | if sublist[at - 1] <= pos_score:
13 | correct += 1
14 | return float(correct) / total
15 |
16 | def compute_mrr(scores, labels, count=10):
17 | total = 0
18 | accumulate_mrr = 0
19 | for i in range(len(labels)):
20 | if i % 10 == 0:
21 | total = total + 1
22 | sublist = scores[i:i + count]
23 | arg_sort = list(np.argsort(sublist)).index(0)
24 | idx = len(sublist) - arg_sort
25 | accumulate_mrr += 1 / idx
26 | return float(accumulate_mrr) / total
27 |
28 | def compute_acc(scores, labels):
29 | scores = (np.asarray(scores) > 0.5).astype(np.int32)
30 | accuracy = sum((scores == labels).astype(np.int32)) / len(labels)
31 | return accuracy
32 |
33 | def evaluate_all(scores, labels):
34 | return compute_acc(scores, labels), compute_r_n_m(scores, labels, 2, 1), compute_r_n_m(scores, labels, 10, 1), compute_r_n_m(scores, labels, 10, 2), \
35 | compute_r_n_m(scores, labels, 10, 5), compute_mrr(scores, labels)
36 |
37 | def evaluate_all_from_file(path):
38 | scores = []
39 | labels = []
40 | with open(path, "r") as f:
41 | for line in f:
42 | score, label = line.strip().split("\t")
43 | scores.append(float(score))
44 | labels.append(float(labels))
45 | evaluate_all(scores, labels)
46 |
47 | def recover_and_show(basic_directory):
48 | import pickle
49 |
50 | vocab = {}
51 | vocab_id2word = {}
52 |
53 | with open("./data/vocab.txt", "r", encoding="utf-8") as fr:
54 | for idx, line in enumerate(fr):
55 | line = line.strip().split("\t")
56 | vocab[line[0]] = idx + 1
57 | vocab_id2word[idx + 1] = line[0]
58 |
59 | vocab["_PAD_"] = 0
60 | vocab_id2word[0] = "_PAD_"
61 |
62 | def initialize():
63 | all_outline = []
64 | outline_dict = {}
65 | initial_file = basic_directory + "test.multi.0.pkl"
66 | with open(initial_file, 'rb') as f:
67 | _, _, outline, _ = pickle.load(f)
68 | for o in outline:
69 | if o not in all_outline:
70 | all_outline.append(o)
71 | sent_o = "".join([vocab_id2word[x] for x in o])
72 | outline_dict[sent_o] = []
73 | return all_outline, outline_dict
74 |
75 | max_turn = 10
76 | all_outline, outline_dict = initialize()
77 |
78 | for turn in range(0, max_turn + 1):
79 | score = []
80 | with open(basic_directory + "test.result.multi." + str(turn) + ".txt", "r") as fr:
81 | for idx, line in enumerate(fr):
82 | if idx == 0:
83 | continue
84 | score.append(float(line.strip()))
85 | with open(basic_directory + "test.multi." + str(turn) + ".pkl", "rb") as fr:
86 | utterance, response, outline, labels = pickle.load(fr) # except for dl2r
87 |
88 | for i, o in enumerate(outline):
89 | if i % 10 == 0:
90 | score_sub_list = score[i:i + 10]
91 | response_sub_list = response[i:i + 10]
92 | max_idx = score_sub_list.index(max(score_sub_list))
93 | selected_response = response_sub_list[max_idx]
94 | sent_o = "".join([vocab_id2word[x] for x in o])
95 | outline_dict[sent_o] = utterance[i] + [selected_response] # for MUSwO
96 |
97 | with open(basic_directory + "test.result.multi.txt", "w", encoding="utf-8") as fw:
98 | for o in all_outline:
99 | sent_o = "".join([vocab_id2word[x] for x in o])
100 | utterance = outline_dict[sent_o]
101 | fw.write("outline\t" + sent_o + "\n")
102 | for u in utterance:
103 | sent_u = "".join([vocab_id2word[x] for x in u])
104 | fw.write("script\t" + sent_u + "\n")
105 | fw.write("\n")
106 |
107 | def evaluate_multi_turn_result(t_file, g_file):
108 | test_file = t_file
109 | gold_file = g_file
110 | ft = open(test_file, "r", encoding="utf-8")
111 | fg = open(gold_file, "r", encoding="utf-8")
112 | t, g = ft.readline(), fg.readline()
113 | c = -1
114 | s = -1
115 | all_s = 0
116 | r_all = 0
117 | tmp_r_all = 0
118 | o = None
119 | flag = 1
120 | t_s = []
121 | g_s = []
122 | while t and g:
123 | t = t.strip().split("\t")
124 | g = g.strip().split("\t")
125 | if len(t) > 1:
126 | if t[0] == "script":
127 | if t[1] == g[1]:
128 | s += 1
129 | c += 1
130 | t_s.append(t[1])
131 | g_s.append(g[1])
132 | if t[0] == "outline":
133 | o = t[1]
134 | else:
135 | tmp_c = -1
136 | tmp_s = -1
137 | for at in g_s:
138 | tmp_c += 1
139 | if at in t_s:
140 | tmp_s += 1
141 | tmp_r = tmp_s / tmp_c
142 | tmp_r_all += tmp_r
143 | r = s / c
144 | r_all += r
145 | all_s += 1
146 | c = -1
147 | s = -1
148 | t_s = []
149 | g_s = []
150 | t, g = ft.readline(), fg.readline()
151 |
152 | tmp_c = -1
153 | tmp_s = -1
154 | for at in g_s:
155 | tmp_c += 1
156 | if at in t_s:
157 | tmp_s += 1
158 | tmp_r = tmp_s / tmp_c
159 | tmp_r_all += tmp_r
160 | r = s / c
161 | r_all += r
162 | all_s += 1
163 |
164 | r_all += r
165 | print("p_strict", r_all / all_s)
166 | print("p_weak", tmp_r_all / all_s)
167 |
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | # ScriptWriter: Narrative-Guided Script Generation
2 |
3 | [](#python)
4 |
5 | #### News
6 | - 2022-1-5: Our new work (ScriptWriter-CPre) has been accepted by TOIS!
7 | - 2021-10-19: We upload the code and data for our new model ScriptWriter-CPre.
8 | - 2020-06-11: We find a minor error in the data thus we upload a new one. We provide the code for building data file from text file. You can now generate your own data with the Utils.py.
9 | - 2020-06-09: We upload the code and data. Note that we do not share the vocab.txt in the data currentely due to the copyright issue. We will upload it as soon as possible.
10 |
11 | ## Abstract
12 | This repository contains the source code and datasets for the ACL 2020 paper [ScriptWriter: Narrative-Guided Script Generation](https://www.aclweb.org/anthology/2020.acl-main.765.pdf) and TOIS paper [Leveraging Narrative to Generate Movie Script](https://dl.acm.org/doi/pdf/10.1145/3507356) by Zhu et al.
13 |
14 | It is appealing to have a system that generates a story or scripts automatically from a storyline, even though this is still out of our reach. In dialogue systems, it would also be useful to drive dialogues by a dialogue plan. In this paper, we address a key problem involved in these applications - guiding a dialogue by a narrative. The proposed model ScriptWriter selects the best response among the candidates that fit the context as well as the given narrative. It keeps track of what in the narrative has been said and what is to be said. A narrative plays a different role than the context (i.e., previous utterances), which is generally used in current dialogue systems. Due to the unavailability of data for this new application, we construct a new large-scale data collection GraphMovie from a movie website where end-users can upload their narratives freely when watching a movie. Experimental results on the dataset show that our proposed approach based on narratives significantly outperforms the baselines that simply use the narrative as a kind of context.
15 |
16 | Authors: Yutao Zhu, Ruihua Song, Zhicheng Dou, Jian-Yun Nie, Jin Zhou
17 |
18 | ## Requirements
19 | We test the code with the following packages.
20 | - Python 3.5
21 | - Tensorflow 1.5 (with GPU support)
22 | - Keras 2.2.4
23 |
24 | ## Usage
25 | - Unzip the compressed [data file](https://drive.google.com/file/d/1fJKI9fzUhPM2dKq2zAFWLbtltv6PT2wh/view?usp=sharing) to the data directory.
26 | - python3 ScriptWriter.py (or python3 ScriptWriter-CPre.py)
27 |
28 | ## Results
29 | | Model | R2@1 | R10@1 | R10@2 | R10@5 | MRR | P_strict | P_weak |
30 | | ----------------- | ----- | ----- | ----- | ----- | ----- | -------- | ------ |
31 | | MVLSTM | 0.651 | 0.217 | 0.384 | 0.732 | 0.395 | 0.198 | 0.224 |
32 | | DL2R | 0.643 | 0.210 | 0.321 | 0.638 | 0.314 | 0.230 | 0.243 |
33 | | SMN | 0.641 | 0.176 | 0.333 | 0.696 | 0.392 | 0.197 | 0.236 |
34 | | DAM | 0.631 | 0.240 | 0.398 | 0.733 | 0.408 | 0.226 | 0.236 |
35 | | DUA | 0.654 | 0.237 | 0.403 | 0.736 | 0.396 | 0.223 | 0.251 |
36 | | IMN | 0.686 | 0.301 | 0.450 | 0.759 | 0.463 | 0.304 | 0.325 |
37 | | IOI | 0.710 | 0.341 | 0.491 | 0.774 | 0.464 | 0.324 | 0.337 |
38 | | MSN | 0.724 | 0.329 | 0.511 | 0.794 | 0.464 | 0.314 | 0.346 |
39 | | ScriptWriter | 0.730 | 0.365 | 0.537 | 0.814 | 0.503 | 0.373 | 0.383 |
40 | | ScriptWriter-CPre | 0.756 | 0.398 | 0.557 | 0.817 | 0.504 | 0.392 | 0.409 |
41 |
42 | ## Citations
43 | If you use the code and datasets, please cite the following paper:
44 | ```
45 | @inproceedings{zhu-etal-2020-scriptwriter,
46 | title = {{S}cript{W}riter: Narrative-Guided Script Generation},
47 | author = {Zhu, Yutao and
48 | Song, Ruihua and
49 | Dou, Zhicheng and
50 | Nie, Jian-Yun and
51 | Zhou, Jin},
52 | booktitle = {Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics},
53 | month = jul,
54 | year = {2020},
55 | address = {Online},
56 | publisher = {Association for Computational Linguistics},
57 | url = {https://www.aclweb.org/anthology/2020.acl-main.765},
58 | pages = {8647--8657}
59 | }
60 | @article{zhu-etal-2022-leverage,
61 | title = {Leveraging Narrative to Generate Movie Script},
62 | author = {Zhu, Yutao and
63 | Song, Ruihua and
64 | Nie, Jian-Yun and
65 | Du, Pan and
66 | Dou, Zhicheng and
67 | Zhou, Jin},
68 | year = {2022},
69 | issue_date = {October 2022},
70 | publisher = {Association for Computing Machinery},
71 | address = {New York, NY, USA},
72 | volume = {40},
73 | number = {4},
74 | issn = {1046-8188},
75 | url = {https://doi.org/10.1145/3507356},
76 | doi = {10.1145/3507356},
77 | journal = {ACM Trans. Inf. Syst.},
78 | month = {mar},
79 | articleno = {86},
80 | numpages = {32}
81 | }
82 |
83 | ```
84 |
--------------------------------------------------------------------------------
/ScriptWriter-CPre.py:
--------------------------------------------------------------------------------
1 | from modules import *
2 | from keras.preprocessing.sequence import pad_sequences
3 | import Utils
4 | import Evaluate
5 | import pickle
6 | import os
7 | from tqdm import tqdm
8 | import logging
9 |
10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
11 |
12 | # for train
13 | embedding_file = "./data/embeddings.pkl"
14 | train_file = "./data/train.gr.pkl"
15 | val_file = "./data/dev.gr.pkl"
16 | evaluate_file = "./data/test.gr.pkl"
17 |
18 | save_path = "./model/cpre/"
19 | result_path = "./output/cpre/"
20 | log_path = "./model/cpre/"
21 |
22 | max_sentence_len = 50
23 | max_num_utterance = 11
24 | batch_size = 50
25 | eval_batch_size = 100
26 |
27 | class ScriptWriter_cpre():
28 | def __init__(self, eta=0.5):
29 | self.max_num_utterance = max_num_utterance
30 | self.negative_samples = 1
31 | self.max_sentence_len = max_sentence_len
32 | self.word_embedding_size = 200
33 | self.hidden_units = 200
34 | self.total_words = 43514
35 | self.batch_size = batch_size
36 | self.eval_batch_size = eval_batch_size
37 | self.learning_rate_ph = tf.placeholder(tf.float32, shape=[], name='learning_rate')
38 | self.dropout_rate = 0
39 | self.num_heads = 1
40 | self.num_blocks = 3
41 | self.eta = eta
42 | self.gamma = tf.get_variable('gamma', shape=1, dtype=tf.float32, trainable=True, initializer=tf.constant_initializer(0.5))
43 |
44 | self.embedding_ph = tf.placeholder(tf.float32, shape=(self.total_words, self.word_embedding_size))
45 | self.utterance_ph = tf.placeholder(tf.int32, shape=(None, max_num_utterance, max_sentence_len))
46 | self.response_ph = tf.placeholder(tf.int32, shape=(None, max_sentence_len))
47 | self.gt_response_ph = tf.placeholder(tf.int32, shape=(None, max_sentence_len))
48 | self.y_true_ph = tf.placeholder(tf.int32, shape=(None,))
49 | self.narrative_ph = tf.placeholder(tf.int32, shape=(None, max_sentence_len))
50 |
51 | self.word_embeddings = tf.get_variable('word_embeddings_v', shape=(self.total_words, self.word_embedding_size), dtype=tf.float32, trainable=False)
52 | self.embedding_init = self.word_embeddings.assign(self.embedding_ph)
53 | self.global_step = tf.Variable(0, name='global_step', trainable=False)
54 | self.is_training = True
55 | print("current eta: ", self.eta)
56 |
57 | def load(self, previous_modelpath):
58 | sess = tf.Session()
59 | latest_ckpt = tf.train.latest_checkpoint(previous_modelpath)
60 | # print("recover from checkpoint: " + latest_ckpt)
61 | variables = tf.contrib.framework.get_variables_to_restore()
62 | saver = tf.train.Saver(variables)
63 | saver.restore(sess, latest_ckpt)
64 | return sess
65 |
66 | def build(self):
67 | all_utterances = tf.unstack(self.utterance_ph, num=self.max_num_utterance, axis=1)
68 | reuse = None
69 | alpha_1, alpha_2 = None, None
70 |
71 | response_embeddings = embedding(self.response_ph, initializer=self.word_embeddings)
72 | Hr_stack = [response_embeddings]
73 | for i in range(self.num_blocks):
74 | with tf.variable_scope("num_blocks_{}".format(i)):
75 | response_embeddings, _ = multihead_attention(queries=response_embeddings, keys=response_embeddings, num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate)
76 | response_embeddings = feedforward(response_embeddings, num_units=[self.hidden_units, self.hidden_units])
77 | Hr_stack.append(response_embeddings)
78 |
79 | gt_response_embeddings = embedding(self.gt_response_ph, initializer=self.word_embeddings)
80 | Hgtr_stack = [gt_response_embeddings]
81 | for i in range(self.num_blocks):
82 | with tf.variable_scope("num_blocks_{}".format(i), reuse=True):
83 | gt_response_embeddings, _ = multihead_attention(queries=gt_response_embeddings, keys=gt_response_embeddings, num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate)
84 | gt_response_embeddings = feedforward(gt_response_embeddings, num_units=[self.hidden_units, self.hidden_units])
85 | Hgtr_stack.append(gt_response_embeddings)
86 |
87 | narrative_embeddings = embedding(self.narrative_ph, initializer=self.word_embeddings)
88 | Hn_stack = [narrative_embeddings]
89 | for i in range(self.num_blocks):
90 | with tf.variable_scope("num_blocks_{}".format(i), reuse=True):
91 | narrative_embeddings, _ = multihead_attention(queries=narrative_embeddings, keys=narrative_embeddings, num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate)
92 | narrative_embeddings = feedforward(narrative_embeddings, num_units=[self.hidden_units, self.hidden_units])
93 | Hn_stack.append(narrative_embeddings)
94 |
95 | Mur, Mun = [], []
96 | self.decay_factor = []
97 | last_u_reps = []
98 | turn_id = 0
99 | for utterance in all_utterances:
100 | utterance_embeddings = embedding(utterance, initializer=self.word_embeddings)
101 | Hu_stack = [utterance_embeddings]
102 | for i in range(self.num_blocks):
103 | with tf.variable_scope("num_blocks_{}".format(i), reuse=True):
104 | utterance_embeddings, _ = multihead_attention(queries=utterance_embeddings, keys=utterance_embeddings, num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate)
105 | utterance_embeddings = feedforward(utterance_embeddings, num_units=[self.hidden_units, self.hidden_units])
106 | Hu_stack.append(utterance_embeddings)
107 |
108 | if turn_id == self.max_num_utterance - 1:
109 | last_u_reps = Hu_stack
110 |
111 | r_a_u_stack = []
112 | u_a_r_stack = []
113 |
114 | for i in range(self.num_blocks + 1):
115 | with tf.variable_scope("utterance_attention_response_{}".format(i), reuse=reuse):
116 | u_a_r, _ = multihead_attention(queries=Hu_stack[i], keys=Hr_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate)
117 | u_a_r = feedforward(u_a_r, num_units=[self.hidden_units, self.hidden_units])
118 | u_a_r_stack.append(u_a_r)
119 | with tf.variable_scope("response_attention_utterance_{}".format(i), reuse=reuse):
120 | r_a_u, _ = multihead_attention(queries=Hr_stack[i], keys=Hu_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate)
121 | r_a_u = feedforward(r_a_u, num_units=[self.hidden_units, self.hidden_units])
122 | r_a_u_stack.append(r_a_u)
123 | u_a_r_stack.extend(Hu_stack)
124 | r_a_u_stack.extend(Hr_stack)
125 |
126 | n_a_u_stack = []
127 | u_a_n_stack = []
128 | for i in range(self.num_blocks + 1):
129 | with tf.variable_scope("narrative_attention_response_{}".format(i), reuse=reuse):
130 | n_a_u, _ = multihead_attention(queries=Hn_stack[i], keys=Hu_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate)
131 | n_a_u = feedforward(n_a_u, num_units=[self.hidden_units, self.hidden_units])
132 | n_a_u_stack.append(n_a_u)
133 | with tf.variable_scope("response_attention_narrative_{}".format(i), reuse=reuse):
134 | u_a_n, alpha_1 = multihead_attention(queries=Hu_stack[i], keys=Hn_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate)
135 | u_a_n = feedforward(u_a_n, num_units=[self.hidden_units, self.hidden_units])
136 | u_a_n_stack.append(u_a_n)
137 | n_a_u_stack.extend(Hn_stack)
138 | u_a_n_stack.extend(Hu_stack)
139 |
140 | u_a_r = tf.stack(u_a_r_stack, axis=-1)
141 | r_a_u = tf.stack(r_a_u_stack, axis=-1)
142 | u_a_n = tf.stack(u_a_n_stack, axis=-1)
143 | n_a_u = tf.stack(n_a_u_stack, axis=-1)
144 |
145 | with tf.variable_scope('similarity'):
146 | # sim shape [batch, max_sent_len, max_sent_len, 2 * (stack_num + 1)]
147 | sim_ur = tf.einsum('biks,bjks->bijs', u_a_r, r_a_u) / tf.sqrt(200.0) # for no rp and normal
148 | sim_un = tf.einsum('biks,bjks->bijs', u_a_n, n_a_u) / tf.sqrt(200.0) # for no rp and normal
149 |
150 | self_n = tf.nn.l2_normalize(tf.stack(Hn_stack, axis=-1)) # #for no rp
151 | self_u = tf.nn.l2_normalize(tf.stack(Hu_stack, axis=-1)) # #for no rp
152 | Hn_stack_tensor = tf.stack(Hn_stack, axis=-1) # [batch, o_len, embedding_size, stack]
153 | with tf.variable_scope('similarity'):
154 | self_sim = tf.einsum('biks,bjks->bijs', self_u, self_n) # [batch, u_len, o_len, stack]
155 | self_sim = 1 - self.gamma * tf.reduce_sum(self_sim, axis=1) # [batch, (1), o_len, stack]
156 | Hn_stack = tf.einsum('bjkl,bjl->bjkl', Hn_stack_tensor, self_sim)
157 | Hn_stack = tf.unstack(Hn_stack, axis=-1, num=self.num_blocks + 1)
158 |
159 | Mur.append(sim_ur)
160 | Mun.append(sim_un)
161 | turn_id += 1
162 | if not reuse:
163 | reuse = True
164 |
165 | Hn_stack_for_tracking = tf.layers.dense(tf.stack(Hn_stack, axis=2), self.hidden_units) # [batch, o_len, stack, embedding_size]
166 | Hn_stack_for_tracking = tf.transpose(Hn_stack_for_tracking, perm=[0, 1, 3, 2]) # [batch, o_len, embedding_size, stack]
167 | Hlastu_stack_for_tracking = tf.stack(last_u_reps, axis=-1) # [batch, u_len, embedding_size, stack]
168 | Hr_stack_for_tracking = tf.stack(Hgtr_stack, axis=-1) # [batch, r_len, embedding_size, stack]
169 | Hlastu = tf.transpose(Hlastu_stack_for_tracking, perm=[0, 2, 3, 1])
170 | Hlastu = tf.squeeze(tf.layers.dense(Hlastu, 1), axis=-1) # [batch, embedding_size, stack]
171 | p1_tensor = tf.nn.softmax(tf.einsum('bnds,bds->bns', Hn_stack_for_tracking, Hlastu), axis=1) # [batch, o_len, stack]
172 | Hlastur = tf.transpose(Hr_stack_for_tracking, perm=[0, 2, 3, 1])
173 | Hlastur = tf.squeeze(tf.layers.dense(Hlastur, 1), axis=-1) # [batch, embedding_size, stack]
174 | p2_tensor = tf.nn.softmax(tf.einsum('bnds,bds->bns', Hn_stack_for_tracking, Hlastur), axis=1) # [batch, o_len, stack]
175 | p1 = tf.unstack(p1_tensor, num=self.num_blocks + 1, axis=-1)
176 | p2 = tf.unstack(p2_tensor, num=self.num_blocks + 1, axis=-1)
177 | KL_loss = 0.0
178 | for i in range(self.num_blocks + 1):
179 | KL_loss += tf.reduce_mean(tf.keras.losses.kullback_leibler_divergence(p1[i], p2[i]))
180 | KL_loss /= (self.num_blocks + 1)
181 |
182 | r_a_n_stack = []
183 | n_a_r_stack = []
184 | for i in range(self.num_blocks + 1):
185 | with tf.variable_scope("narrative_attention_response_{}".format(i), reuse=True):
186 | n_a_r, _ = multihead_attention(queries=Hn_stack[i], keys=Hr_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate)
187 | n_a_r = feedforward(n_a_r, num_units=[self.hidden_units, self.hidden_units])
188 | n_a_r_stack.append(n_a_r)
189 | with tf.variable_scope("response_attention_narrative_{}".format(i), reuse=True):
190 | r_a_n, _ = multihead_attention(queries=Hr_stack[i], keys=Hn_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False, dropout_rate=self.dropout_rate)
191 | r_a_n = feedforward(r_a_n, num_units=[self.hidden_units, self.hidden_units])
192 | r_a_n_stack.append(r_a_n)
193 |
194 | n_a_r_stack.extend(Hn_stack)
195 | r_a_n_stack.extend(Hr_stack)
196 | n_a_r = tf.stack(n_a_r_stack, axis=-1)
197 | r_a_n = tf.stack(r_a_n_stack, axis=-1)
198 |
199 | with tf.variable_scope('similarity'):
200 | Mrn = tf.einsum('biks,bjks->bijs', n_a_r, r_a_n) / tf.sqrt(200.0)
201 | self.rosim = Mrn
202 | Mur = tf.stack(Mur, axis=1)
203 | Mun = tf.stack(Mun, axis=1)
204 | with tf.variable_scope('cnn_aggregation'):
205 | conv3d = tf.layers.conv3d(Mur, filters=32, kernel_size=[3, 3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv1")
206 | pool3d = tf.layers.max_pooling3d(conv3d, pool_size=[3, 3, 3], strides=[3, 3, 3], padding="SAME")
207 | conv3d2 = tf.layers.conv3d(pool3d, filters=32, kernel_size=[3, 3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv2")
208 | pool3d2 = tf.layers.max_pooling3d(conv3d2, pool_size=[3, 3, 3], strides=[3, 3, 3], padding="SAME")
209 | mur = tf.contrib.layers.flatten(pool3d2)
210 | with tf.variable_scope('cnn_aggregation', reuse=True):
211 | conv3d = tf.layers.conv3d(Mun, filters=32, kernel_size=[3, 3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv1")
212 | pool3d = tf.layers.max_pooling3d(conv3d, pool_size=[3, 3, 3], strides=[3, 3, 3], padding="SAME")
213 | conv3d2 = tf.layers.conv3d(pool3d, filters=32, kernel_size=[3, 3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv2")
214 | pool3d2 = tf.layers.max_pooling3d(conv3d2, pool_size=[3, 3, 3], strides=[3, 3, 3], padding="SAME")
215 | mun = tf.contrib.layers.flatten(pool3d2)
216 | with tf.variable_scope('cnn_aggregation'):
217 | conv2d = tf.layers.conv2d(Mrn, filters=32, kernel_size=[3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv2d")
218 | pool2d = tf.layers.max_pooling2d(conv2d, pool_size=[3, 3], strides=[3, 3], padding="SAME")
219 | conv2d2 = tf.layers.conv2d(pool2d, filters=32, kernel_size=[3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv2d2")
220 | pool2d2 = tf.layers.max_pooling2d(conv2d2, pool_size=[3, 3], strides=[3, 3], padding="SAME")
221 | mrn = tf.contrib.layers.flatten(pool2d2)
222 |
223 | all_vector = tf.concat([mur, mun, mrn], axis=-1)
224 | logits = tf.reshape(tf.layers.dense(all_vector, 1, kernel_initializer=tf.orthogonal_initializer()), [-1])
225 |
226 | self.y_pred = tf.sigmoid(logits)
227 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate_ph, beta1=0.9, beta2=0.98, epsilon=1e-8)
228 | RS_loss = tf.reduce_mean(tf.clip_by_value(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.cast(self.y_true_ph, tf.float32), logits=logits), -10, 10))
229 | self.loss = self.eta * RS_loss + (1 - self.eta) * KL_loss
230 | self.all_variables = tf.global_variables()
231 | self.grads_and_vars = optimizer.compute_gradients(self.loss)
232 |
233 | for grad, var in self.grads_and_vars:
234 | if grad is None:
235 | print(var)
236 |
237 | self.capped_gvs = [(tf.clip_by_value(grad, -5, 5), var) for grad, var in self.grads_and_vars]
238 | self.train_op = optimizer.apply_gradients(self.capped_gvs, global_step=self.global_step)
239 | self.saver = tf.train.Saver(max_to_keep=10)
240 | self.alpha_1 = alpha_1
241 | # self.alpha_2 = alpha_2
242 | # self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_step)
243 |
244 |
245 | def evaluate(model_path, eval_file, output_path, eta):
246 | with open(eval_file, 'rb') as f:
247 | utterance, response, narrative, gt_response, y_true = pickle.load(f)
248 |
249 | current_lr = 1e-3
250 | all_candidate_scores = []
251 | dataset = tf.data.Dataset.from_tensor_slices((utterance, narrative, response, gt_response, y_true)).batch(eval_batch_size)
252 | iterator = dataset.make_initializable_iterator()
253 | data_iterator = iterator.get_next()
254 |
255 | with open(embedding_file, 'rb') as f:
256 | embeddings = pickle.load(f)
257 |
258 | model = ScriptWriter_cpre(eta)
259 | model.build()
260 | sess = model.load(model_path)
261 | sess.run(iterator.initializer)
262 | sess.run(model.embedding_init, feed_dict={model.embedding_ph: embeddings})
263 |
264 | test_loss = 0.0
265 | step = 0
266 | try:
267 | with tqdm(total=len(y_true), ncols=100) as pbar:
268 | while True:
269 | bu, bn, br, bgtr, by = data_iterator
270 | bu, bn, br, bgtr, by = sess.run([bu, bn, br, bgtr, by])
271 | candidate_scores, loss = sess.run([model.y_pred, model.loss], feed_dict={
272 | model.utterance_ph: bu,
273 | model.narrative_ph: bn,
274 | model.response_ph: br,
275 | model.gt_response_ph: bgtr,
276 | model.y_true_ph: by,
277 | model.learning_rate_ph: current_lr
278 | })
279 | all_candidate_scores.append(candidate_scores)
280 | test_loss += loss
281 | pbar.update(model.eval_batch_size)
282 | step += 1
283 | except tf.errors.OutOfRangeError:
284 | pass
285 |
286 | sess.close()
287 | tf.reset_default_graph()
288 |
289 | all_candidate_scores = np.concatenate(all_candidate_scores, axis=0)
290 | with open(output_path + "test.result.micro_session.txt", "w") as fw:
291 | for sc in all_candidate_scores.tolist():
292 | fw.write(str(sc) + "\n")
293 | return Evaluate.evaluate_all(all_candidate_scores, y_true), test_loss / step, all_candidate_scores.tolist()
294 |
295 |
296 | def simple_evaluate(sess, model, eval_file):
297 | with open(eval_file, 'rb') as f:
298 | utterance, response, narrative, y_true = pickle.load(f)
299 | utterance, utterance_len = Utils.multi_sequences_padding(utterance, max_sentence_len, max_num_utterance=max_num_utterance)
300 | utterance = np.array(utterance)
301 | narrative = np.array(pad_sequences(narrative, padding='post', maxlen=max_sentence_len))
302 | response = np.array(pad_sequences(response, padding='post', maxlen=max_sentence_len))
303 | y_true = np.array(y_true)
304 | all_candidate_scores = []
305 | dataset = tf.data.Dataset.from_tensor_slices((utterance, narrative, response, y_true)).batch(eval_batch_size)
306 | iterator = dataset.make_initializable_iterator()
307 | data_iterator = iterator.get_next()
308 | sess.run(iterator.initializer)
309 | current_lr = 1e-3
310 | test_loss = 0.0
311 | step = 0
312 | try:
313 | with tqdm(total=len(y_true), ncols=100) as pbar:
314 | while True:
315 | bu, bn, br, by = data_iterator
316 | bu, bn, br, by = sess.run([bu, bn, br, by])
317 | candidate_scores, loss = sess.run([model.y_pred, model.loss], feed_dict={
318 | model.utterance_ph: bu,
319 | model.narrative_ph: bn,
320 | model.response_ph: br,
321 | model.y_true_ph: by,
322 | model.gt_response_ph: br,
323 | model.learning_rate_ph: current_lr
324 | })
325 | all_candidate_scores.append(candidate_scores)
326 | test_loss += loss
327 | pbar.update(eval_batch_size)
328 | step += 1
329 | except tf.errors.OutOfRangeError:
330 | pass
331 | all_candidate_scores = np.concatenate(all_candidate_scores, axis=0)
332 | return Evaluate.evaluate_all(all_candidate_scores, y_true), test_loss / step, all_candidate_scores.tolist()
333 |
334 |
335 | def evaluate_multi_turns(test_file, model_path, output_path):
336 | vocab = {}
337 | vocab_id2word = {}
338 |
339 | with open(embedding_file, 'rb') as f:
340 | embeddings = pickle.load(f)
341 |
342 | model = ScriptWriter_cpre()
343 | model.build()
344 | sess = model.load(model_path)
345 | sess.run(model.embedding_init, feed_dict={model.embedding_ph: embeddings})
346 |
347 | with open("./data/vocab.txt", "r", encoding="utf-8") as fr:
348 | for idx, line in enumerate(fr):
349 | line = line.strip().split("\t")
350 | vocab[line[0]] = idx + 1
351 | vocab_id2word[idx + 1] = line[0]
352 | vocab["_PAD_"] = 0
353 | vocab_id2word[0] = "_PAD_"
354 |
355 | def initialize(test_file):
356 | initial_file = output_path + "test.multi.0.pkl"
357 | max_turn = 0
358 | narrative_dict = {}
359 | narrative_dict_score = {}
360 |
361 | with open(test_file, 'rb') as f:
362 | utterance, response, outline, labels = pickle.load(f)
363 | new_utterance, new_response, new_narrative, new_labels = [], [], [], []
364 | for i in range(len(response)):
365 | ut = utterance[i]
366 | if len(ut) == 1:
367 | o = outline[i]
368 | r = response[i]
369 | l = labels[i]
370 | new_utterance.append(ut)
371 | new_response.append(r)
372 | new_narrative.append(o)
373 | new_labels.append(l)
374 | if len(ut) > max_turn:
375 | max_turn = len(ut)
376 | o = "".join([vocab_id2word[x] for x in outline[i]])
377 | if o not in narrative_dict:
378 | narrative_dict[o] = {0: outline[i]}
379 | narrative_dict_score[o] = {0: [-1]}
380 | r = response[i]
381 | l = labels[i]
382 | if len(ut) in narrative_dict[o]:
383 | narrative_dict[o][len(ut)].append(r)
384 | narrative_dict_score[o][len(ut)].append(l)
385 | else:
386 | narrative_dict[o][len(ut)] = [r]
387 | narrative_dict_score[o][len(ut)] = [l]
388 |
389 | pickle.dump(narrative_dict, open(output_path + "response_candidate.pkl", "wb"))
390 |
391 | new_data = [new_utterance, new_response, new_narrative, new_labels]
392 | pickle.dump(new_data, open(initial_file, "wb"))
393 |
394 | (r2_1, r10_1, r10_2, r10_5, mrr), eva_loss, result = simple_evaluate(sess, model, initial_file)
395 | with open(output_path + "test.result.multi.0.txt", "w") as fw:
396 | fw.write("R2@1: %f, R10@1: %f, R10@2: %f, R10@5: %f, MRR: %f\n" % (r2_1, r10_1, r10_2, r10_5, mrr))
397 | for r in result:
398 | fw.write(str(r) + "\n")
399 |
400 | return max_turn, narrative_dict, narrative_dict_score
401 |
402 | max_turn, narrative_dict, narrative_dict_score = initialize(test_file)
403 | for turn in range(1, max_turn):
404 | score = []
405 | with open(output_path + "test.result.multi." + str(turn - 1) + ".txt", "r") as fr:
406 | for idx, line in enumerate(fr):
407 | if idx == 0:
408 | continue
409 | score.append(float(line.strip()))
410 | with open(output_path + "test.multi." + str(turn - 1) + ".pkl", "rb") as fr:
411 | utterance, response, narrative, y_true = pickle.load(fr)
412 |
413 | new_utterance = []
414 | new_response = []
415 | new_narrative = []
416 | new_labels = []
417 |
418 | for i, o in enumerate(narrative):
419 | if i % 10 == 0:
420 | sent_o = "".join([vocab_id2word[x] for x in o])
421 | if turn + 1 in narrative_dict[sent_o]:
422 | new_response.extend(narrative_dict[sent_o][turn + 1])
423 | score_sub_list = score[i:i + 10]
424 | response_sub_list = response[i:i + 10]
425 | max_idx = score_sub_list.index(max(score_sub_list))
426 | selected_response = response_sub_list[max_idx]
427 | for ut in utterance[i:i + 10]:
428 | tmp = ut + [selected_response]
429 | new_utterance.append(tmp)
430 | new_narrative.extend([o] * 10)
431 | new_labels.extend(narrative_dict_score[sent_o][turn + 1])
432 |
433 | new_data = [new_utterance, new_response, new_narrative, new_labels]
434 | new_file = output_path + "test.multi." + str(turn) + ".pkl"
435 | pickle.dump(new_data, open(new_file, "wb"))
436 |
437 | (r2_1, r10_1, r10_2, r10_5, mrr), eva_loss, result = simple_evaluate(sess, model, new_file)
438 | with open(output_path + "test.result.multi." + str(turn) + ".txt", "w") as fw:
439 | fw.write("R2@1: %f, R10@1: %f, R10@2: %f, R10@5: %f, MRR: %f\n" % (r2_1, r10_1, r10_2, r10_5, mrr))
440 | for r in result:
441 | fw.write(str(r) + "\n")
442 |
443 |
444 | def train(eta=0.5, load=False, model_path=None, logger=None):
445 | config = tf.ConfigProto(allow_soft_placement=True)
446 | config.gpu_options.allow_growth = True
447 | epoch = 0
448 | best_result = [0.0, 0.0, 0.0, 0.0, 0.0]
449 | with tf.Session(config=config) as sess:
450 | with open(embedding_file, 'rb') as f:
451 | embeddings = pickle.load(f, encoding="bytes")
452 | with open(train_file, 'rb') as f:
453 | utterance_train, response_train, narrative_train, gt_response_train, y_true_train = pickle.load(f)
454 | with open(val_file, "rb") as f:
455 | utterance_val, response_val, narrative_val, gt_response_val, y_true_val = pickle.load(f)
456 |
457 | train_dataset = tf.data.Dataset.from_tensor_slices((utterance_train, narrative_train, response_train, gt_response_train, y_true_train)).shuffle(1024).batch(batch_size)
458 | train_iterator = train_dataset.make_initializable_iterator()
459 | train_data_iterator = train_iterator.get_next()
460 |
461 | val_dataset = tf.data.Dataset.from_tensor_slices((utterance_val, narrative_val, response_val, gt_response_val, y_true_val)).batch(batch_size)
462 | val_iterator = val_dataset.make_initializable_iterator()
463 | val_data_iterator = val_iterator.get_next()
464 |
465 | model = ScriptWriter_cpre(eta=eta)
466 | model.build()
467 |
468 | if load:
469 | sess = model.load(model_path)
470 |
471 | sess.run(tf.global_variables_initializer())
472 | sess.run(model.embedding_init, feed_dict={model.embedding_ph: embeddings})
473 | current_lr = 1e-3
474 |
475 | while epoch < 4:
476 | print("\nEpoch ", epoch + 1, "/ 4")
477 | train_loss = 0.0
478 | sess.run(train_iterator.initializer)
479 | step = 0
480 | try:
481 | with tqdm(total=len(y_true_train), ncols=100) as pbar:
482 | while True:
483 | bu, bn, br, bgtr, by = train_data_iterator
484 | bu, bn, br, bgtr, by = sess.run([bu, bn, br, bgtr, by])
485 | _, loss = sess.run([model.train_op, model.loss], feed_dict={
486 | model.utterance_ph: bu,
487 | model.narrative_ph: bn,
488 | model.response_ph: br,
489 | model.gt_response_ph: bgtr,
490 | model.y_true_ph: by,
491 | model.learning_rate_ph: current_lr
492 | })
493 | train_loss += loss
494 | pbar.set_postfix(learning_rate=current_lr, loss=loss)
495 | pbar.update(model.batch_size)
496 | step += 1
497 | if step % 500 == 0:
498 | val_loss = 0.0
499 | val_step = 0
500 | sess.run(val_iterator.initializer)
501 | all_candidate_scores = []
502 | try:
503 | while True:
504 | bu, bn, br, bgtr, by = val_data_iterator
505 | bu, bn, br, bgtr, by = sess.run([bu, bn, br, bgtr, by])
506 | candidate_scores, loss = sess.run([model.y_pred, model.loss], feed_dict={
507 | model.utterance_ph: bu,
508 | model.narrative_ph: bn,
509 | model.response_ph: br,
510 | model.gt_response_ph: bgtr,
511 | model.y_true_ph: by,
512 | })
513 | all_candidate_scores.append(candidate_scores)
514 | val_loss += loss
515 | val_step += 1
516 | except tf.errors.OutOfRangeError:
517 | pass
518 | all_candidate_scores = np.concatenate(all_candidate_scores, axis=0)
519 | result = Evaluate.evaluate_all(all_candidate_scores, y_true_val)
520 | if result[0] + result[1] + result[2] + result[3] + result[4] > best_result[0] + best_result[1] + best_result[2] + best_result[3] + best_result[4]:
521 | best_result = result
522 | tqdm.write("Current best result on validation set: r2@1 %.3f, r10@1 %.3f, r10@2 %.3f, r10@5 %.3f, mrr %.3f" % (best_result[0], best_result[1], best_result[2], best_result[3], best_result[4]))
523 | logger.info("Current best result on validation set: r2@1 %.3f, r10@1 %.3f, r10@2 %.3f, r10@5 %.3f, mrr %.3f" % (best_result[0], best_result[1], best_result[2], best_result[3], best_result[4]))
524 | model.saver.save(sess, save_path + "model")
525 | patience = 0
526 | else:
527 | patience += 1
528 | if patience >= 3:
529 | current_lr *= 0.5
530 | except tf.errors.OutOfRangeError:
531 | pass
532 |
533 | val_loss = 0.0
534 | val_step = 0
535 | sess.run(val_iterator.initializer)
536 | all_candidate_scores = []
537 | try:
538 | while True:
539 | bu, bn, br, bgtr, by = val_data_iterator
540 | bu, bn, br, bgtr, by = sess.run([bu, bn, br, bgtr, by])
541 | candidate_scores, loss = sess.run([model.y_pred, model.loss], feed_dict={
542 | model.utterance_ph: bu,
543 | model.narrative_ph: bn,
544 | model.response_ph: br,
545 | model.gt_response_ph: bgtr,
546 | model.y_true_ph: by
547 | })
548 | all_candidate_scores.append(candidate_scores)
549 | val_loss += loss
550 | val_step += 1
551 | except tf.errors.OutOfRangeError:
552 | pass
553 | all_candidate_scores = np.concatenate(all_candidate_scores, axis=0)
554 | result = Evaluate.evaluate_all(all_candidate_scores, y_true_val)
555 | if result[0] + result[1] + result[2] + result[3] + result[4] > best_result[0] + best_result[1] + best_result[2] + best_result[3] + best_result[4]:
556 | best_result = result
557 | tqdm.write("Current best result on validation set: r2@1 %.3f, r10@1 %.3f, r10@2 %.3f, r10@5 %.3f, mrr %.3f" % (best_result[0], best_result[1], best_result[2], best_result[3], best_result[4]))
558 | logger.info("Current best result on validation set: r2@1 %.3f, r10@1 %.3f, r10@2 %.3f, r10@5 %.3f, mrr %.3f" % (best_result[0], best_result[1], best_result[2], best_result[3], best_result[4]))
559 | model.saver.save(sess, save_path + "model")
560 | tqdm.write('Epoch No: %d, the train loss is %f, the dev loss is %f' % (epoch + 1, train_loss / step, val_loss / val_step))
561 | epoch += 1
562 | sess.close()
563 | tf.reset_default_graph()
564 |
565 |
566 | if __name__ == "__main__":
567 | os.environ['CUDA_VISIBLE_DEVICES'] = "0"
568 | log_path = "./model/cpre/all_log"
569 | logging.basicConfig(filename=log_path, level=logging.INFO)
570 | logger = logging.getLogger(__name__)
571 | eta = 0.7
572 | save_path = "./model/cpre/"
573 | result_path = "./output/cpre/"
574 | if not os.path.exists(save_path):
575 | os.mkdir(save_path)
576 | if not os.path.exists(result_path):
577 | os.mkdir(result_path)
578 | logger.info("Current Eta: %.2f" % eta)
579 | train(eta=eta, logger=logger)
580 |
581 | (r2_1, r10_1, r10_2, r10_5, mrr), eva_loss, _ = evaluate(save_path, evaluate_file, output_path=result_path, eta=eta)
582 | print("Loss on test set: %f, R2@1: %f, R10@1: %f, R10@2: %f, R10@5: %f, MRR: %f" % (eva_loss, r2_1, r10_1, r10_2, r10_5, mrr))
583 |
584 | # to evaluate multi-turn results, the vocab file is needed
585 | # evaluate_multi_turns(test_file=evaluate_file, model_path=save_path, output_path=result_path)
586 | # Evaluate.recover_and_show(result_path)
587 | # test_file = result_path + "test.result.multi.txt"
588 | # gold_file = "../data/ground_truth.result.mul.txt"
589 | # Evaluate.evaluate_multi_turn_result(test_file, gold_file)
590 |
591 |
--------------------------------------------------------------------------------
/ScriptWriter.py:
--------------------------------------------------------------------------------
1 | from modules import *
2 | from keras.preprocessing.sequence import pad_sequences
3 | import Utils
4 | import Evaluate
5 | import pickle
6 | import os
7 | from tqdm import tqdm
8 |
9 | embedding_file = "./data/embeddings.pkl"
10 | train_file = "./data/train.pkl"
11 | val_file = "./data/dev.pkl"
12 | evaluate_file = "./data/test.pkl"
13 | evaluate_embedding_file = "./data/embeddings.pkl"
14 |
15 | max_sentence_len = 50
16 | max_num_utterance = 11
17 | batch_size = 64
18 | eval_batch_size = 64
19 |
20 |
21 | class ScripteWriter():
22 | def __init__(self, data_iterator):
23 | self.max_num_utterance = max_num_utterance
24 | self.negative_samples = 1
25 | self.max_sentence_len = max_sentence_len
26 | self.word_embedding_size = 200
27 | self.hidden_units = 200
28 | self.total_words = 43514
29 | self.batch_size = batch_size
30 | self.eval_batch_size = eval_batch_size
31 | self.initial_learning_rate = 1e-3
32 | self.dropout_rate = 0
33 | self.num_heads = 1
34 | self.num_blocks = 3
35 | # self.gamma = 0.1
36 | self.gamma = tf.get_variable('gamma', shape=1, dtype=tf.float32, trainable=True, initializer=tf.constant_initializer(0.5))
37 |
38 | self.utterance_ph = data_iterator[0]
39 | self.response_ph = data_iterator[4]
40 | self.y_true = data_iterator[6]
41 | self.embedding_ph = tf.placeholder(tf.float32, shape=(self.total_words, self.word_embedding_size))
42 | self.response_len = data_iterator[5]
43 | self.all_utterance_len_ph = data_iterator[1]
44 | self.narrative_ph = data_iterator[2]
45 | self.narrative_len = data_iterator[3]
46 | self.word_embeddings = tf.get_variable('word_embeddings_v', shape=(self.total_words, self.word_embedding_size), dtype=tf.float32, trainable=False)
47 | self.embedding_init = self.word_embeddings.assign(self.embedding_ph)
48 | self.global_step = tf.Variable(0, name='global_step', trainable=False)
49 | self.is_training = True
50 |
51 | def load(self, previous_modelpath):
52 | sess = tf.Session()
53 | latest_ckpt = tf.train.latest_checkpoint(previous_modelpath)
54 | # latest_ckpt = previous_modelpath + "model.4"
55 | print("recover from checkpoint: " + latest_ckpt)
56 | variables = tf.contrib.framework.get_variables_to_restore()
57 | saver = tf.train.Saver(variables)
58 | saver.restore(sess, latest_ckpt)
59 | return sess
60 |
61 | def build(self):
62 | all_utterances = tf.unstack(self.utterance_ph, num=self.max_num_utterance, axis=1)
63 | all_utterance_len = tf.unstack(self.all_utterance_len_ph, num=self.max_num_utterance, axis=1)
64 | reuse = None
65 | alpha_1 = None
66 |
67 | response_embeddings = embedding(self.response_ph, initializer=self.word_embeddings)
68 | response_embeddings = tf.layers.dropout(response_embeddings, rate=self.dropout_rate, training=tf.convert_to_tensor(self.is_training))
69 | Hr_stack = [response_embeddings]
70 | for i in range(self.num_blocks):
71 | with tf.variable_scope("num_blocks_{}".format(i)):
72 | response_embeddings, _ = multihead_attention(queries=response_embeddings, keys=response_embeddings, num_units=self.hidden_units, num_heads=self.num_heads, dropout_rate=self.dropout_rate, is_training=self.is_training, causality=False)
73 | response_embeddings = feedforward(response_embeddings, num_units=[self.hidden_units, self.hidden_units])
74 | Hr_stack.append(response_embeddings)
75 |
76 | narrative_embeddings = embedding(self.narrative_ph, initializer=self.word_embeddings)
77 | narrative_embeddings = tf.layers.dropout(narrative_embeddings, rate=self.dropout_rate, training=tf.convert_to_tensor(self.is_training))
78 | Hn_stack = [narrative_embeddings]
79 | for i in range(self.num_blocks):
80 | with tf.variable_scope("num_blocks_{}".format(i), reuse=True):
81 | narrative_embeddings, _ = multihead_attention(queries=narrative_embeddings, keys=narrative_embeddings, num_units=self.hidden_units, num_heads=self.num_heads, dropout_rate=self.dropout_rate, is_training=self.is_training, causality=False)
82 | narrative_embeddings = feedforward(narrative_embeddings, num_units=[self.hidden_units, self.hidden_units])
83 | Hn_stack.append(narrative_embeddings)
84 |
85 | Mur, Mun = [], []
86 | self.decay_factor = []
87 |
88 | for utterance, utterance_len in zip(all_utterances, all_utterance_len):
89 | utterance_embeddings = embedding(utterance, initializer=self.word_embeddings)
90 | Hu_stack = [utterance_embeddings]
91 | for i in range(self.num_blocks):
92 | with tf.variable_scope("num_blocks_{}".format(i), reuse=True):
93 | utterance_embeddings, _ = multihead_attention(queries=utterance_embeddings, keys=utterance_embeddings, num_units=self.hidden_units, num_heads=self.num_heads, dropout_rate=self.dropout_rate, is_training=self.is_training, causality=False)
94 | utterance_embeddings = feedforward(utterance_embeddings, num_units=[self.hidden_units, self.hidden_units])
95 | Hu_stack.append(utterance_embeddings)
96 |
97 | r_a_u_stack = []
98 | u_a_r_stack = []
99 |
100 | for i in range(self.num_blocks + 1):
101 | with tf.variable_scope("utterance_attention_response_{}".format(i), reuse=reuse):
102 | u_a_r, _ = multihead_attention(queries=Hu_stack[i], keys=Hr_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False)
103 | u_a_r = feedforward(u_a_r, num_units=[self.hidden_units, self.hidden_units])
104 | u_a_r_stack.append(u_a_r)
105 | with tf.variable_scope("response_attention_utterance_{}".format(i), reuse=reuse):
106 | r_a_u, _ = multihead_attention(queries=Hr_stack[i], keys=Hu_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False)
107 | r_a_u = feedforward(r_a_u, num_units=[self.hidden_units, self.hidden_units])
108 | r_a_u_stack.append(r_a_u)
109 | u_a_r_stack.extend(Hu_stack)
110 | r_a_u_stack.extend(Hr_stack)
111 |
112 | n_a_u_stack = []
113 | u_a_n_stack = []
114 | for i in range(self.num_blocks + 1):
115 | with tf.variable_scope("narrative_attention_utterance_{}".format(i), reuse=reuse):
116 | n_a_u, _ = multihead_attention(queries=Hn_stack[i], keys=Hu_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False)
117 | n_a_u = feedforward(n_a_u, num_units=[self.hidden_units, self.hidden_units])
118 | n_a_u_stack.append(n_a_u)
119 | with tf.variable_scope("utterance_attention_narrative_{}".format(i), reuse=reuse):
120 | u_a_n, alpha_1 = multihead_attention(queries=Hu_stack[i], keys=Hn_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False)
121 | u_a_n = feedforward(u_a_n, num_units=[self.hidden_units, self.hidden_units])
122 | u_a_n_stack.append(u_a_n)
123 |
124 | n_a_u_stack.extend(Hn_stack)
125 | u_a_n_stack.extend(Hu_stack)
126 |
127 | u_a_r = tf.stack(u_a_r_stack, axis=-1)
128 | r_a_u = tf.stack(r_a_u_stack, axis=-1)
129 | u_a_n = tf.stack(u_a_n_stack, axis=-1)
130 | n_a_u = tf.stack(n_a_u_stack, axis=-1)
131 |
132 | with tf.variable_scope('similarity'):
133 | # sim shape [batch, max_sent_len, max_sent_len, 2 * (stack_num + 1)]
134 | sim_ur = tf.einsum('biks,bjks->bijs', u_a_r, r_a_u) / tf.sqrt(200.0)
135 | sim_un = tf.einsum('biks,bjks->bijs', u_a_n, n_a_u) / tf.sqrt(200.0)
136 |
137 | self_n = tf.nn.l2_normalize(tf.stack(Hn_stack, axis=-1))
138 | self_u = tf.nn.l2_normalize(tf.stack(Hu_stack, axis=-1))
139 | with tf.variable_scope('similarity'):
140 | self_sim = tf.einsum('biks,bjks->bijs', self_u, self_n) # [batch * len * len * stack]
141 | self_sim = tf.unstack(self_sim, axis=-1, num=self.num_blocks + 1)
142 | reuse2 = reuse
143 | for i in range(self.num_blocks + 1):
144 | tmp_self_sim = tf.expand_dims(self_sim[i], axis=-1)
145 | tmp_self_sim = 1 - self.gamma * tf.layers.conv2d(tmp_self_sim, filters=1, kernel_size=[max_sentence_len, 1], padding="valid", kernel_initializer=tf.ones_initializer, use_bias=False, trainable=False, reuse=reuse2) # for auto2
146 | tmp_self_sim = tf.squeeze(tmp_self_sim, axis=1)
147 | tmp_self_sim = tf.squeeze(tmp_self_sim, axis=-1)
148 | Hn_stack[i] = tf.einsum('bik,bi->bik', Hn_stack[i], tmp_self_sim)
149 | reuse2 = True
150 |
151 | Mur.append(sim_ur)
152 | Mun.append(sim_un)
153 |
154 | if not reuse:
155 | reuse = True
156 |
157 | r_a_n_stack = []
158 | n_a_r_stack = []
159 | reuse2 = False
160 | for i in range(self.num_blocks + 1):
161 | with tf.variable_scope("narrative_attention_response_{}".format(i), reuse=reuse2):
162 | n_a_r, _ = multihead_attention(queries=Hn_stack[i], keys=Hr_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False)
163 | n_a_r = feedforward(n_a_r, num_units=[self.hidden_units, self.hidden_units])
164 | n_a_r_stack.append(n_a_r)
165 | with tf.variable_scope("response_attention_narrative_{}".format(i), reuse=reuse2):
166 | r_a_n, _ = multihead_attention(queries=Hr_stack[i], keys=Hn_stack[i], num_units=self.hidden_units, num_heads=self.num_heads, is_training=self.is_training, causality=False)
167 | r_a_n = feedforward(r_a_n, num_units=[self.hidden_units, self.hidden_units])
168 | r_a_n_stack.append(r_a_n)
169 |
170 | n_a_r_stack.extend(Hn_stack)
171 | r_a_n_stack.extend(Hr_stack)
172 | n_a_r = tf.stack(n_a_r_stack, axis=-1)
173 | r_a_n = tf.stack(r_a_n_stack, axis=-1)
174 |
175 | with tf.variable_scope('similarity'):
176 | Mrn = tf.einsum('biks,bjks->bijs', n_a_r, r_a_n) / tf.sqrt(200.0)
177 |
178 | Mur = tf.stack(Mur, axis=1)
179 | Mun = tf.stack(Mun, axis=1)
180 | with tf.variable_scope('cnn_aggregation'):
181 | conv3d = tf.layers.conv3d(Mur, filters=32, kernel_size=[3, 3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv1")
182 | pool3d = tf.layers.max_pooling3d(conv3d, pool_size=[3, 3, 3], strides=[3, 3, 3], padding="SAME")
183 | conv3d2 = tf.layers.conv3d(pool3d, filters=16, kernel_size=[3, 3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv2")
184 | pool3d2 = tf.layers.max_pooling3d(conv3d2, pool_size=[3, 3, 3], strides=[3, 3, 3], padding="SAME")
185 | mur = tf.contrib.layers.flatten(pool3d2)
186 | with tf.variable_scope('cnn_aggregation', reuse=True):
187 | conv3d = tf.layers.conv3d(Mun, filters=32, kernel_size=[3, 3, 3], padding="SAME", activation=tf.nn.elu,
188 | kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv1")
189 | pool3d = tf.layers.max_pooling3d(conv3d, pool_size=[3, 3, 3], strides=[3, 3, 3], padding="SAME")
190 | conv3d2 = tf.layers.conv3d(pool3d, filters=16, kernel_size=[3, 3, 3], padding="SAME", activation=tf.nn.elu,
191 | kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv2")
192 | pool3d2 = tf.layers.max_pooling3d(conv3d2, pool_size=[3, 3, 3], strides=[3, 3, 3], padding="SAME")
193 | mun = tf.contrib.layers.flatten(pool3d2)
194 |
195 | with tf.variable_scope('cnn_aggregation'):
196 | conv2d = tf.layers.conv2d(Mrn, filters=32, kernel_size=[3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv2d")
197 | pool2d = tf.layers.max_pooling2d(conv2d, pool_size=[3, 3], strides=[3, 3], padding="SAME")
198 | conv2d2 = tf.layers.conv2d(pool2d, filters=16, kernel_size=[3, 3], padding="SAME", activation=tf.nn.elu, kernel_initializer=tf.random_uniform_initializer(-0.01, 0.01), name="conv2d2")
199 | pool2d2 = tf.layers.max_pooling2d(conv2d2, pool_size=[3, 3], strides=[3, 3], padding="SAME")
200 | mrn = tf.contrib.layers.flatten(pool2d2)
201 |
202 | all_vector = tf.concat([mur, mun, mrn], axis=-1)
203 | logits = tf.reshape(tf.layers.dense(all_vector, 1, kernel_initializer=tf.orthogonal_initializer()), [-1])
204 |
205 | self.y_pred = tf.sigmoid(logits)
206 | self.learning_rate = tf.train.exponential_decay(self.initial_learning_rate, global_step=self.global_step, decay_steps=1000, decay_rate=0.9, staircase=True)
207 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=0.9, beta2=0.98, epsilon=1e-8)
208 | self.loss = tf.reduce_mean(tf.clip_by_value(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.cast(self.y_true, tf.float32), logits=logits), -10, 10))
209 | self.all_variables = tf.global_variables()
210 | self.grads_and_vars = optimizer.compute_gradients(self.loss)
211 |
212 | for grad, var in self.grads_and_vars:
213 | if grad is None:
214 | print(var)
215 |
216 | self.capped_gvs = [(tf.clip_by_value(grad, -5, 5), var) for grad, var in self.grads_and_vars]
217 | self.train_op = optimizer.apply_gradients(self.capped_gvs, global_step=self.global_step)
218 | self.saver = tf.train.Saver(max_to_keep=10)
219 | self.alpha_1 = alpha_1
220 |
221 |
222 | def evaluate(model_path, eval_file, output_path):
223 | with open(eval_file, 'rb') as f:
224 | utterance, response, narrative, labels = pickle.load(f)
225 |
226 | all_candidate_scores = []
227 | utterance, utterance_len = Utils.multi_sequences_padding(utterance, max_sentence_len, max_num_utterance=max_num_utterance)
228 | utterance, utterance_len = np.array(utterance), np.array(utterance_len)
229 | narrative_len = np.array(Utils.get_sequences_length(narrative, maxlen=max_sentence_len))
230 | narrative = np.array(pad_sequences(narrative, padding='post', maxlen=max_sentence_len))
231 | response_len = np.array(Utils.get_sequences_length(response, maxlen=max_sentence_len))
232 | response = np.array(pad_sequences(response, padding='post', maxlen=max_sentence_len))
233 | y_true = np.array(labels)
234 |
235 | dataset = tf.data.Dataset.from_tensor_slices((utterance, utterance_len, narrative, narrative_len, response, response_len, y_true))
236 | dataset = dataset.batch(eval_batch_size)
237 | iterator = dataset.make_initializable_iterator()
238 |
239 | data_iterator = iterator.get_next()
240 |
241 | with open(evaluate_embedding_file, 'rb') as f:
242 | embeddings = pickle.load(f)
243 |
244 | model = ScripteWriter(data_iterator)
245 | model.build()
246 | sess = model.load(model_path)
247 | sess.run(iterator.initializer)
248 | sess.run(model.embedding_init, feed_dict={model.embedding_ph: embeddings})
249 |
250 | test_loss = 0.0
251 | step = 0
252 | try:
253 | with tqdm(total=len(y_true)) as pbar:
254 | while True:
255 | candidate_scores, loss = sess.run([model.y_pred, model.loss])
256 | all_candidate_scores.append(candidate_scores)
257 | test_loss += loss
258 | pbar.update(model.eval_batch_size)
259 | step += 1
260 | except tf.errors.OutOfRangeError:
261 | pass
262 |
263 | all_candidate_scores = np.concatenate(all_candidate_scores, axis=0)
264 | with open(output_path + "test.result.micro_session.txt", "w") as fw:
265 | for sc in all_candidate_scores.tolist():
266 | fw.write(str(sc) + "\n")
267 | return Evaluate.evaluate_all(all_candidate_scores, labels), test_loss / step, all_candidate_scores.tolist()
268 |
269 |
270 | def simple_evaluate(sess, model, iterator, utterance_ph, utterance_len_ph, narrative_ph, narrative_len_ph, response_ph, response_len_ph, y_true_ph, eval_file):
271 | with open(eval_file, 'rb') as f:
272 | utterance, response, narrative, labels = pickle.load(f)
273 |
274 | all_candidate_scores = []
275 | utterance, utterance_len = Utils.multi_sequences_padding(utterance, max_sentence_len, max_num_utterance=max_num_utterance)
276 | utterance, utterance_len = np.array(utterance), np.array(utterance_len)
277 | narrative_len = np.array(Utils.get_sequences_length(narrative, maxlen=max_sentence_len))
278 | narrative = np.array(pad_sequences(narrative, padding='post', maxlen=max_sentence_len))
279 | response_len = np.array(Utils.get_sequences_length(response, maxlen=max_sentence_len))
280 | response = np.array(pad_sequences(response, padding='post', maxlen=max_sentence_len))
281 | y_true = np.array(labels)
282 |
283 | sess.run(iterator.initializer, feed_dict={utterance_ph: utterance,
284 | utterance_len_ph: utterance_len,
285 | narrative_ph: narrative,
286 | narrative_len_ph: narrative_len,
287 | response_ph: response,
288 | response_len_ph: response_len,
289 | y_true_ph: y_true})
290 |
291 | test_loss = 0.0
292 | step = 0
293 | try:
294 | with tqdm(total=len(y_true)) as pbar:
295 | while True:
296 | candidate_scores, loss = sess.run([model.y_pred, model.loss])
297 | all_candidate_scores.append(candidate_scores)
298 | test_loss += loss
299 | pbar.update(eval_batch_size)
300 | step += 1
301 | except tf.errors.OutOfRangeError:
302 | pass
303 |
304 | all_candidate_scores = np.concatenate(all_candidate_scores, axis=0)
305 | return Evaluate.evaluate_all(all_candidate_scores, labels), test_loss / step, all_candidate_scores.tolist()
306 |
307 |
308 | def evaluate_multi_turns(test_file, model_path, output_path):
309 | vocab = {}
310 | vocab_id2word = {}
311 |
312 | utterance_ph = tf.placeholder(tf.int32, shape=(None, max_num_utterance, max_sentence_len))
313 | response_ph = tf.placeholder(tf.int32, shape=(None, max_sentence_len))
314 | y_true_ph = tf.placeholder(tf.int32, shape=(None,))
315 | response_len_ph = tf.placeholder(tf.int32, shape=(None,))
316 | utterance_len_ph = tf.placeholder(tf.int32, shape=(None, max_num_utterance))
317 | narrative_ph = tf.placeholder(tf.int32, shape=(None, max_sentence_len))
318 | narrative_len_ph = tf.placeholder(tf.int32, shape=(None,))
319 |
320 | with open(evaluate_embedding_file, 'rb') as f:
321 | embeddings = pickle.load(f)
322 |
323 | dataset = tf.data.Dataset.from_tensor_slices((utterance_ph, utterance_len_ph, narrative_ph, narrative_len_ph, response_ph, response_len_ph, y_true_ph))
324 | dataset = dataset.batch(eval_batch_size)
325 | iterator = dataset.make_initializable_iterator()
326 |
327 | data_iterator = iterator.get_next()
328 |
329 | model = ScripteWriter(data_iterator)
330 | model.build()
331 | sess = model.load(model_path)
332 | sess.run(model.embedding_init, feed_dict={model.embedding_ph: embeddings})
333 |
334 | with open("./data/vocab.txt", "r", encoding="utf-8") as fr:
335 | for idx, line in enumerate(fr):
336 | line = line.strip().split("\t")
337 | vocab[line[0]] = idx + 1
338 | vocab_id2word[idx + 1] = line[0]
339 | vocab["_PAD_"] = 0
340 | vocab_id2word[0] = "_PAD_"
341 |
342 | def initialize(test_file):
343 | initial_file = output_path + "test.multi.0.pkl"
344 | max_turn = 0
345 | narrative_dict = {}
346 | narrative_dict_score = {}
347 |
348 | with open(test_file, 'rb') as f:
349 | utterance, response, narrative, labels = pickle.load(f)
350 | new_utterance, new_response, new_narrative, new_labels = [], [], [], []
351 | for i in range(len(response)):
352 | ut = utterance[i]
353 | if len(ut) == 1:
354 | o = narrative[i]
355 | r = response[i]
356 | l = labels[i]
357 | new_utterance.append(ut)
358 | new_response.append(r)
359 | new_narrative.append(o)
360 | new_labels.append(l)
361 | if len(ut) > max_turn:
362 | max_turn = len(ut)
363 | o = "".join([vocab_id2word[x] for x in narrative[i]])
364 | if o not in narrative_dict:
365 | narrative_dict[o] = {0: narrative[i]}
366 | narrative_dict_score[o] = {0: [-1]}
367 | r = response[i]
368 | l = labels[i]
369 | if len(ut) in narrative_dict[o]:
370 | narrative_dict[o][len(ut)].append(r)
371 | narrative_dict_score[o][len(ut)].append(l)
372 | else:
373 | narrative_dict[o][len(ut)] = [r]
374 | narrative_dict_score[o][len(ut)] = [l]
375 |
376 | pickle.dump(narrative_dict, open(output_path + "response_candidate.pkl", "wb"))
377 |
378 | new_data = [new_utterance, new_response, new_narrative, new_labels]
379 | pickle.dump(new_data, open(initial_file, "wb"))
380 |
381 | (acc, r2_1, r10_1, r10_2, r10_5, mrr), eva_loss, result = simple_evaluate(sess, model, iterator, utterance_ph, utterance_len_ph, narrative_ph, narrative_len_ph, response_ph, response_len_ph, y_true_ph, initial_file)
382 | with open(output_path + "test.result.multi.0.txt", "w") as fw:
383 | fw.write("R2@1: %f, R10@1: %f, R10@2: %f, R10@5: %f, MRR: %f\n" % (r2_1, r10_1, r10_2, r10_5, mrr))
384 | for r in result:
385 | fw.write(str(r) + "\n")
386 |
387 | return max_turn, narrative_dict, narrative_dict_score
388 |
389 | max_turn, narrative_dict, narrative_dict_score = initialize(test_file)
390 |
391 | for turn in range(1, max_turn):
392 | # for turn in range(1, 2):
393 | score = []
394 | with open(output_path + "test.result.multi." + str(turn - 1) + ".txt", "r") as fr:
395 | for idx, line in enumerate(fr):
396 | if idx == 0:
397 | continue
398 | score.append(float(line.strip()))
399 | with open(output_path + "test.multi." + str(turn - 1) + ".pkl", "rb") as fr:
400 | utterance, response, narrative, labels = pickle.load(fr)
401 |
402 | new_utterance = []
403 | new_response = []
404 | new_narrative = []
405 | new_labels = []
406 |
407 | for i, o in enumerate(narrative):
408 | if i % 10 == 0:
409 | sent_o = "".join([vocab_id2word[x] for x in o])
410 | if turn + 1 in narrative_dict[sent_o]:
411 | new_response.extend(narrative_dict[sent_o][turn + 1])
412 | score_sub_list = score[i:i + 10]
413 | response_sub_list = response[i:i + 10]
414 | max_idx = score_sub_list.index(max(score_sub_list))
415 | selected_response = response_sub_list[max_idx]
416 | for ut in utterance[i:i + 10]:
417 | tmp = ut + [selected_response]
418 | new_utterance.append(tmp)
419 | new_narrative.extend([o] * 10)
420 | new_labels.extend(narrative_dict_score[sent_o][turn + 1])
421 |
422 | new_data = [new_utterance, new_response, new_narrative, new_labels]
423 | new_file = output_path + "test.multi." + str(turn) + ".pkl"
424 | pickle.dump(new_data, open(new_file, "wb"))
425 |
426 | (acc, r2_1, r10_1, r10_2, r10_5, mrr), eva_loss, result = simple_evaluate(sess, model, iterator, utterance_ph, utterance_len_ph, narrative_ph, narrative_len_ph, response_ph, response_len_ph, y_true_ph, new_file)
427 | with open(output_path + "test.result.multi." + str(turn) + ".txt", "w") as fw:
428 | fw.write("R2@1: %f, R10@1: %f, R10@2: %f, R10@5: %f, MRR: %f\n" % (r2_1, r10_1, r10_2, r10_5, mrr))
429 | for r in result:
430 | fw.write(str(r) + "\n")
431 |
432 |
433 | def train(load=False, model_path=None):
434 | best_val_loss = 100000.0
435 | config = tf.ConfigProto(allow_soft_placement=True)
436 | config.gpu_options.allow_growth = True
437 | epoch = 0
438 | with tf.Session(config=config) as sess:
439 | with open(embedding_file, 'rb') as f:
440 | embeddings = pickle.load(f, encoding="bytes")
441 | with open(train_file, 'rb') as f:
442 | utterance, response, narrative, labels = pickle.load(f)
443 | with open(val_file, "rb") as f:
444 | utterance_val, response_val, narrative_val, labels_val = pickle.load(f)
445 |
446 | state = np.random.get_state()
447 | np.random.shuffle(utterance)
448 | np.random.set_state(state)
449 | np.random.shuffle(response)
450 | np.random.set_state(state)
451 | np.random.shuffle(labels)
452 | np.random.set_state(state)
453 | np.random.shuffle(narrative)
454 |
455 | utterance_ph = tf.placeholder(tf.int32, shape=(None, max_num_utterance, max_sentence_len))
456 | response_ph = tf.placeholder(tf.int32, shape=(None, max_sentence_len))
457 | y_true_ph = tf.placeholder(tf.int32, shape=(None,))
458 | response_len_ph = tf.placeholder(tf.int32, shape=(None,))
459 | utterance_len_ph = tf.placeholder(tf.int32, shape=(None, max_num_utterance))
460 | narrative_ph = tf.placeholder(tf.int32, shape=(None, max_sentence_len))
461 | narrative_len_ph = tf.placeholder(tf.int32, shape=(None,))
462 |
463 | utterance_train, utterance_len_train = Utils.multi_sequences_padding(utterance, max_sentence_len, max_num_utterance=max_num_utterance)
464 | utterance_train, utterance_len_train = np.array(utterance_train), np.array(utterance_len_train)
465 | response_len_train = np.array(Utils.get_sequences_length(response, maxlen=max_sentence_len))
466 | response_train = np.array(pad_sequences(response, padding='post', maxlen=max_sentence_len))
467 | narrative_len_train = np.array(Utils.get_sequences_length(narrative, maxlen=max_sentence_len))
468 | narrative_train = np.array(pad_sequences(narrative, padding='post', maxlen=max_sentence_len))
469 | y_true_train = np.array(labels)
470 |
471 | utterance_val, utterance_len_val = Utils.multi_sequences_padding(utterance_val, max_sentence_len, max_num_utterance=max_num_utterance)
472 | utterance_val, utterance_len_val = np.array(utterance_val), np.array(utterance_len_val)
473 | response_len_val = np.array(Utils.get_sequences_length(response_val, maxlen=max_sentence_len))
474 | response_val = np.array(pad_sequences(response_val, padding='post', maxlen=max_sentence_len))
475 | narrative_len_val = np.array(Utils.get_sequences_length(narrative_val, maxlen=max_sentence_len))
476 | narrative_val = np.array(pad_sequences(narrative_val, padding='post', maxlen=max_sentence_len))
477 | y_true_val = np.array(labels_val)
478 |
479 | dataset = tf.data.Dataset.from_tensor_slices((utterance_ph, utterance_len_ph, narrative_ph, narrative_len_ph, response_ph, response_len_ph,
480 | y_true_ph)).shuffle(1000)
481 | dataset = dataset.batch(batch_size)
482 | iterator = dataset.make_initializable_iterator()
483 |
484 | data_iterator = iterator.get_next()
485 |
486 | model = ScripteWriter(data_iterator)
487 | model.build()
488 |
489 | if load:
490 | sess = model.load(model_path)
491 |
492 | sess.run(tf.global_variables_initializer())
493 | sess.run(model.embedding_init, feed_dict={model.embedding_ph: embeddings})
494 |
495 | while epoch < 8:
496 | train_loss = 0.0
497 | sess.run(iterator.initializer, feed_dict={utterance_ph: utterance_train,
498 | utterance_len_ph: utterance_len_train,
499 | narrative_ph: narrative_train,
500 | narrative_len_ph: narrative_len_train,
501 | response_ph: response_train,
502 | response_len_ph: response_len_train,
503 | y_true_ph: y_true_train})
504 | step = 0
505 | try:
506 | with tqdm(total=len(y_true_train)) as pbar:
507 | while True:
508 | _, loss, lr = sess.run([model.train_op, model.loss, model.learning_rate])
509 | train_loss += loss
510 | pbar.set_postfix(learning_rate=lr, loss=loss)
511 | pbar.update(model.batch_size)
512 | step += 1
513 | except tf.errors.OutOfRangeError:
514 | pass
515 |
516 | val_loss = 0.0
517 | val_step = 0
518 | sess.run(iterator.initializer, feed_dict={utterance_ph: utterance_val,
519 | utterance_len_ph: utterance_len_val,
520 | narrative_ph: narrative_val,
521 | narrative_len_ph: narrative_len_val,
522 | response_ph: response_val,
523 | response_len_ph: response_len_val,
524 | y_true_ph: y_true_val})
525 | try:
526 | while True:
527 | loss = sess.run(model.loss)
528 | val_loss += loss
529 | val_step += 1
530 | except tf.errors.OutOfRangeError:
531 | pass
532 |
533 | print('Epoch No: %d, the train loss is %f, the dev loss is %f' % (epoch + 1, train_loss / step, val_loss / val_step))
534 | if val_loss / val_step < best_val_loss:
535 | best_val_loss = val_loss / val_step
536 | model.saver.save(sess, "./model/model.{0}".format(epoch + 1))
537 | print("Save model.{}".format(epoch + 1))
538 | epoch += 1
539 |
540 | if __name__ == "__main__":
541 | os.environ['CUDA_VISIBLE_DEVICES'] = "0"
542 | is_train = True
543 | previous_train_modelpath = "./model/"
544 | if is_train:
545 | train(False, previous_train_modelpath)
546 | else:
547 | # check the validation loss obtained in the training process and use the saved model with the smallest validation loss
548 | (acc, r2_1, r10_1, r10_2, r10_5, mrr), eva_loss, _ = evaluate(previous_train_modelpath, evaluate_file, output_path="./output/")
549 | print("Loss on test set: %f, Accuracy: %f, R2@1: %f, R10@1: %f, R10@2: %f, R10@5: %f, MRR: %f" % (eva_loss, acc, r2_1, r10_1, r10_2, r10_5, mrr))
550 |
551 | # to evaluate multi-turn results, the vocab file is needed
552 | # evaluate_multi_turns(test_file=evaluate_file, model_path=previous_train_modelpath, output_path="./output/")
553 | # Evaluate.recover_and_show(basic_directory="./output/")
554 | # test_file = basic_directory + "test.result.multi.txt"
555 | # gold_file = "./data/ground_truth.result.mul.txt"
556 | # Evaluate.evaluate_multi_turn_result(test_file, gold_file)
557 |
--------------------------------------------------------------------------------
/Utils.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures
2 | import pickle
3 | import numpy as np
4 | from keras.preprocessing.sequence import pad_sequences
5 | from keras.preprocessing.text import text_to_word_sequence
6 |
7 |
8 | def multi_sequences_padding(all_sequences, max_sentence_len=50, max_num_utterance=10):
9 | PAD_SEQUENCE = [0] * max_sentence_len
10 | padded_sequences = []
11 | sequences_length = []
12 | for sequences in all_sequences:
13 | sequences_len = len(sequences)
14 | sequences_length.append(get_sequences_length(sequences, maxlen=max_sentence_len))
15 | if sequences_len < max_num_utterance:
16 | sequences += [PAD_SEQUENCE] * (max_num_utterance - sequences_len)
17 | sequences_length[-1] += [0] * (max_num_utterance - sequences_len)
18 | else:
19 | sequences = sequences[-max_num_utterance:]
20 | sequences_length[-1] = sequences_length[-1][-max_num_utterance:]
21 | sequences = pad_sequences(sequences, padding='post', maxlen=max_sentence_len)
22 | padded_sequences.append(sequences)
23 | return padded_sequences, sequences_length
24 |
25 |
26 | def get_sequences_length(sequences, maxlen):
27 | sequences_length = [min(len(sequence), maxlen) for sequence in sequences]
28 | return sequences_length
29 |
30 |
31 | def generate_data_with_random_samples():
32 | # generate negative samples randomly
33 | # In training set, for each sample, we randomly sample a response as a negative candidate
34 | # In development and test set, for each sample, we randomly sample 9 responses as negative candidates and we add a "EOS" response as a candidate to let model select when to stop
35 | import random
36 | import pickle
37 | vocab = {}
38 | positive_data = []
39 | EOS_ID = 7
40 | with open("./data/sample_vocab.txt", "r", encoding="utf-8") as fr:
41 | for idx, line in enumerate(fr):
42 | line = line.strip().split("\t")
43 | vocab[line[0]] = idx + 1
44 | with open("./data/sample_data.txt", "r", encoding="utf-8") as fr:
45 | tmp = []
46 | for line in fr:
47 | line = line.strip()
48 | if len(line) > 0:
49 | line = line.split("\t")
50 | if line[0] == "narrative":
51 | tmp.append(line[1])
52 | elif line[0] == "script":
53 | tmp.append(line[1])
54 | else:
55 | narrative = tmp[0]
56 | context = tmp[1:]
57 | narrative_id = [vocab.get(word, 0) for word in narrative.split()]
58 | context_id = [[vocab.get(word, 0) for word in sent.split()] for sent in context]
59 | if len(narrative_id) == 0 or len(context_id) == 0:
60 | continue
61 | data = [context_id, narrative_id, 1]
62 | positive_data.append(data)
63 | tmp = []
64 | random.shuffle(positive_data)
65 | print("all suitable sessions: ", len(positive_data))
66 | train_num = int(len(positive_data) * 0.9)
67 | dev_test_num = int(len(positive_data) * 0.05)
68 | train, dev, test = positive_data[:train_num], positive_data[train_num: train_num + dev_test_num], positive_data[train_num + dev_test_num:]
69 | train_all, dev_all, test_all = [], [], []
70 | for context_id, narrative_id, _ in train:
71 | num_context = len(context_id)
72 | for i in range(1, num_context):
73 | context = context_id[:i]
74 | response = context_id[i]
75 | train_all.append([context, response, narrative_id, 1])
76 | flag = True
77 | while flag:
78 | random_idx = random.randint(0, len(positive_data) - 1)
79 | random_context = positive_data[random_idx][0]
80 | random_idx_2 = random.randint(0, len(random_context) - 1)
81 | random_response = random_context[random_idx_2]
82 | if len(response) != len(random_response):
83 | flag = False
84 | train_all.append([context, random_response, narrative_id, 0])
85 | else:
86 | for idx, wid in enumerate(response):
87 | if wid != random_response[idx]:
88 | flag = False
89 | train_all.append([context, random_response, narrative_id, 0])
90 | break
91 | print(train_all[0], train_all[1])
92 | for context_id, narrative_id, _ in dev:
93 | num_context = len(context_id)
94 | for i in range(1, num_context):
95 | context = context_id[:i]
96 | response = context_id[i]
97 | dev_all.append([context, response, narrative_id, 1])
98 | count = 0
99 | negative_samples = []
100 | while count < 9:
101 | random_idx = random.randint(0, len(positive_data) - 1)
102 | random_context = positive_data[random_idx][0]
103 | random_idx_2 = random.randint(0, len(random_context) - 1)
104 | random_response = random_context[random_idx_2]
105 | if random_response not in negative_samples and random_response != [EOS_ID]:
106 | if len(response) != len(random_response):
107 | dev_all.append([context, random_response, narrative_id, 0])
108 | count += 1
109 | negative_samples.append(random_response)
110 | else:
111 | for idx, wid in enumerate(response):
112 | if wid != random_response[idx]:
113 | dev_all.append([context, random_response, narrative_id, 0])
114 | count += 1
115 | negative_samples.append(random_response)
116 | break
117 | if response == [EOS_ID]:
118 | dev_all.append([context, [EOS_ID], narrative_id, 1])
119 | else:
120 | dev_all.append([context, [EOS_ID], narrative_id, 0])
121 | print(dev_all[0], dev_all[1], dev_all[2])
122 | for context_id, narrative_id, _ in test:
123 | num_context = len(context_id)
124 | for i in range(1, num_context):
125 | context = context_id[:i]
126 | response = context_id[i]
127 | test_all.append([context, response, narrative_id, 1])
128 | count = 0
129 | negative_samples = []
130 | while count < 9:
131 | random_idx = random.randint(0, len(positive_data) - 1)
132 | random_context = positive_data[random_idx][0]
133 | random_idx_2 = random.randint(0, len(random_context) - 1)
134 | random_response = random_context[random_idx_2]
135 | if random_response not in negative_samples and random_response != [EOS_ID]:
136 | if len(response) != len(random_response):
137 | test_all.append([context, random_response, narrative_id, 0])
138 | negative_samples.append(random_response)
139 | count += 1
140 | else:
141 | for idx, id in enumerate(response):
142 | if id != random_response[idx]:
143 | test_all.append([context, random_response, narrative_id, 0])
144 | negative_samples.append(random_response)
145 | count += 1
146 | break
147 | if response == [EOS_ID]:
148 | test_all.append([context, [EOS_ID], narrative_id, 1])
149 | else:
150 | test_all.append([context, [EOS_ID], narrative_id, 0])
151 | print(test_all[0], test_all[1], test_all[2])
152 | context, response, narrative, label = [], [], [], []
153 | print("train size: ", len(train_all))
154 | for data in train_all:
155 | context.append(data[0])
156 | response.append(data[1])
157 | narrative.append(data[2])
158 | label.append(data[3])
159 | train = [context, response, narrative, label]
160 | pickle.dump(train, open("./data/train.multi.pkl", "wb"))
161 | context, response, narrative, label = [], [], [], []
162 | print("dev size: ", len(dev_all))
163 | for data in dev_all:
164 | context.append(data[0])
165 | response.append(data[1])
166 | narrative.append(data[2])
167 | label.append(data[3])
168 | dev = [context, response, narrative, label]
169 | pickle.dump(dev, open("./data/dev.multi.pkl", "wb"))
170 | context, response, narrative, label = [], [], [], []
171 | print("test size: ", len(test_all))
172 | for data in test_all:
173 | context.append(data[0])
174 | response.append(data[1])
175 | narrative.append(data[2])
176 | label.append(data[3])
177 | test = [context, response, narrative, label]
178 | pickle.dump(test, open("./data/test.multi.pkl", "wb"))
179 |
180 |
181 | def generate_data_with_solr_samples():
182 | # generate negative samples from solr
183 | # this is only for development and test set since training set has only one negative sample
184 | import pickle
185 | import pysolr
186 | import jieba
187 |
188 | EOS_ID = 7
189 |
190 | def query_comt(post, num):
191 | # the format of the solr data: "ut1: xxxxx, ut2: xxxxx", where ut1 is the index
192 | solr = pysolr.Solr('xxxxxx', timeout=10) # write your Solr address
193 | post = "ut1:(" + post + ")"
194 | results = solr.search(q=post, **{'rows': num}) # rows equal to the number of pairs you want to retrieve
195 | return results
196 |
197 | vocab = {}
198 | vocab_id2word = {}
199 |
200 | with open("./data/vocab.txt", "r", encoding="utf-8") as fr:
201 | for idx, line in enumerate(fr):
202 | line = line.strip().split("\t")
203 | vocab[line[0]] = idx + 1
204 | vocab_id2word[idx + 1] = line[0]
205 |
206 | dev = pickle.load(open("./data/dev.multi.pkl", "rb"))
207 | context, response, narrative, label = dev[0], dev[1], dev[2], dev[3]
208 | num = len(response)
209 | dev_all = []
210 | for i in range(num):
211 | # One positive sample, nine negative samples and a "EOS" sample. 1 + 9 + 1 = 11
212 | if i % 11 == 0 and int(label[i]) == 1:
213 | count = 0
214 | context_ = context[i]
215 | pos_response = "".join([vocab_id2word[x] for x in response[i]])
216 | last_ut = "".join([vocab_id2word[x] for x in context_[-1]]).replace(".", "").replace("?", "").replace("\"", "").replace(":", "")
217 | dev_all.append([context[i], response[i], narrative[i], 1])
218 | negative_samples = query_comt(last_ut, 15)
219 | for result in negative_samples:
220 | if result['ut2'] != pos_response:
221 | negtive_sample = [vocab[x] for x in jieba.lcut(result['ut2']) if x != ' ' and x != '\xa0' and x != '\u3000']
222 | dev_all.append([context[i], negtive_sample, narrative[i], 0])
223 | count += 1
224 | if count == 8:
225 | break
226 | if count != 8:
227 | last = 8 - count
228 | for j in range(last):
229 | negtive_sample = response[i + 1 + j]
230 | dev_all.append([context[i], negtive_sample, narrative[i], 0])
231 | if response[i] == [EOS_ID]:
232 | dev_all.append([context[i], [EOS_ID], narrative[i], 1])
233 | else:
234 | dev_all.append([context[i], [EOS_ID], narrative[i], 0])
235 |
236 | test = pickle.load(open("./data/test.multi.pkl", "rb"))
237 | context, response, narrative, label = test[0], test[1], test[2], test[3]
238 | num = len(response)
239 | test_all = []
240 | print("start test")
241 | for i in range(num):
242 | if i % 11 == 0 and int(label[i]) == 1:
243 | count = 0
244 | context_ = context[i]
245 | pos_response = "".join([vocab_id2word[x] for x in response[i]])
246 | last_ut = "".join([vocab_id2word[x] for x in context_[-1]]).replace(".", "").replace("?", "").replace("\"", "").replace(":", "")
247 | test_all.append([context[i], response[i], narrative[i], 1])
248 | negative_samples = query_comt(last_ut, 15)
249 | for result in negative_samples:
250 | if result['ut2'] != pos_response:
251 | negtive_sample = [vocab[x] for x in jieba.lcut(result['ut2']) if x != ' ' and x != '\xa0' and x != '\u3000']
252 | test_all.append([context[i], negtive_sample, narrative[i], 0])
253 | count += 1
254 | if count == 8:
255 | break
256 | if count != 8:
257 | last = 8 - count
258 | for j in range(last):
259 | negtive_sample = response[i + 1 + j]
260 | test_all.append([context[i], negtive_sample, narrative[i], 0])
261 | if response[i] == [EOS_ID]:
262 | test_all.append([context[i], [EOS_ID], narrative[i], 1])
263 | else:
264 | test_all.append([context[i], [EOS_ID], narrative[i], 0])
265 |
266 | context, response, narrative, label = [], [], [], []
267 | print("dev size: ", len(dev_all))
268 | for data in dev_all:
269 | context.append(data[0])
270 | response.append(data[1])
271 | narrative.append(data[2])
272 | label.append(data[3])
273 | dev = [context, response, narrative, label]
274 | pickle.dump(dev, open("./data/dev.pkl", "wb"))
275 | context, response, narrative, label = [], [], [], []
276 | print("test size: ", len(test_all))
277 | for data in test_all:
278 | context.append(data[0])
279 | response.append(data[1])
280 | narrative.append(data[2])
281 | label.append(data[3])
282 | test = [context, response, narrative, label]
283 | pickle.dump(test, open("./data/test.pkl", "wb"))
284 |
--------------------------------------------------------------------------------
/data/readme.md:
--------------------------------------------------------------------------------
1 | Unzip data.zip to this directory.
2 |
--------------------------------------------------------------------------------
/data/sample_data.txt:
--------------------------------------------------------------------------------
1 | narrative 珍妮不喜欢回家。为了陪珍妮,甘决定晚点回家。珍妮是甘最好的朋友
2 | script 妈会担心我的
3 | script 再坐一会!
4 | script 珍妮为了某种理由不想回家
5 | script 好,珍妮,我留下来
6 | script 她是我最特别的朋友
7 | script EOS
8 |
9 | narrative 珍妮不喜欢回家。为了陪珍妮,甘决定晚点回家。珍妮是甘最好的朋友
10 | script 妈会担心我的
11 | script 再坐一会!
12 | script 珍妮为了某种理由不想回家
13 | script 好,珍妮,我留下来
14 | script 她是我最特别的朋友
15 | script EOS
16 |
17 | narrative 珍妮不喜欢回家。为了陪珍妮,甘决定晚点回家。珍妮是甘最好的朋友
18 | script 妈会担心我的
19 | script 再坐一会!
20 | script 珍妮为了某种理由不想回家
21 | script 好,珍妮,我留下来
22 | script 她是我最特别的朋友
23 | script EOS
24 |
25 | narrative 珍妮不喜欢回家。为了陪珍妮,甘决定晚点回家。珍妮是甘最好的朋友
26 | script 妈会担心我的
27 | script 再坐一会!
28 | script 珍妮为了某种理由不想回家
29 | script 好,珍妮,我留下来
30 | script 她是我最特别的朋友
31 | script EOS
32 |
33 | narrative 珍妮不喜欢回家。为了陪珍妮,甘决定晚点回家。珍妮是甘最好的朋友
34 | script 妈会担心我的
35 | script 再坐一会!
36 | script 珍妮为了某种理由不想回家
37 | script 好,珍妮,我留下来
38 | script 她是我最特别的朋友
39 | script EOS
40 |
--------------------------------------------------------------------------------
/data/sample_data_english.txt:
--------------------------------------------------------------------------------
1 | narrative xxxxxxxxxxxxxxx
2 | script xxxxxxxxxxxx
3 | script xxxxxxxxxxxx
4 | script xxxxxxxxxxxx
5 | script xxxxxxxxxxxx
6 | script xxxxxxxxxxxx
7 | script xxxxxxxxxxxx
8 |
9 | narrative xxxxxxxxxxxxxxx
10 | script xxxxxxxxxxxx
11 | script xxxxxxxxxxxx
12 | script xxxxxxxxxxxx
13 | script xxxxxxxxxxxx
14 | script xxxxxxxxxxxx
15 | script xxxxxxxxxxxx
16 |
17 | narrative xxxxxxxxxxxxxxx
18 | script xxxxxxxxxxxx
19 | script xxxxxxxxxxxx
20 | script xxxxxxxxxxxx
21 | script xxxxxxxxxxxx
22 | script xxxxxxxxxxxx
23 | script xxxxxxxxxxxx
24 |
25 | narrative xxxxxxxxxxxxxxx
26 | script xxxxxxxxxxxx
27 | script xxxxxxxxxxxx
28 | script xxxxxxxxxxxx
29 | script xxxxxxxxxxxx
30 | script xxxxxxxxxxxx
31 | script xxxxxxxxxxxx
--------------------------------------------------------------------------------
/data/sample_vocab.txt:
--------------------------------------------------------------------------------
1 | , 50435
2 | 的 37684
3 | 。 27141
4 | 了 22941
5 | 我 21728
6 | 你 20009
7 | EOS 16109
8 | 他 12878
9 | 是 12859
10 | ? 12549
11 |
--------------------------------------------------------------------------------
/data/sample_vocab_english.txt:
--------------------------------------------------------------------------------
1 | , 50435
2 | I 37684
3 | . 27141
4 | you 22941
5 | my 21728
6 | yours 20009
7 | EOS 16109
8 | he 12878
9 | is 12859
10 | ? 12549
--------------------------------------------------------------------------------
/image/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DaoD/ScriptWriter/89808c9c8d3c7e4114643209c44c5e37602781d1/image/example.png
--------------------------------------------------------------------------------
/image/matching.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DaoD/ScriptWriter/89808c9c8d3c7e4114643209c44c5e37602781d1/image/matching.png
--------------------------------------------------------------------------------
/image/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DaoD/ScriptWriter/89808c9c8d3c7e4114643209c44c5e37602781d1/image/result.png
--------------------------------------------------------------------------------
/image/statistics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DaoD/ScriptWriter/89808c9c8d3c7e4114643209c44c5e37602781d1/image/statistics.png
--------------------------------------------------------------------------------
/image/update.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DaoD/ScriptWriter/89808c9c8d3c7e4114643209c44c5e37602781d1/image/update.png
--------------------------------------------------------------------------------
/model/readme.md:
--------------------------------------------------------------------------------
1 | The trained model is saved here.
2 |
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 | def normalize(inputs, epsilon=1e-8, scope="normalize", reuse=None):
6 | with tf.variable_scope(scope, reuse=reuse):
7 | inputs_shape = inputs.get_shape()
8 | params_shape = inputs_shape[-1:]
9 | mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
10 | beta = tf.get_variable(name='beta', shape=params_shape, dtype=tf.float32, initializer=tf.zeros_initializer())
11 | gamma = tf.get_variable(name='scale', shape=params_shape, dtype=tf.float32, initializer=tf.ones_initializer())
12 | normalized = (inputs - mean) / ((variance + epsilon) ** .5)
13 | outputs = gamma * normalized + beta
14 | return outputs
15 |
16 |
17 | def embedding(inputs, vocab_size=None, embedding_size=None, zero_pad=False, scale=False, scope="embedding", reuse=None, initializer=None):
18 | with tf.variable_scope(scope, reuse=reuse):
19 | if initializer:
20 | lookup_table = initializer
21 | else:
22 | lookup_table = tf.get_variable('lookup_table', dtype=tf.float32, shape=[vocab_size, embedding_size], initializer=tf.contrib.layers.xavier_initializer())
23 | if zero_pad:
24 | lookup_table = tf.concat((tf.zeros(shape=[1, embedding_size]), lookup_table[1:, :]), 0)
25 | outputs = tf.nn.embedding_lookup(lookup_table, inputs)
26 | if scale:
27 | outputs = outputs * (embedding_size ** 0.5)
28 | return outputs
29 |
30 |
31 | def positional_encoding(inputs, num_units, max_len, zero_pad=True, scale=True, scope="positional_encoding", reuse=None):
32 | # N, T = inputs.get_shape().as_list()
33 | # N, T = tf.shape(inputs)
34 | inputs_shape = tf.shape(inputs)
35 | N = inputs_shape[0]
36 | T_real = inputs_shape[1]
37 | T = max_len
38 | with tf.variable_scope(scope, reuse=reuse, dtype=tf.float32):
39 | position_ind = tf.tile(tf.expand_dims(tf.range(T_real), 0), [N, 1])
40 | position_enc = np.array([[pos / np.power(10000, 2 * (i // 2) /num_units) for i in range(num_units)] for pos in range(T)], dtype=np.float32)
41 | position_enc[:, 0::2] = np.sin(position_enc[:, 0::2])
42 | position_enc[:, 1::2] = np.cos(position_enc[:, 1::2])
43 | lookup_table = tf.convert_to_tensor(position_enc)
44 | if zero_pad:
45 | lookup_table = tf.concat((tf.zeros(shape=[1, num_units]), lookup_table[1:, :]), 0)
46 | outputs = tf.nn.embedding_lookup(lookup_table, position_ind)
47 | if scale:
48 | outputs = outputs * (num_units ** 0.5)
49 | return outputs
50 |
51 |
52 | def multihead_attention(queries, keys, num_units=None, num_heads=8, dropout_rate=0, is_training=True, causality=False, scope="multihead_attention", reuse=None):
53 | with tf.variable_scope(scope, reuse=reuse):
54 | if num_units is None:
55 | num_units = queries.get_shape().as_list()[-1]
56 | Q = tf.layers.dense(queries, num_units, activation=tf.nn.relu) # (N, T_q, C)
57 | K = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C)
58 | V = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C)
59 |
60 | Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, C/h)
61 | K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, C/h)
62 | V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, C/h)
63 |
64 | outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) # (h*N, T_q, T_k)
65 | outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5)
66 | key_masks = tf.sign(tf.abs(tf.reduce_sum(keys, axis=-1))) # (N, T_k)
67 | key_masks = tf.tile(key_masks, [num_heads, 1]) # (h*N, T_k)
68 | key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(queries)[1], 1]) # (h*N, T_q, T_k)
69 | paddings = tf.ones_like(outputs) * (-2 ** 32 + 1)
70 | outputs = tf.where(tf.equal(key_masks, 0), paddings, outputs) # (h*N, T_q, T_k)
71 | if causality:
72 | diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k)
73 | tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
74 | masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k)
75 | paddings = tf.ones_like(masks) * (-2 ** 32 + 1)
76 | outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k)
77 | outputs = tf.nn.softmax(outputs) # (h*N, T_q, T_k)
78 | attn_weights = outputs
79 | query_masks = tf.sign(tf.abs(tf.reduce_sum(queries, axis=-1))) # (N, T_q)
80 | query_masks = tf.tile(query_masks, [num_heads, 1]) # (h*N, T_q)
81 | query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, tf.shape(keys)[1]]) # (h*N, T_q, T_k)
82 | outputs *= query_masks # broadcasting. (N, T_q, C)
83 | outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=tf.convert_to_tensor(is_training))
84 | outputs = tf.matmul(outputs, V_) # ( h*N, T_q, C/h)
85 | outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2) # (N, T_q, C)
86 | outputs += queries
87 | outputs = normalize(outputs, scope=scope) # (N, T_q, C)
88 |
89 | return outputs, attn_weights
90 |
91 |
92 | def feedforward(inputs, num_units=[2048, 512], scope="feed_forward", reuse=None):
93 | with tf.variable_scope(scope, reuse=reuse):
94 | params = {"inputs": inputs, "filters": num_units[0], "kernel_size": 1, "activation": tf.nn.relu, "use_bias": True}
95 | outputs = tf.layers.conv1d(**params)
96 | params = {"inputs": outputs, "filters": num_units[1], "kernel_size": 1, "activation": None, "use_bias": True}
97 | outputs = tf.layers.conv1d(**params)
98 | outputs += inputs
99 | outputs = normalize(outputs, scope=scope)
100 | return outputs
101 |
102 |
103 | def label_smoothing(inputs, epsilon=0.1):
104 | K = inputs.get_shape().as_list()[-1] # number of channels
105 | return ((1 - epsilon) * inputs) + (epsilon / K)
106 |
107 |
--------------------------------------------------------------------------------
/output/readme.md:
--------------------------------------------------------------------------------
1 | The output is saved here.
2 |
--------------------------------------------------------------------------------