├── README.md
├── data
└── data_preprocess.py
├── image
├── model.png
└── result.png
├── model
├── __init__.py
├── data_helpers.py
├── eval.py
├── metrics.py
├── model_DIM.py
└── train.py
└── scripts
├── compute_recall.py
├── test.sh
└── train.sh
/README.md:
--------------------------------------------------------------------------------
1 | # Dually Interactive Matching Network for Personalized Response Selection in Retrieval-Based Chatbots
2 | This repository contains the source code and dataset for the EMNLP 2019 paper [Dually Interactive Matching Network for Personalized Response Selection in Retrieval-Based Chatbots](https://www.aclweb.org/anthology/D19-1193.pdf) by Gu et al.
3 |
4 | Our proposed Dually Interactive Matching Network (DIM) has achieved a new state-of-the-art performance of response selection on the PERSONA-CHAT dataset.
5 |
6 | ## Model overview
7 |
8 |
9 | ## Results
10 |
11 |
12 | ## Dependencies
13 | Python 2.7
14 | Tensorflow 1.4.0
15 |
16 | ## Dataset
17 | Your can download the PERSONA-CHAT dataset [here](https://drive.google.com/open?id=1gNyVL5pSMO6DnTIlA9ORNIrd2zm8f3QH) or from [ParlAI](https://parl.ai/), and unzip it to the folder of ```data```.
18 | Run the following commands and the processed files are stored in ```data/personachat_processed/```.
19 | ```
20 | cd data
21 | python data_preprocess.py
22 | ```
23 | Then, download the embedding and vocab files [here](https://drive.google.com/open?id=1gGZfQ-m7EGo5Z1Ts93Ta8GPJpdIQqckC), and unzip them to the folder of ```data/personachat_processed/```.
24 |
25 | ## Train a new model
26 | ```
27 | cd scripts
28 | bash train.sh
29 | ```
30 | The training process is recorded in ```log_DIM_train.txt``` file.
31 |
32 | ## Test a trained model
33 | ```
34 | bash test.sh
35 | ```
36 | The testing process is recorded in ```log_DIM_test.txt``` file. And your can get a ```persona_test_out.txt``` file which records scores for each context-response pair. Run the following command and you can compute the metric of Recall.
37 | ```
38 | python compute_recall.py
39 | ```
40 |
41 | ## Cite
42 | If you use the code, please cite the following paper:
43 | **"Dually Interactive Matching Network for Personalized Response Selection in Retrieval-Based Chatbots"**
44 | Jia-Chen Gu, Zhen-Hua Ling, Xiaodan Zhu, Quan Liu. _EMNLP (2019)_
45 |
46 | ```
47 | @inproceedings{gu-etal-2019-dually,
48 | title = "Dually Interactive Matching Network for Personalized Response Selection in Retrieval-Based Chatbots",
49 | author = "Gu, Jia-Chen and
50 | Ling, Zhen-Hua and
51 | Zhu, Xiaodan and
52 | Liu, Quan",
53 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)",
54 | month = nov,
55 | year = "2019",
56 | address = "Hong Kong, China",
57 | publisher = "Association for Computational Linguistics",
58 | url = "https://www.aclweb.org/anthology/D19-1193",
59 | pages = "1845--1854",
60 | }
61 | ```
62 |
--------------------------------------------------------------------------------
/data/data_preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | from nltk.tokenize import WordPunctTokenizer
5 |
6 |
7 | def tokenize(text):
8 | return WordPunctTokenizer().tokenize(text)
9 |
10 |
11 | def data_process_none(input_path, output_path, fname):
12 |
13 | dialogues = []
14 | dialogue = []
15 | with open(os.path.join(input_path, fname), "r") as f:
16 | for line in f:
17 | line = line.decode('utf-8').strip()
18 | if line.split()[0] == "1": # new dialogue
19 | dialogues.append(dialogue)
20 | dialogue = []
21 | dialogue.append(line)
22 |
23 | dialogues.append(dialogue)
24 | dialogues.remove([])
25 | print("{} is composed of {} dialogues".format(fname, len(dialogues)))
26 |
27 | context_candidates = []
28 | for dialogue in dialogues:
29 | context_history = []
30 | for turn in dialogue:
31 | fields = turn.split("\t")
32 | context = " ".join(tokenize(fields[0])[1:])
33 | response = fields[1]
34 | candidates = fields[-1].split("|")
35 | random.shuffle(candidates)
36 | label = candidates.index(response)
37 |
38 | context_history.append(context)
39 | # (context, candidates, label, partner's persona, your persona)
40 | context_candidates.append( [" _eos_ ".join(context_history) + " _eos_",
41 | "|".join(candidates),
42 | str(label),
43 | "NA",
44 | "NA"] )
45 | context_history.append(response)
46 |
47 | print("{} is composed of {} context-candidates".format(fname, len(context_candidates)))
48 |
49 | with open(os.path.join(output_path, "processed_{}".format(fname)), "w") as f:
50 | print("Saving dataset to processed_{} ...".format(fname))
51 | for dialogue in context_candidates:
52 | f.write(("\t".join(dialogue) + "\n").encode('utf-8'))
53 |
54 |
55 | def data_process_self(input_path, output_path, fname):
56 |
57 | dialogues = []
58 | dialogue = []
59 | with open(os.path.join(input_path, fname), "r") as f:
60 | for line in f:
61 | line = line.decode('utf-8').strip()
62 | if line.split()[0] == "1": # new dialogue
63 | dialogues.append(dialogue)
64 | dialogue = []
65 | dialogue.append(line)
66 |
67 | dialogues.append(dialogue)
68 | dialogues.remove([])
69 | print("{} is composed of {} dialogues".format(fname, len(dialogues)))
70 |
71 | context_candidates = []
72 | for dialogue in dialogues:
73 | persona = []
74 | context_history = []
75 | for line in dialogue:
76 | fields = line.strip().split("\t")
77 |
78 | if len(fields) == 1:
79 | persona.append((" ").join(tokenize(fields[0])[4:]))
80 | if len(fields) == 4:
81 | context = " ".join(tokenize(fields[0])[1:])
82 | response = fields[1]
83 | candidates = fields[-1].split("|")
84 | random.shuffle(candidates)
85 | label = candidates.index(response)
86 |
87 | context_history.append(context)
88 | # (context, candidates, label, partner's persona, your persona)
89 | context_candidates.append( [" _eos_ ".join(context_history) + " _eos_",
90 | "|".join(candidates),
91 | str(label),
92 | "NA",
93 | "|".join(persona)] )
94 | context_history.append(response)
95 | print("{} is composed of {} context-candidates".format(fname, len(context_candidates)))
96 |
97 | with open(os.path.join(output_path, "processed_{}".format(fname)), "w") as f:
98 | print("Saving dataset to processed_{} ...".format(fname))
99 | for dialogue in context_candidates:
100 | f.write(("\t".join(dialogue) + "\n").encode('utf-8'))
101 |
102 |
103 | def data_process_other(input_path, output_path, fname):
104 |
105 | dialogues = []
106 | dialogue = []
107 | with open(os.path.join(input_path, fname), "r") as f:
108 | for line in f:
109 | line = line.decode('utf-8').strip()
110 | if line.split()[0] == "1": # new dialogue
111 | dialogues.append(dialogue)
112 | dialogue = []
113 | dialogue.append(line)
114 |
115 | dialogues.append(dialogue)
116 | dialogues.remove([])
117 | print("{} is composed of {} dialogues".format(fname, len(dialogues)))
118 |
119 | context_candidates = []
120 | for dialogue in dialogues:
121 | persona = []
122 | context_history = []
123 | for line in dialogue:
124 | fields = line.strip().split("\t")
125 |
126 | if len(fields) == 1:
127 | persona.append((" ").join(tokenize(fields[0])[6:]))
128 | if len(fields) == 4:
129 | context = " ".join(tokenize(fields[0])[1:])
130 | response = fields[1]
131 | candidates = fields[-1].split("|")
132 | random.shuffle(candidates)
133 | label = candidates.index(response)
134 |
135 | context_history.append(context)
136 | # (context, candidates, label, partner's persona, your persona)
137 | context_candidates.append( [" _eos_ ".join(context_history) + " _eos_",
138 | "|".join(candidates),
139 | str(label),
140 | "|".join(persona),
141 | "NA"] )
142 | context_history.append(response)
143 | print("{} is composed of {} context-candidates".format(fname, len(context_candidates)))
144 |
145 | with open(os.path.join(output_path, "processed_{}".format(fname)), "w") as f:
146 | print("Saving dataset to processed_{} ...".format(fname))
147 | for dialogue in context_candidates:
148 | f.write(("\t".join(dialogue) + "\n").encode('utf-8'))
149 |
150 |
151 | def data_process_both(input_path, output_path, fname):
152 |
153 | dialogues = []
154 | dialogue = []
155 | with open(os.path.join(input_path, fname), "r") as f:
156 | for line in f:
157 | line = line.decode('utf-8').strip()
158 | if line.split()[0] == "1": # new dialogue
159 | dialogues.append(dialogue)
160 | dialogue = []
161 | dialogue.append(line)
162 |
163 | dialogues.append(dialogue)
164 | dialogues.remove([])
165 | print("{} is composed of {} dialogues".format(fname, len(dialogues)))
166 |
167 | context_candidates = []
168 | for dialogue in dialogues:
169 | self_persons = []
170 | other_persona = []
171 | context_history = []
172 | for line in dialogue:
173 | fields = line.strip().split("\t")
174 |
175 | if len(fields) == 1:
176 | if fields[0].split()[1] == "your":
177 | self_persons.append((" ").join(tokenize(fields[0])[4:]))
178 | if fields[0].split()[1] == "partner's":
179 | other_persona.append((" ").join(tokenize(fields[0])[6:]))
180 | if len(fields) == 4:
181 | context = " ".join(tokenize(fields[0])[1:])
182 | response = fields[1]
183 | candidates = fields[-1].split("|")
184 | random.shuffle(candidates)
185 | label = candidates.index(response)
186 |
187 | context_history.append(context)
188 | # (context, candidates, label, partner's persona, your persona)
189 | context_candidates.append( [" _eos_ ".join(context_history) + " _eos_",
190 | "|".join(candidates),
191 | str(label),
192 | "|".join(other_persona),
193 | "|".join(self_persons)] )
194 | context_history.append(response)
195 | print("{} is composed of {} context-candidates".format(fname, len(context_candidates)))
196 |
197 | with open(os.path.join(output_path, "processed_{}".format(fname)), "w") as f:
198 | print("Saving dataset to processed_{} ...".format(fname))
199 | for dialogue in context_candidates:
200 | f.write(("\t".join(dialogue) + "\n").encode('utf-8'))
201 |
202 | if __name__ == '__main__':
203 |
204 | input_path = "./personachat"
205 | output_path = "./personachat_processed"
206 |
207 | if not os.path.exists(output_path):
208 | os.makedirs(output_path)
209 |
210 | files = [file for file in os.listdir(input_path)]
211 | files_none = [file for file in files if file.split("_")[1] == "none"]
212 | files_self = [file for file in files if file.split("_")[1] == "self"]
213 | files_other = [file for file in files if file.split("_")[1] == "other"]
214 | files_both = [file for file in files if file.split("_")[1] == "both"]
215 |
216 | print("There are {} files to process.\nStart processing data ...".format(len(files)))
217 |
218 | for file in files_none:
219 | print("Preprocessing {} ...".format(file))
220 | data_process_none(input_path, output_path, file)
221 | print("="*60)
222 |
223 | for file in files_self:
224 | print("Preprocessing {} ...".format(file))
225 | data_process_self(input_path, output_path, file)
226 | print("="*60)
227 |
228 | for file in files_other:
229 | print("Preprocessing {} ...".format(file))
230 | data_process_other(input_path, output_path, file)
231 | print("="*60)
232 |
233 | for file in files_both:
234 | print("Preprocessing {} ...".format(file))
235 | data_process_both(input_path, output_path, file)
236 | print("="*60)
237 |
238 | print("data preprocess done!")
239 |
--------------------------------------------------------------------------------
/image/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JasonForJoy/DIM/6faa4c61f57c28cfbd9a4ade52c29490a6e40d06/image/model.png
--------------------------------------------------------------------------------
/image/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JasonForJoy/DIM/6faa4c61f57c28cfbd9a4ade52c29490a6e40d06/image/result.png
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/model/data_helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 |
4 |
5 | def load_vocab(fname):
6 | '''
7 | vocab = {"I": 0, ...}
8 | '''
9 | vocab={}
10 | with open(fname, 'rt') as f:
11 | for i,line in enumerate(f):
12 | word = line.decode('utf-8').strip()
13 | vocab[word] = i
14 | return vocab
15 |
16 | def load_char_vocab(fname):
17 | '''
18 | charVocab = {"U": 0, "!": 1, ...}
19 | '''
20 | charVocab={}
21 | with open(fname, 'rt') as f:
22 | for line in f:
23 | fields = line.strip().split('\t')
24 | char_id = int(fields[0])
25 | ch = fields[1]
26 | charVocab[ch] = char_id
27 | return charVocab
28 |
29 | def to_vec(tokens, vocab, maxlen):
30 | '''
31 | length: length of the input sequence
32 | vec: map the token to the vocab_id, return a varied-length array [3, 6, 4, 3, ...]
33 | '''
34 | n = len(tokens)
35 | length = 0
36 | vec=[]
37 | for i in range(n):
38 | length += 1
39 | if tokens[i] in vocab:
40 | vec.append(vocab[tokens[i]])
41 | else:
42 | vec.append(vocab["fiance"]) # fix to fiance
43 | return length, np.array(vec)
44 |
45 | def load_dataset(fname, vocab, max_utter_num, max_utter_len, max_response_len, max_persona_len):
46 |
47 | dataset=[]
48 | with open(fname, 'rt') as f:
49 | for us_id, line in enumerate(f):
50 | line = line.decode('utf-8').strip()
51 | fields = line.split('\t')
52 |
53 | # context utterances
54 | context = fields[0]
55 | utterances = (context + " ").split(' _eos_ ')[:-1]
56 | utterances = [utterance + " _eos_" for utterance in utterances]
57 | utterances = utterances[-max_utter_num:] # select the last max_utter_num utterances
58 | us_tokens = []
59 | us_vec = []
60 | us_len = []
61 | for utterance in utterances:
62 | u_tokens = utterance.split(' ')[:max_utter_len] # select the head max_utter_len tokens in every utterance
63 | u_len, u_vec = to_vec(u_tokens, vocab, max_utter_len)
64 | us_tokens.append(u_tokens)
65 | us_vec.append(u_vec)
66 | us_len.append(u_len)
67 | us_num = len(utterances)
68 |
69 | # responses
70 | responses = fields[1].split("|")
71 | rs_tokens = []
72 | rs_vec = []
73 | rs_len = []
74 | for response in responses:
75 | r_tokens = response.split(' ')[:max_response_len] # select the head max_response_len tokens in every candidate
76 | r_len, r_vec = to_vec(r_tokens, vocab, max_response_len)
77 | rs_tokens.append(r_tokens)
78 | rs_vec.append(r_vec)
79 | rs_len.append(r_len)
80 |
81 | # label
82 | label = int(fields[2])
83 |
84 | # other persona
85 | if fields[3] != "NA" and fields[4] == "NA":
86 | personas = fields[3].split("|")
87 | ps_tokens = []
88 | ps_vec = []
89 | ps_len = []
90 | for persona in personas:
91 | p_tokens = persona.split(' ')[:max_persona_len] # select the head max_persona_len tokens in every persona
92 | p_len, p_vec = to_vec(p_tokens, vocab, max_persona_len)
93 | ps_tokens.append(p_tokens)
94 | ps_vec.append(p_vec)
95 | ps_len.append(p_len)
96 | ps_num = len(personas)
97 |
98 | # self persona
99 | if fields[3] == "NA" and fields[4] != "NA":
100 | personas = fields[4].split("|")
101 | ps_tokens = []
102 | ps_vec = []
103 | ps_len = []
104 | for persona in personas:
105 | p_tokens = persona.split(' ')[:max_persona_len] # select the head max_persona_len tokens in every persona
106 | p_len, p_vec = to_vec(p_tokens, vocab, max_persona_len)
107 | ps_tokens.append(p_tokens)
108 | ps_vec.append(p_vec)
109 | ps_len.append(p_len)
110 | ps_num = len(personas)
111 |
112 | dataset.append((us_id, us_tokens, us_vec, us_len, us_num, rs_tokens, rs_vec, rs_len, label, ps_tokens, ps_vec, ps_len, ps_num))
113 |
114 | return dataset
115 |
116 |
117 | def normalize_vec(vec, maxlen):
118 | '''
119 | pad the original vec to the same maxlen
120 | [3, 4, 7] maxlen=5 --> [3, 4, 7, 0, 0]
121 | '''
122 | if len(vec) == maxlen:
123 | return vec
124 |
125 | new_vec = np.zeros(maxlen, dtype='int32')
126 | for i in range(len(vec)):
127 | new_vec[i] = vec[i]
128 | return new_vec
129 |
130 |
131 | def charVec(tokens, charVocab, maxlen, maxWordLength):
132 | '''
133 | chars = np.array( (maxlen, maxWordLength) ) 0 if not found in charVocab or None
134 | word_lengths = np.array( maxlen ) 1 if None
135 | '''
136 | n = len(tokens)
137 | if n > maxlen:
138 | n = maxlen
139 |
140 | chars = np.zeros((maxlen, maxWordLength), dtype=np.int32)
141 | word_lengths = np.ones(maxlen, dtype=np.int32)
142 | for i in range(n):
143 | token = tokens[i][:maxWordLength]
144 | word_lengths[i] = len(token)
145 | row = chars[i]
146 | for idx, ch in enumerate(token):
147 | if ch in charVocab:
148 | row[idx] = charVocab[ch]
149 |
150 | return chars, word_lengths
151 |
152 |
153 | def batch_iter(data, batch_size, num_epochs, max_utter_num, max_utter_len, max_response_num, max_response_len,
154 | max_persona_num, max_persona_len, charVocab, max_word_length, shuffle=True):
155 | """
156 | Generates a batch iterator for a dataset.
157 | """
158 | data_size = len(data)
159 | num_batches_per_epoch = int(len(data)/batch_size) + 1
160 | for epoch in range(num_epochs):
161 | # Shuffle the data at each epoch
162 | if shuffle:
163 | random.shuffle(data)
164 | for batch_num in range(num_batches_per_epoch):
165 | start_index = batch_num * batch_size
166 | end_index = min((batch_num + 1) * batch_size, data_size)
167 |
168 | x_utterances = []
169 | x_utterances_len = []
170 | x_responses = []
171 | x_responses_len = []
172 |
173 | x_labels = []
174 | x_ids = []
175 | x_utterances_num = []
176 |
177 | x_utterances_char=[]
178 | x_utterances_char_len=[]
179 | x_responses_char=[]
180 | x_responses_char_len=[]
181 |
182 | x_personas = []
183 | x_personas_len = []
184 | x_personas_char=[]
185 | x_personas_char_len=[]
186 | x_personas_num = []
187 |
188 | for rowIdx in range(start_index, end_index):
189 | us_id, us_tokens, us_vec, us_len, us_num, rs_tokens, rs_vec, rs_len, label, ps_tokens, ps_vec, ps_len, ps_num = data[rowIdx]
190 |
191 | # normalize us_vec and us_len
192 | new_utters_vec = np.zeros((max_utter_num, max_utter_len), dtype='int32')
193 | new_utters_len = np.zeros((max_utter_num, ), dtype='int32')
194 | for i in range(len(us_len)):
195 | new_utter_vec = normalize_vec(us_vec[i], max_utter_len)
196 | new_utters_vec[i] = new_utter_vec
197 | new_utters_len[i] = us_len[i]
198 | x_utterances.append(new_utters_vec)
199 | x_utterances_len.append(new_utters_len)
200 |
201 | # normalize rs_vec and rs_len
202 | new_responses_vec = np.zeros((max_response_num, max_response_len), dtype='int32')
203 | new_responses_len = np.zeros((max_response_num, ), dtype='int32')
204 | for i in range(len(rs_len)):
205 | new_response_vec = normalize_vec(rs_vec[i], max_response_len)
206 | new_responses_vec[i] = new_response_vec
207 | new_responses_len[i] = rs_len[i]
208 | x_responses.append(new_responses_vec)
209 | x_responses_len.append(new_responses_len)
210 |
211 | x_labels.append(label)
212 | x_ids.append(us_id)
213 | x_utterances_num.append(us_num)
214 |
215 | # normalize us_CharVec and us_CharLen
216 | uttersCharVec = np.zeros((max_utter_num, max_utter_len, max_word_length), dtype='int32')
217 | uttersCharLen = np.ones((max_utter_num, max_utter_len), dtype='int32')
218 | for i in range(len(us_len)):
219 | utterCharVec, utterCharLen = charVec(us_tokens[i], charVocab, max_utter_len, max_word_length)
220 | uttersCharVec[i] = utterCharVec
221 | uttersCharLen[i] = utterCharLen
222 | x_utterances_char.append(uttersCharVec)
223 | x_utterances_char_len.append(uttersCharLen)
224 |
225 | # normalize rs_CharVec and rs_CharLen
226 | rsCharVec = np.zeros((max_response_num, max_response_len, max_word_length), dtype='int32')
227 | rsCharLen = np.ones((max_response_num, max_response_len), dtype='int32')
228 | for i in range(len(us_len)):
229 | rCharVec, rCharLen = charVec(rs_tokens[i], charVocab, max_response_len, max_word_length)
230 | rsCharVec[i] = rCharVec
231 | rsCharLen[i] = rCharLen
232 | x_responses_char.append(rsCharVec)
233 | x_responses_char_len.append(rsCharLen)
234 |
235 | # normalize ps_vec and ps_len
236 | new_personas_vec = np.zeros((max_persona_num, max_persona_len), dtype='int32')
237 | new_personas_len = np.zeros((max_persona_num, ), dtype='int32')
238 | for i in range(len(ps_len)):
239 | new_persona_vec = normalize_vec(ps_vec[i], max_persona_len)
240 | new_personas_vec[i] = new_persona_vec
241 | new_personas_len[i] = ps_len[i]
242 | x_personas.append(new_personas_vec)
243 | x_personas_len.append(new_personas_len)
244 |
245 | # normalize ps_CharVec and ps_CharLen
246 | psCharVec = np.zeros((max_persona_num, max_persona_len, max_word_length), dtype='int32')
247 | psCharLen = np.ones((max_persona_num, max_persona_len), dtype='int32')
248 | for i in range(len(ps_len)):
249 | pCharVec, pCharLen = charVec(ps_tokens[i], charVocab, max_persona_len, max_word_length)
250 | psCharVec[i] = pCharVec
251 | psCharLen[i] = pCharLen
252 | x_personas_char.append(psCharVec)
253 | x_personas_char_len.append(psCharLen)
254 |
255 | x_personas_num.append(ps_num)
256 |
257 | yield np.array(x_utterances), np.array(x_utterances_len), np.array(x_responses), np.array(x_responses_len), \
258 | np.array(x_utterances_num), np.array(x_labels), x_ids, \
259 | np.array(x_utterances_char), np.array(x_utterances_char_len), np.array(x_responses_char), np.array(x_responses_char_len), \
260 | np.array(x_personas), np.array(x_personas_len), np.array(x_personas_char), np.array(x_personas_char_len), np.array(x_personas_num)
261 |
262 |
--------------------------------------------------------------------------------
/model/eval.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import os
4 | import time
5 | import datetime
6 | import operator
7 | import metrics
8 | from collections import defaultdict
9 | from model import data_helpers
10 |
11 | # Files
12 | tf.flags.DEFINE_string("test_file", "", "path to test file")
13 | tf.flags.DEFINE_string("vocab_file", "", "vocabulary file")
14 | tf.flags.DEFINE_string("char_vocab_file", "", "vocabulary file")
15 | tf.flags.DEFINE_string("output_file", "", "prediction output file")
16 |
17 | # Model Hyperparameters
18 | tf.flags.DEFINE_integer("max_utter_num", 15, "max utterance number")
19 | tf.flags.DEFINE_integer("max_utter_len", 20, "max utterance length")
20 | tf.flags.DEFINE_integer("max_response_num", 20, "max response candidate number")
21 | tf.flags.DEFINE_integer("max_response_len", 20, "max response length")
22 | tf.flags.DEFINE_integer("max_persona_num", 5, "max persona number")
23 | tf.flags.DEFINE_integer("max_persona_len", 15, "max persona length")
24 | tf.flags.DEFINE_integer("max_word_length", 18, "max word length")
25 |
26 | # Test parameters
27 | tf.flags.DEFINE_integer("batch_size", 32, "Batch Size (default: 64)")
28 | tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run")
29 |
30 | # Misc Parameters
31 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
32 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
33 |
34 | FLAGS = tf.flags.FLAGS
35 | FLAGS._parse_flags()
36 | print("\nParameters:")
37 | for attr, value in sorted(FLAGS.__flags.items()):
38 | print("{}={}".format(attr.upper(), value))
39 | print("")
40 |
41 | vocab = data_helpers.load_vocab(FLAGS.vocab_file)
42 | print('vocabulary size: {}'.format(len(vocab)))
43 | charVocab = data_helpers.load_char_vocab(FLAGS.char_vocab_file)
44 | print('charVocab size: {}'.format(len(charVocab)))
45 |
46 | test_dataset = data_helpers.load_dataset(FLAGS.test_file, vocab, FLAGS.max_utter_num, FLAGS.max_utter_len, FLAGS.max_response_len, FLAGS.max_persona_len)
47 | print('test dataset size: {}'.format(len(test_dataset)))
48 |
49 | print("\nEvaluating...\n")
50 |
51 | checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
52 | print(checkpoint_file)
53 |
54 | graph = tf.Graph()
55 | with graph.as_default():
56 | session_conf = tf.ConfigProto(
57 | allow_soft_placement=FLAGS.allow_soft_placement,
58 | log_device_placement=FLAGS.log_device_placement)
59 | sess = tf.Session(config=session_conf)
60 | with sess.as_default():
61 | # Load the saved meta graph and restore variables
62 | saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
63 | saver.restore(sess, checkpoint_file)
64 |
65 | # Get the placeholders from the graph by name
66 | utterances = graph.get_operation_by_name("utterances").outputs[0]
67 | utterances_len = graph.get_operation_by_name("utterances_len").outputs[0]
68 |
69 | responses = graph.get_operation_by_name("responses").outputs[0]
70 | responses_len = graph.get_operation_by_name("responses_len").outputs[0]
71 |
72 | utterances_num = graph.get_operation_by_name("utterances_num").outputs[0]
73 | dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
74 |
75 | u_char_feature = graph.get_operation_by_name("utterances_char").outputs[0]
76 | u_char_len = graph.get_operation_by_name("utterances_char_len").outputs[0]
77 |
78 | r_char_feature = graph.get_operation_by_name("responses_char").outputs[0]
79 | r_char_len = graph.get_operation_by_name("responses_char_len").outputs[0]
80 |
81 | personas = graph.get_operation_by_name("personas").outputs[0]
82 | personas_len = graph.get_operation_by_name("personas_len").outputs[0]
83 | p_char_feature = graph.get_operation_by_name("personas_char").outputs[0]
84 | p_char_len = graph.get_operation_by_name("personas_char_len").outputs[0]
85 | personas_num = graph.get_operation_by_name("personas_num").outputs[0]
86 |
87 | # Tensors we want to evaluate
88 | pred_prob = graph.get_operation_by_name("prediction_layer/prob").outputs[0]
89 |
90 | results = defaultdict(list)
91 | num_test = 0
92 | test_batches = data_helpers.batch_iter(test_dataset, FLAGS.batch_size, 1, FLAGS.max_utter_num, FLAGS.max_utter_len, \
93 | FLAGS.max_response_num, FLAGS.max_response_len, FLAGS.max_persona_num, FLAGS.max_persona_len, \
94 | charVocab, FLAGS.max_word_length, shuffle=False)
95 | for test_batch in test_batches:
96 | x_utterances, x_utterances_len, x_response, x_response_len,\
97 | x_utters_num, x_target, x_ids, \
98 | x_u_char, x_u_char_len, x_r_char, x_r_char_len, \
99 | x_personas, x_personas_len, x_p_char, x_p_char_len, x_personas_num = test_batch
100 | feed_dict = {
101 | utterances: x_utterances,
102 | utterances_len: x_utterances_len,
103 | responses: x_response,
104 | responses_len: x_response_len,
105 | utterances_num: x_utters_num,
106 | dropout_keep_prob: 1.0,
107 | u_char_feature: x_u_char,
108 | u_char_len: x_u_char_len,
109 | r_char_feature: x_r_char,
110 | r_char_len: x_r_char_len,
111 | personas: x_personas,
112 | personas_len: x_personas_len,
113 | p_char_feature: x_p_char,
114 | p_char_len: x_p_char_len,
115 | personas_num: x_personas_num
116 | }
117 | predicted_prob = sess.run(pred_prob, feed_dict)
118 | num_test += len(predicted_prob)
119 | print('num_test_sample={}'.format(num_test))
120 |
121 | for i in range(len(predicted_prob)):
122 | probs = predicted_prob[i]
123 | us_id = x_ids[i]
124 | label = x_target[i]
125 | labels = np.zeros(FLAGS.max_response_num)
126 | labels[label] = 1
127 | for r_id, prob in enumerate(probs):
128 | results[us_id].append((str(r_id), labels[r_id], prob))
129 |
130 | accu, precision, recall, f1, loss = metrics.classification_metrics(results)
131 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss))
132 |
133 | mvp = metrics.mean_average_precision(results)
134 | mrr = metrics.mean_reciprocal_rank(results)
135 | top_1_precision = metrics.top_1_precision(results)
136 | total_valid_query = metrics.get_num_valid_query(results)
137 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tTop-1 precision: {}\tNum_query: {}'.format(mvp, mrr, top_1_precision, total_valid_query))
138 |
139 | out_path = FLAGS.output_file
140 | print("Saving evaluation to {}".format(out_path))
141 | with open(out_path, 'w') as f:
142 | f.write("query_id\tdocument_id\tscore\trank\trelevance\n")
143 | for us_id, v in results.items():
144 | v.sort(key=operator.itemgetter(2), reverse=True)
145 | for i, rec in enumerate(v):
146 | r_id, label, prob_score = rec
147 | rank = i+1
148 | f.write('{}\t{}\t{}\t{}\t{}\n'.format(us_id, r_id, prob_score, rank, label))
149 |
--------------------------------------------------------------------------------
/model/metrics.py:
--------------------------------------------------------------------------------
1 | import operator
2 | import math
3 |
4 | def is_valid_query(v):
5 | num_pos = 0
6 | num_neg = 0
7 | for aid, label, score in v:
8 | if label > 0:
9 | num_pos += 1
10 | else:
11 | num_neg += 1
12 | if num_pos > 0 and num_neg > 0:
13 | return True
14 | else:
15 | return False
16 |
17 | def get_num_valid_query(results):
18 | num_query = 0
19 | for k, v in results.items():
20 | if not is_valid_query(v):
21 | continue
22 | num_query += 1
23 | return num_query
24 |
25 | def top_1_precision(results):
26 | num_query = 0
27 | top_1_correct = 0.0
28 | for k, v in results.items():
29 | if not is_valid_query(v):
30 | continue
31 | num_query += 1
32 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True)
33 | aid, label, score = sorted_v[0]
34 | if label > 0:
35 | top_1_correct += 1
36 |
37 | if num_query > 0:
38 | return top_1_correct/num_query
39 | else:
40 | return 0.0
41 |
42 | def mean_reciprocal_rank(results):
43 | num_query = 0
44 | mrr = 0.0
45 | for k, v in results.items():
46 | if not is_valid_query(v):
47 | continue
48 |
49 | num_query += 1
50 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True)
51 | for i, rec in enumerate(sorted_v):
52 | aid, label, score = rec
53 | if label > 0:
54 | mrr += 1.0/(i+1)
55 | break
56 |
57 | if num_query == 0:
58 | return 0.0
59 | else:
60 | mrr = mrr/num_query
61 | return mrr
62 |
63 | def mean_average_precision(results):
64 | num_query = 0
65 | mvp = 0.0
66 | for k, v in results.items():
67 | if not is_valid_query(v):
68 | continue
69 |
70 | num_query += 1
71 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True)
72 | num_relevant_doc = 0.0
73 | avp = 0.0
74 | for i, rec in enumerate(sorted_v):
75 | aid, label, score = rec
76 | if label == 1:
77 | num_relevant_doc += 1
78 | precision = num_relevant_doc/(i+1)
79 | avp += precision
80 | avp = avp/num_relevant_doc
81 | mvp += avp
82 |
83 | if num_query == 0:
84 | return 0.0
85 | else:
86 | mvp = mvp/num_query
87 | return mvp
88 |
89 | def classification_metrics(results):
90 | total_num = 0
91 | total_correct = 0
92 | true_positive = 0
93 | positive_correct = 0
94 | predicted_positive = 0
95 |
96 | loss = 0.0;
97 | for k, v in results.items():
98 | for rec in v:
99 | total_num += 1
100 | aid, label, score = rec
101 |
102 |
103 | if score > 0.5:
104 | predicted_positive += 1
105 |
106 | if label > 0:
107 | true_positive += 1
108 | loss += -math.log(score+1e-12)
109 | else:
110 | loss += -math.log(1.0 - score + 1e-12);
111 |
112 | if score > 0.5 and label > 0:
113 | total_correct += 1
114 | positive_correct += 1
115 |
116 | if score < 0.5 and label < 0.5:
117 | total_correct += 1
118 |
119 | accuracy = float(total_correct)/total_num
120 | precision = float(positive_correct)/(predicted_positive+1e-12)
121 | recall = float(positive_correct)/true_positive
122 | F1 = 2.0 * precision * recall/(1e-12+precision + recall)
123 | return accuracy, precision, recall, F1, loss/total_num;
124 |
--------------------------------------------------------------------------------
/model/model_DIM.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | FLAGS = tf.flags.FLAGS
5 |
6 | def get_embeddings(vocab):
7 | print("get_embedding")
8 | initializer = load_word_embeddings(vocab, FLAGS.embedding_dim)
9 | return tf.constant(initializer, name="word_embedding")
10 |
11 | def get_char_embedding(charVocab):
12 | print("get_char_embedding")
13 | char_size = len(charVocab)
14 | embeddings = np.zeros((char_size, char_size), dtype='float32')
15 | for i in range(1, char_size):
16 | embeddings[i, i] = 1.0
17 | return tf.constant(embeddings, name="word_char_embedding")
18 |
19 | def load_embed_vectors(fname, dim):
20 | vectors = {}
21 | for line in open(fname, 'rt'):
22 | items = line.strip().split(' ')
23 | if len(items[0]) <= 0:
24 | continue
25 | vec = [float(items[i]) for i in range(1, dim+1)]
26 | vectors[items[0]] = vec
27 | return vectors
28 |
29 | def load_word_embeddings(vocab, dim):
30 | vectors = load_embed_vectors(FLAGS.embedded_vector_file, dim)
31 | vocab_size = len(vocab)
32 | embeddings = np.zeros((vocab_size, dim), dtype='float32')
33 | for word, code in vocab.items():
34 | if word in vectors:
35 | embeddings[code] = vectors[word]
36 | #else:
37 | # embeddings[code] = np.random.uniform(-0.25, 0.25, dim)
38 | return embeddings
39 |
40 |
41 | def lstm_layer(inputs, input_seq_len, rnn_size, dropout_keep_prob, scope, scope_reuse=False):
42 | with tf.variable_scope(scope, reuse=scope_reuse) as vs:
43 | fw_cell = tf.contrib.rnn.LSTMCell(rnn_size, forget_bias=1.0, state_is_tuple=True, reuse=scope_reuse)
44 | fw_cell = tf.contrib.rnn.DropoutWrapper(fw_cell, output_keep_prob=dropout_keep_prob)
45 | bw_cell = tf.contrib.rnn.LSTMCell(rnn_size, forget_bias=1.0, state_is_tuple=True, reuse=scope_reuse)
46 | bw_cell = tf.contrib.rnn.DropoutWrapper(bw_cell, output_keep_prob=dropout_keep_prob)
47 | rnn_outputs, rnn_states = tf.nn.bidirectional_dynamic_rnn(cell_fw=fw_cell, cell_bw=bw_cell,
48 | inputs=inputs,
49 | sequence_length=input_seq_len,
50 | dtype=tf.float32)
51 | return rnn_outputs, rnn_states
52 |
53 | def cnn_layer(inputs, filter_sizes, num_filters, scope=None, scope_reuse=False):
54 | with tf.variable_scope(scope, reuse=scope_reuse):
55 | input_size = inputs.get_shape()[2].value
56 |
57 | outputs = []
58 | for i, filter_size in enumerate(filter_sizes):
59 | with tf.variable_scope("conv_{}".format(i)):
60 | w = tf.get_variable("w", [filter_size, input_size, num_filters])
61 | b = tf.get_variable("b", [num_filters])
62 | conv = tf.nn.conv1d(inputs, w, stride=1, padding="VALID") # [num_words, num_chars - filter_size, num_filters]
63 | h = tf.nn.relu(tf.nn.bias_add(conv, b)) # [num_words, num_chars - filter_size, num_filters]
64 | pooled = tf.reduce_max(h, 1) # [num_words, num_filters]
65 | outputs.append(pooled)
66 | return tf.concat(outputs, 1) # [num_words, num_filters * len(filter_sizes)]
67 |
68 |
69 | def attended_response(similarity_matrix, contexts, flattened_utters_len, max_utter_len, max_utter_num):
70 | # similarity_matrix: [batch_size, max_response_num, max_response_len, max_utter_num*max_utter_len]
71 | # contexts: [batch_size, max_utter_num*max_utter_len, dim]
72 | # flattened_utters_len: [batch_size* max_utter_num, ]
73 | max_response_num = similarity_matrix.get_shape()[1].value
74 |
75 | # masked similarity_matrix
76 | mask_c = tf.sequence_mask(flattened_utters_len, max_utter_len, dtype=tf.float32) # [batch_size*max_utter_num, max_utter_len]
77 | mask_c = tf.reshape(mask_c, [-1, max_utter_num*max_utter_len]) # [batch_size, max_utter_num*max_utter_len]
78 | mask_c = tf.expand_dims(mask_c, 1) # [batch_size, 1, max_utter_num*max_utter_len]
79 | mask_c = tf.expand_dims(mask_c, 2) # [batch_size, 1, 1, max_utter_num*max_utter_len]
80 | similarity_matrix = similarity_matrix * mask_c + -1e9 * (1-mask_c) # [batch_size, max_response_num, response_len, max_utter_num*max_utter_len]
81 |
82 | attention_weight_for_c = tf.nn.softmax(similarity_matrix, dim=-1) # [batch_size, max_response_num, response_len, max_utter_num*max_utter_len]
83 | contexts_tiled = tf.tile(tf.expand_dims(contexts, 1), [1, max_response_num, 1, 1])# [batch_size, max_response_num, max_utter_num*max_utter_len, dim]
84 | attended_response = tf.matmul(attention_weight_for_c, contexts_tiled) # [batch_size, max_response_num, response_len, dim]
85 |
86 | return attended_response
87 |
88 | def attended_context(similarity_matrix, responses, flattened_responses_len, max_response_len, max_response_num):
89 | # similarity_matrix: [batch_size, max_response_num, max_response_len, max_utter_num*max_utter_len]
90 | # responses: [batch_size, max_response_num, max_response_len, dim]
91 | # flattened_responses_len: [batch_size* max_response_num, ]
92 |
93 | # masked similarity_matrix
94 | mask_r = tf.sequence_mask(flattened_responses_len, max_response_len, dtype=tf.float32) # [batch_size*max_response_num, max_response_len]
95 | mask_r = tf.reshape(mask_r, [-1, max_response_num, max_response_len]) # [batch_size, max_response_num, max_response_len]
96 | mask_r = tf.expand_dims(mask_r, -1) # [batch_size, max_response_num, max_response_len, 1]
97 | similarity_matrix = similarity_matrix * mask_r + -1e9 * (1-mask_r) # [batch_size, max_response_num, max_response_len, max_utter_num*max_utter_len]
98 |
99 | attention_weight_for_r = tf.nn.softmax(tf.transpose(similarity_matrix, perm=[0,1,3,2]), dim=-1) # [batch_size, max_response_num, max_utter_num*max_utter_len, response_len]
100 | attended_context = tf.matmul(attention_weight_for_r, responses) # [batch_size, max_response_num, max_utter_num*max_utter_len, dim]
101 |
102 | return attended_context
103 |
104 |
105 | class DIM(object):
106 | def __init__(
107 | self, max_utter_num, max_utter_len, max_response_num, max_response_len, max_persona_num, max_persona_len,
108 | vocab_size, embedding_size, vocab, rnn_size, maxWordLength, charVocab, l2_reg_lambda=0.0):
109 |
110 | self.utterances = tf.placeholder(tf.int32, [None, max_utter_num, max_utter_len], name="utterances")
111 | self.utterances_len = tf.placeholder(tf.int32, [None, max_utter_num], name="utterances_len")
112 | self.utters_num = tf.placeholder(tf.int32, [None], name="utterances_num")
113 |
114 | self.responses = tf.placeholder(tf.int32, [None, max_response_num, max_response_len], name="responses")
115 | self.responses_len = tf.placeholder(tf.int32, [None, max_response_num], name="responses_len")
116 |
117 | self.personas = tf.placeholder(tf.int32, [None, max_persona_num, max_persona_len], name="personas")
118 | self.personas_len = tf.placeholder(tf.int32, [None, max_persona_num], name="personas_len")
119 | self.personas_num = tf.placeholder(tf.int32, [None], name="personas_num")
120 |
121 | self.target = tf.placeholder(tf.int64, [None], name="target")
122 | self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")
123 |
124 | self.u_charVec = tf.placeholder(tf.int32, [None, max_utter_num, max_utter_len, maxWordLength], name="utterances_char")
125 | self.u_charLen = tf.placeholder(tf.int32, [None, max_utter_num, max_utter_len], name="utterances_char_len")
126 |
127 | self.r_charVec = tf.placeholder(tf.int32, [None, max_response_num, max_response_len, maxWordLength], name="responses_char")
128 | self.r_charLen = tf.placeholder(tf.int32, [None, max_response_num, max_response_len], name="responses_char_len")
129 |
130 | self.p_charVec = tf.placeholder(tf.int32, [None, max_persona_num, max_persona_len, maxWordLength], name="personas_char")
131 | self.p_charLen = tf.placeholder(tf.int32, [None, max_persona_num, max_persona_len], name="personas_char_len")
132 |
133 | l2_loss = tf.constant(1.0)
134 |
135 | # =============================== Embedding layer ===============================
136 | # word embedding
137 | with tf.name_scope("embedding"):
138 | W = get_embeddings(vocab)
139 | utterances_embedded = tf.nn.embedding_lookup(W, self.utterances) # [batch_size, max_utter_num, max_utter_len, word_dim]
140 | responses_embedded = tf.nn.embedding_lookup(W, self.responses) # [batch_size, max_response_num, max_response_len, word_dim]
141 | personas_embedded = tf.nn.embedding_lookup(W, self.personas) # [batch_size, max_persona_num, max_persona_len, word_dim]
142 | print("original utterances_embedded: {}".format(utterances_embedded.get_shape()))
143 | print("original responses_embedded: {}".format(responses_embedded.get_shape()))
144 | print("original personas_embedded: {}".format(personas_embedded.get_shape()))
145 |
146 | with tf.name_scope('char_embedding'):
147 | char_W = get_char_embedding(charVocab)
148 | utterances_char_embedded = tf.nn.embedding_lookup(char_W, self.u_charVec) # [batch_size, max_utter_num, max_utter_len, maxWordLength, char_dim]
149 | responses_char_embedded = tf.nn.embedding_lookup(char_W, self.r_charVec) # [batch_size, max_response_num, max_response_len, maxWordLength, char_dim]
150 | personas_char_embedded = tf.nn.embedding_lookup(char_W, self.p_charVec) # [batch_size, max_persona_num, max_persona_len, maxWordLength, char_dim]
151 | print("utterances_char_embedded: {}".format(utterances_char_embedded.get_shape()))
152 | print("responses_char_embedded: {}".format(responses_char_embedded.get_shape()))
153 | print("personas_char_embedded: {}".format(personas_char_embedded.get_shape()))
154 |
155 | char_dim = utterances_char_embedded.get_shape()[-1].value
156 | utterances_char_embedded = tf.reshape(utterances_char_embedded, [-1, maxWordLength, char_dim]) # [batch_size*max_utter_num*max_utter_len, maxWordLength, char_dim]
157 | responses_char_embedded = tf.reshape(responses_char_embedded, [-1, maxWordLength, char_dim]) # [batch_size*max_response_num*max_response_len, maxWordLength, char_dim]
158 | personas_char_embedded = tf.reshape(personas_char_embedded, [-1, maxWordLength, char_dim]) # [batch_size*max_persona_num*max_persona_len, maxWordLength, char_dim]
159 |
160 | # char embedding
161 | utterances_cnn_char_emb = cnn_layer(utterances_char_embedded, filter_sizes=[3, 4, 5], num_filters=50, scope="CNN_char_emb", scope_reuse=False) # [batch_size*max_utter_num*max_utter_len, emb]
162 | cnn_char_dim = utterances_cnn_char_emb.get_shape()[1].value
163 | utterances_cnn_char_emb = tf.reshape(utterances_cnn_char_emb, [-1, max_utter_num, max_utter_len, cnn_char_dim]) # [batch_size, max_utter_num, max_utter_len, emb]
164 |
165 | responses_cnn_char_emb = cnn_layer(responses_char_embedded, filter_sizes=[3, 4, 5], num_filters=50, scope="CNN_char_emb", scope_reuse=True) # [batch_size*max_response_num*max_response_len, emb]
166 | responses_cnn_char_emb = tf.reshape(responses_cnn_char_emb, [-1, max_response_num, max_response_len, cnn_char_dim]) # [batch_size, max_response_num, max_response_len, emb]
167 |
168 | personas_cnn_char_emb = cnn_layer(personas_char_embedded, filter_sizes=[3, 4, 5], num_filters=50, scope="CNN_char_emb", scope_reuse=True) # [batch_size*max_persona_num*max_persona_len, emb]
169 | personas_cnn_char_emb = tf.reshape(personas_cnn_char_emb, [-1, max_persona_num, max_persona_len, cnn_char_dim]) # [batch_size, max_persona_num, max_persona_len, emb]
170 |
171 | utterances_embedded = tf.concat(axis=-1, values=[utterances_embedded, utterances_cnn_char_emb]) # [batch_size, max_utter_num, max_utter_len, emb]
172 | responses_embedded = tf.concat(axis=-1, values=[responses_embedded, responses_cnn_char_emb]) # [batch_size, max_response_num, max_response_len, emb]
173 | personas_embedded = tf.concat(axis=-1, values=[personas_embedded, personas_cnn_char_emb]) # [batch_size, max_persona_num, max_persona_len, emb]
174 | utterances_embedded = tf.nn.dropout(utterances_embedded, keep_prob=self.dropout_keep_prob)
175 | responses_embedded = tf.nn.dropout(responses_embedded, keep_prob=self.dropout_keep_prob)
176 | personas_embedded = tf.nn.dropout(personas_embedded, keep_prob=self.dropout_keep_prob)
177 | print("utterances_embedded: {}".format(utterances_embedded.get_shape()))
178 | print("responses_embedded: {}".format(responses_embedded.get_shape()))
179 | print("personas_embedded: {}".format(personas_embedded.get_shape()))
180 |
181 |
182 | # =============================== Encoding layer ===============================
183 | with tf.variable_scope("encoding_layer") as vs:
184 |
185 | emb_dim = utterances_embedded.get_shape()[-1].value
186 | flattened_utterances_embedded = tf.reshape(utterances_embedded, [-1, max_utter_len, emb_dim]) # [batch_size*max_utter_num, max_utter_len, emb]
187 | flattened_utterances_len = tf.reshape(self.utterances_len, [-1]) # [batch_size*max_utter_num, ]
188 | flattened_responses_embedded = tf.reshape(responses_embedded, [-1, max_response_len, emb_dim]) # [batch_size*max_response_num, max_response_len, emb]
189 | flattened_responses_len = tf.reshape(self.responses_len, [-1]) # [batch_size*max_response_num, ]
190 | flattened_personas_embedded = tf.reshape(personas_embedded, [-1, max_persona_len, emb_dim]) # [batch_size*max_persona_num, max_persona_len, emb]
191 | flattened_personas_len = tf.reshape(self.personas_len, [-1]) # [batch_size*max_persona_num, ]
192 |
193 | rnn_scope_name = "bidirectional_rnn"
194 | u_rnn_output, u_rnn_states = lstm_layer(flattened_utterances_embedded, flattened_utterances_len, rnn_size, self.dropout_keep_prob, rnn_scope_name, scope_reuse=False)
195 | utterances_output = tf.concat(axis=2, values=u_rnn_output) # [batch_size*max_utter_num, max_utter_len, rnn_size*2]
196 | r_rnn_output, r_rnn_states = lstm_layer(flattened_responses_embedded, flattened_responses_len, rnn_size, self.dropout_keep_prob, rnn_scope_name, scope_reuse=True)
197 | responses_output = tf.concat(axis=2, values=r_rnn_output) # [batch_size*max_response_num, max_response_len, rnn_size*2]
198 | p_rnn_output, p_rnn_states = lstm_layer(flattened_personas_embedded, flattened_personas_len, rnn_size, self.dropout_keep_prob, rnn_scope_name, scope_reuse=True)
199 | personas_output = tf.concat(axis=2, values=p_rnn_output) # [batch_size*max_persona_num, max_persona_len, rnn_size*2]
200 | print("encoded utterances : {}".format(utterances_output.shape))
201 | print("encoded responses : {}".format(responses_output.shape))
202 | print("encoded personas : {}".format(personas_output.shape))
203 |
204 |
205 | # =============================== Matching layer ===============================
206 | with tf.variable_scope("matching_layer") as vs:
207 |
208 | output_dim = utterances_output.get_shape()[-1].value
209 | utterances_output = tf.reshape(utterances_output, [-1, max_utter_num*max_utter_len, output_dim]) # [batch_size, max_utter_num*max_utter_len, rnn_size*2]
210 | utterances_output_tiled = tf.tile(tf.expand_dims(utterances_output, 1), [1, max_response_num, 1, 1]) # [batch_size, max_response_num, max_utter_num*max_utter_len, rnn_size*2]
211 | responses_output = tf.reshape(responses_output, [-1, max_response_num, max_response_len, output_dim]) # [batch_size, max_response_num, max_response_len, rnn_size*2]
212 | personas_output = tf.reshape(personas_output, [-1, max_persona_num*max_persona_len, output_dim]) # [batch_size, max_persona_num*max_persona_len, rnn_size*2]
213 | personas_output_tiled = tf.tile(tf.expand_dims(personas_output, 1), [1, max_response_num, 1, 1]) # [batch_size, max_response_num, max_persona_num*max_persona_len, rnn_size*2]
214 |
215 | # 1. cross-attention between context and response
216 | similarity_UR = tf.matmul(responses_output, # [batch_size, max_response_num, response_len, max_utter_num*max_utter_len]
217 | tf.transpose(utterances_output_tiled, perm=[0,1,3,2]), name='similarity_matrix_UR')
218 | attended_utterances_output_ur = attended_context(similarity_UR, responses_output, flattened_responses_len, max_response_len, max_response_num) # [batch_size, max_response_num, max_utter_num*max_utter_len, dim]
219 | attended_responses_output_ur = attended_response(similarity_UR, utterances_output, flattened_utterances_len, max_utter_len, max_utter_num) # [batch_size, max_response_num, response_len, dim]
220 |
221 | m_u_ur = tf.concat(axis=-1, values=[utterances_output_tiled, attended_utterances_output_ur, tf.multiply(utterances_output_tiled, attended_utterances_output_ur), utterances_output_tiled-attended_utterances_output_ur]) # [batch_size, max_response_num, max_utter_num*max_utter_len, dim]
222 | m_r_ur = tf.concat(axis=-1, values=[responses_output, attended_responses_output_ur, tf.multiply(responses_output, attended_responses_output_ur), responses_output-attended_responses_output_ur]) # [batch_size, max_response_num, response_len, dim]
223 | concat_dim = m_u_ur.get_shape()[-1].value
224 | m_u_ur = tf.reshape(m_u_ur, [-1, max_utter_len, concat_dim]) # [batch_size*max_response_num*max_utter_num, max_utter_len, dim]
225 | m_r_ur = tf.reshape(m_r_ur, [-1, max_response_len, concat_dim]) # [batch_size*max_response_num, max_response_len, dim]
226 |
227 | rnn_scope_cross = 'bidirectional_rnn_cross'
228 | rnn_size_layer_2 = rnn_size
229 | tiled_flattened_utterances_len = tf.reshape(tf.tile(tf.expand_dims(self.utterances_len, 1), [1, max_response_num, 1]), [-1, ]) # [batch_size*max_response_num*max_utter_num, ]
230 | u_ur_rnn_output, u_ur_rnn_state = lstm_layer(m_u_ur, tiled_flattened_utterances_len, rnn_size_layer_2, self.dropout_keep_prob, rnn_scope_cross, scope_reuse=False)
231 | r_ur_rnn_output, r_ur_rnn_state = lstm_layer(m_r_ur, flattened_responses_len, rnn_size_layer_2, self.dropout_keep_prob, rnn_scope_cross, scope_reuse=True)
232 | utterances_output_cross_ur = tf.concat(axis=-1, values=u_ur_rnn_output) # [batch_size*max_response_num*max_utter_num, max_utter_len, rnn_size*2]
233 | responses_output_cross_ur = tf.concat(axis=-1, values=r_ur_rnn_output) # [batch_size*max_response_num, max_response_len, rnn_size*2]
234 | print("establish cross-attention between context and response")
235 |
236 |
237 | # 2. cross-attention between persona and response without decay
238 | similarity_PR = tf.matmul(responses_output, # [batch_size, max_response_num, response_len, max_persona_num*max_persona_len]
239 | tf.transpose(personas_output_tiled, perm=[0,1,3,2]), name='similarity_matrix_PR')
240 | attended_personas_output_pr = attended_context(similarity_PR, responses_output, flattened_responses_len, max_response_len, max_response_num) # [batch_size, max_response_num, max_persona_num*max_persona_len, dim]
241 | attended_responses_output_pr = attended_response(similarity_PR, personas_output, flattened_personas_len, max_persona_len, max_persona_num) # [batch_size, max_response_num, response_len, dim]
242 |
243 | m_p_pr = tf.concat(axis=-1, values=[personas_output_tiled, attended_personas_output_pr, tf.multiply(personas_output_tiled, attended_personas_output_pr), personas_output_tiled-attended_personas_output_pr]) # [batch_size, max_response_num, max_persona_num*max_persona_len, dim]
244 | m_r_pr = tf.concat(axis=-1, values=[responses_output, attended_responses_output_pr, tf.multiply(responses_output, attended_responses_output_pr), responses_output-attended_responses_output_pr]) # [batch_size, max_response_num, response_len, dim]
245 | m_p_pr = tf.reshape(m_p_pr, [-1, max_persona_len, concat_dim]) # [batch_size*max_response_num*max_persona_num, max_persona_len, dim]
246 | m_r_pr = tf.reshape(m_r_pr, [-1, max_response_len, concat_dim]) # [batch_size*max_response_num, max_response_len, dim]
247 |
248 | tiled_flattened_personas_len = tf.reshape(tf.tile(tf.expand_dims(self.personas_len, 1), [1, max_response_num, 1]), [-1, ]) # [batch_size*max_response_num*max_persona_num, ]
249 | p_pr_rnn_output, p_pr_rnn_state = lstm_layer(m_p_pr, tiled_flattened_personas_len, rnn_size_layer_2, self.dropout_keep_prob, rnn_scope_cross, scope_reuse=True)
250 | r_pr_rnn_output, r_pr_rnn_state = lstm_layer(m_r_pr, flattened_responses_len, rnn_size_layer_2, self.dropout_keep_prob, rnn_scope_cross, scope_reuse=True)
251 | personas_output_cross_pr = tf.concat(axis=-1, values=p_pr_rnn_output) # [batch_size*max_response_num*max_persona_num, max_persona_len, rnn_size*2]
252 | responses_output_cross_pr = tf.concat(axis=-1, values=r_pr_rnn_output) # [batch_size*max_response_num, max_response_len, rnn_size*2]
253 | print("establish cross-attention between persona and response")
254 |
255 |
256 | # =============================== Aggregation layer ===============================
257 | with tf.variable_scope("aggregation_layer") as vs:
258 | # aggregate utterance across utterance_len
259 | final_utterances_max = tf.reduce_max(utterances_output_cross_ur, axis=1)
260 | final_utterances_state = tf.concat(axis=1, values=[u_ur_rnn_state[0].h, u_ur_rnn_state[1].h])
261 | final_utterances = tf.concat(axis=1, values=[final_utterances_max, final_utterances_state]) # [batch_size*max_response_num*max_utter_num, 4*rnn_size]
262 |
263 | # aggregate utterance across utterance_num
264 | final_utterances = tf.reshape(final_utterances, [-1, max_utter_num, output_dim*2]) # [batch_size*max_response_num, max_utter_num, 4*rnn_size]
265 | tiled_utters_num = tf.reshape(tf.tile(tf.expand_dims(self.utters_num, 1), [1, max_response_num]), [-1, ]) # [batch_size*max_response_num, ]
266 | rnn_scope_aggre = "bidirectional_rnn_aggregation"
267 | final_utterances_output, final_utterances_state = lstm_layer(final_utterances, tiled_utters_num, rnn_size, self.dropout_keep_prob, rnn_scope_aggre, scope_reuse=False)
268 | final_utterances_output = tf.concat(axis=2, values=final_utterances_output) # [batch_size*max_response_num, max_utter_num, 2*rnn_size]
269 | final_utterances_max = tf.reduce_max(final_utterances_output, axis=1) # [batch_size*max_response_num, 2*rnn_size]
270 | final_utterances_state = tf.concat(axis=1, values=[final_utterances_state[0].h, final_utterances_state[1].h]) # [batch_size*max_response_num, 2*rnn_size]
271 | aggregated_utterances = tf.concat(axis=1, values=[final_utterances_max, final_utterances_state]) # [batch_size*max_response_num, 4*rnn_size]
272 |
273 | # aggregate response across response_len
274 | final_responses_max = tf.reduce_max(responses_output_cross_ur, axis=1) # [batch_size*max_response_num, 2*rnn_size]
275 | final_responses_state = tf.concat(axis=1, values=[r_ur_rnn_state[0].h, r_ur_rnn_state[1].h]) # [batch_size*max_response_num, 2*rnn_size]
276 | aggregated_responses_ur = tf.concat(axis=1, values=[final_responses_max, final_responses_state]) # [batch_size*max_response_num, 4*rnn_size]
277 | print("establish RNN aggregation on context and response")
278 |
279 |
280 | # aggregate persona across persona_len
281 | final_personas_max = tf.reduce_max(personas_output_cross_pr, axis=1) # [batch_size*max_response_num*max_persona_num, 2*rnn_size]
282 | final_personas_state = tf.concat(axis=1, values=[p_pr_rnn_state[0].h, p_pr_rnn_state[1].h]) # [batch_size*max_response_num*max_persona_num, 2*rnn_size]
283 | final_personas = tf.concat(axis=1, values=[final_personas_max, final_personas_state]) # [batch_size*max_response_num*max_persona_num, 4*rnn_size]
284 |
285 | # aggregate persona across persona_num
286 | # 1. RNN aggregation
287 | # final_personas = tf.reshape(final_personas, [-1, max_persona_num, output_dim*2]) # [batch_size*max_response_num, max_persona_num, 4*rnn_size]
288 | # tiled_personas_num = tf.reshape(tf.tile(tf.expand_dims(self.personas_num, 1), [1, max_response_num]), [-1, ]) # [batch_size*max_response_num, ]
289 | # final_personas_output, final_personas_state = lstm_layer(final_personas, tiled_personas_num, rnn_size, self.dropout_keep_prob, rnn_scope_aggre, scope_reuse=True)
290 | # final_personas_output = tf.concat(axis=2, values=final_personas_output) # [batch_size*max_response_num, max_persona_num, 2*rnn_size]
291 | # final_personas_max = tf.reduce_max(final_personas_output, axis=1) # [batch_size*max_response_num, 2*rnn_size]
292 | # final_personas_state = tf.concat(axis=1, values=[final_personas_state[0].h, final_personas_state[1].h]) # [batch_size*max_response_num, 2*rnn_size]
293 | # aggregated_personas = tf.concat(axis=1, values=[final_personas_max, final_personas_state]) # [batch_size*max_response_num, 4*rnn_size]
294 | # print("establish RNN aggregation on persona")
295 | # 2. ATT aggregation
296 | final_personas = tf.reshape(final_personas, [-1, max_persona_num, output_dim*2]) # [batch_size*max_response_num, max_persona_num, 4*rnn_size]
297 | pers_w = tf.get_variable("pers_w", [output_dim*2, 1], initializer=tf.contrib.layers.xavier_initializer())
298 | pers_b = tf.get_variable("pers_b", shape=[1, ], initializer=tf.zeros_initializer())
299 | pers_weights = tf.nn.relu(tf.einsum('aij,jk->aik', final_personas, pers_w) + pers_b) # [batch_size*max_response_num, max_persona_num, 1]
300 | tiled_personas_num = tf.reshape(tf.tile(tf.expand_dims(self.personas_num, 1), [1, max_response_num]), [-1, ]) # [batch_size*max_response_num, ]
301 | mask_p = tf.expand_dims(tf.sequence_mask(tiled_personas_num, max_persona_num, dtype=tf.float32), -1) # [batch_size*max_response_num, max_persona_num, 1]
302 | pers_weights = pers_weights * mask_p + -1e9 * (1-mask_p) # [batch_size*max_response_num, max_persona_num, 1]
303 | pers_weights = tf.nn.softmax(pers_weights, dim=1)
304 | aggregated_personas = tf.matmul(tf.transpose(pers_weights, [0, 2, 1]), final_personas) # [batch_size*max_response_num, 1, 4*rnn_size]
305 | aggregated_personas = tf.squeeze(aggregated_personas, [1]) # [batch_size*max_response_num, 4*rnn_size]
306 |
307 | # aggregate response across response_len
308 | final_responses_max = tf.reduce_max(responses_output_cross_pr, axis=1) # [batch_size*max_response_num, 2*rnn_size]
309 | final_responses_state = tf.concat(axis=1, values=[r_pr_rnn_state[0].h, r_pr_rnn_state[1].h]) # [batch_size*max_response_num, 2*rnn_size]
310 | aggregated_responses_pr = tf.concat(axis=1, values=[final_responses_max, final_responses_state]) # [batch_size*max_response_num, 4*rnn_size]
311 | print("establish ATT aggregation on persona and response")
312 |
313 | joined_feature = tf.concat(axis=1, values=[aggregated_utterances, aggregated_responses_ur, aggregated_personas, aggregated_responses_pr]) # [batch_size*max_response_num, 16*rnn_size(3200)]
314 | print("joined feature: {}".format(joined_feature.get_shape()))
315 |
316 |
317 | # =============================== Prediction layer ===============================
318 | with tf.variable_scope("prediction_layer") as vs:
319 | hidden_input_size = joined_feature.get_shape()[1].value
320 | hidden_output_size = 256
321 | regularizer = tf.contrib.layers.l2_regularizer(l2_reg_lambda)
322 | #regularizer = None
323 | # dropout On MLP
324 | joined_feature = tf.nn.dropout(joined_feature, keep_prob=self.dropout_keep_prob)
325 | full_out = tf.contrib.layers.fully_connected(joined_feature, hidden_output_size,
326 | activation_fn=tf.nn.relu,
327 | reuse=False,
328 | trainable=True,
329 | scope="projected_layer") # [batch_size*max_response_num, hidden_output_size(256)]
330 | full_out = tf.nn.dropout(full_out, keep_prob=self.dropout_keep_prob)
331 |
332 | last_weight_dim = full_out.get_shape()[1].value
333 | print("last_weight_dim: {}".format(last_weight_dim))
334 | bias = tf.Variable(tf.constant(0.1, shape=[1]), name="bias")
335 | s_w = tf.get_variable("s_w", shape=[last_weight_dim, 1], initializer=tf.contrib.layers.xavier_initializer())
336 | logits = tf.reshape(tf.matmul(full_out, s_w) + bias, [-1, max_response_num]) # [batch_size, max_response_num]
337 | print("logits: {}".format(logits.get_shape()))
338 |
339 | self.probs = tf.nn.softmax(logits, name="prob") # [batch_size, max_response_num]
340 |
341 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=self.target)
342 | self.mean_loss = tf.reduce_mean(losses, name="mean_loss") + l2_reg_lambda * l2_loss + sum(
343 | tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
344 |
345 | with tf.name_scope("accuracy"):
346 | correct_prediction = tf.equal(tf.argmax(self.probs, 1), self.target)
347 | self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy")
348 |
--------------------------------------------------------------------------------
/model/train.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import os
4 | import time
5 | import datetime
6 | import operator
7 | from collections import defaultdict
8 | from model import metrics
9 | from model import data_helpers
10 | from model.model_DIM import DIM
11 |
12 |
13 | # Files
14 | tf.flags.DEFINE_string("train_file", "", "path to train file")
15 | tf.flags.DEFINE_string("valid_file", "", "path to valid file")
16 | tf.flags.DEFINE_string("vocab_file", "", "vocabulary file")
17 | tf.flags.DEFINE_string("char_vocab_file", "", "path to char vocab file")
18 | tf.flags.DEFINE_string("embedded_vector_file", "", "pre-trained embedded word vector")
19 |
20 | # Model Hyperparameters
21 | tf.flags.DEFINE_integer("max_utter_num", 15, "max utterance number")
22 | tf.flags.DEFINE_integer("max_utter_len", 20, "max utterance length")
23 | tf.flags.DEFINE_integer("max_response_num", 20, "max response candidate number")
24 | tf.flags.DEFINE_integer("max_response_len", 20, "max response length")
25 | tf.flags.DEFINE_integer("max_persona_num", 5, "max persona number")
26 | tf.flags.DEFINE_integer("max_persona_len", 15, "max persona length")
27 | tf.flags.DEFINE_integer("max_word_length", 18, "max word length")
28 | tf.flags.DEFINE_integer("embedding_dim", 200, "dimensionality of word embedding")
29 | tf.flags.DEFINE_integer("rnn_size", 200, "number of RNN units")
30 |
31 | # Training parameters
32 | tf.flags.DEFINE_integer("batch_size", 128, "batch size (default: 128)")
33 | tf.flags.DEFINE_float("l2_reg_lambda", 0, "L2 regularizaion lambda (default: 0)")
34 | tf.flags.DEFINE_float("dropout_keep_prob", 1.0, "dropout keep probability (default: 1.0)")
35 | tf.flags.DEFINE_integer("num_epochs", 1000000, "number of training epochs (default: 1000000)")
36 | tf.flags.DEFINE_integer("evaluate_every", 1000, "evaluate model on valid dataset after this many steps (default: 1000)")
37 |
38 | # Misc Parameters
39 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
40 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
41 |
42 | FLAGS = tf.flags.FLAGS
43 | FLAGS._parse_flags()
44 | print("\nParameters:")
45 | for attr, value in sorted(FLAGS.__flags.items()):
46 | print("{}={}".format(attr.upper(), value))
47 | print("")
48 |
49 | # Load data
50 | print("Loading data...")
51 |
52 | vocab = data_helpers.load_vocab(FLAGS.vocab_file)
53 | print('vocabulary size: {}'.format(len(vocab)))
54 | charVocab = data_helpers.load_char_vocab(FLAGS.char_vocab_file)
55 | print('charVocab size: {}'.format(len(charVocab)))
56 |
57 | train_dataset = data_helpers.load_dataset(FLAGS.train_file, vocab, FLAGS.max_utter_num, FLAGS.max_utter_len, FLAGS.max_response_len, FLAGS.max_persona_len)
58 | print('train dataset size: {}'.format(len(train_dataset)))
59 | valid_dataset = data_helpers.load_dataset(FLAGS.valid_file, vocab, FLAGS.max_utter_num, FLAGS.max_utter_len, FLAGS.max_response_len, FLAGS.max_persona_len)
60 | print('valid dataset size: {}'.format(len(valid_dataset)))
61 |
62 |
63 | with tf.Graph().as_default():
64 | session_conf = tf.ConfigProto(
65 | allow_soft_placement=FLAGS.allow_soft_placement,
66 | log_device_placement=FLAGS.log_device_placement)
67 | sess = tf.Session(config=session_conf)
68 | with sess.as_default():
69 | dim = DIM(
70 | max_utter_num=FLAGS.max_utter_num,
71 | max_utter_len=FLAGS.max_utter_len,
72 | max_response_num=FLAGS.max_response_num,
73 | max_response_len=FLAGS.max_response_len,
74 | max_persona_num=FLAGS.max_persona_num,
75 | max_persona_len=FLAGS.max_persona_len,
76 | vocab_size=len(vocab),
77 | embedding_size=FLAGS.embedding_dim,
78 | vocab=vocab,
79 | rnn_size=FLAGS.rnn_size,
80 | maxWordLength=FLAGS.max_word_length,
81 | charVocab=charVocab,
82 | l2_reg_lambda=FLAGS.l2_reg_lambda)
83 | # Define Training procedure
84 | global_step = tf.Variable(0, name="global_step", trainable=False)
85 | starter_learning_rate = 0.001
86 | learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
87 | 5000, 0.96, staircase=True)
88 | optimizer = tf.train.AdamOptimizer(learning_rate)
89 | grads_and_vars = optimizer.compute_gradients(dim.mean_loss)
90 | train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
91 |
92 | # Keep track of gradient values and sparsity (optional)
93 | """
94 | grad_summaries = []
95 | for g, v in grads_and_vars:
96 | if g is not None:
97 | grad_hist_summary = tf.histogram_summary("{}/grad/hist".format(v.name), g)
98 | sparsity_summary = tf.scalar_summary("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
99 | grad_summaries.append(grad_hist_summary)
100 | grad_summaries.append(sparsity_summary)
101 | grad_summaries_merged = tf.merge_summary(grad_summaries)
102 | """
103 |
104 | # Output directory for models and summaries
105 | timestamp = str(int(time.time()))
106 | out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
107 | print("Writing to {}\n".format(out_dir))
108 |
109 | # Summaries for loss and accuracy
110 | """
111 | loss_summary = tf.scalar_summary("loss", dim.mean_loss)
112 | acc_summary = tf.scalar_summary("accuracy", dim.accuracy)
113 |
114 | # Train Summaries
115 | train_summary_op = tf.merge_summary([loss_summary, acc_summary, grad_summaries_merged])
116 | train_summary_dir = os.path.join(out_dir, "summaries", "train")
117 | train_summary_writer = tf.train.SummaryWriter(train_summary_dir, sess.graph_def)
118 |
119 | # Dev summaries
120 | dev_summary_op = tf.merge_summary([loss_summary, acc_summary])
121 | dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
122 | dev_summary_writer = tf.train.SummaryWriter(dev_summary_dir, sess.graph_def)
123 | """
124 |
125 | # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
126 | checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
127 | checkpoint_prefix = os.path.join(checkpoint_dir, "model")
128 | if not os.path.exists(checkpoint_dir):
129 | os.makedirs(checkpoint_dir)
130 | saver = tf.train.Saver(tf.global_variables())
131 |
132 | # Initialize all variables
133 | sess.run(tf.global_variables_initializer())
134 |
135 | def train_step(x_utterances, x_utterances_len, x_response, x_response_len,
136 | x_utters_num, x_target, x_ids,
137 | x_u_char, x_u_char_len, x_r_char, x_r_char_len,
138 | x_personas, x_personas_len, x_p_char, x_p_char_len, x_personas_num):
139 | """
140 | A single training step
141 | """
142 | feed_dict = {
143 | dim.utterances: x_utterances,
144 | dim.utterances_len: x_utterances_len,
145 | dim.responses: x_response,
146 | dim.responses_len: x_response_len,
147 | dim.utters_num: x_utters_num,
148 | dim.target: x_target,
149 | dim.dropout_keep_prob: FLAGS.dropout_keep_prob,
150 | dim.u_charVec: x_u_char,
151 | dim.u_charLen: x_u_char_len,
152 | dim.r_charVec: x_r_char,
153 | dim.r_charLen: x_r_char_len,
154 | dim.personas: x_personas,
155 | dim.personas_len: x_personas_len,
156 | dim.p_charVec: x_p_char,
157 | dim.p_charLen: x_p_char_len,
158 | dim.personas_num: x_personas_num
159 | }
160 |
161 | _, step, loss, accuracy, predicted_prob = sess.run(
162 | [train_op, global_step, dim.mean_loss, dim.accuracy, dim.probs],
163 | feed_dict)
164 |
165 | if step % 100 == 0:
166 | time_str = datetime.datetime.now().isoformat()
167 | print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
168 | #train_summary_writer.add_summary(summaries, step)
169 |
170 |
171 | def dev_step():
172 | results = defaultdict(list)
173 | num_test = 0
174 | num_correct = 0.0
175 | valid_batches = data_helpers.batch_iter(valid_dataset, FLAGS.batch_size, 1, FLAGS.max_utter_num, FLAGS.max_utter_len, \
176 | FLAGS.max_response_num, FLAGS.max_response_len, FLAGS.max_persona_num, FLAGS.max_persona_len, \
177 | charVocab, FLAGS.max_word_length, shuffle=True)
178 | for valid_batch in valid_batches:
179 | x_utterances, x_utterances_len, x_response, x_response_len, \
180 | x_utters_num, x_target, x_ids, \
181 | x_u_char, x_u_char_len, x_r_char, x_r_char_len, \
182 | x_personas, x_personas_len, x_p_char, x_p_char_len, x_personas_num = valid_batch
183 | feed_dict = {
184 | dim.utterances: x_utterances,
185 | dim.utterances_len: x_utterances_len,
186 | dim.responses: x_response,
187 | dim.responses_len: x_response_len,
188 | dim.utters_num: x_utters_num,
189 | dim.target: x_target,
190 | dim.dropout_keep_prob: 1.0,
191 | dim.u_charVec: x_u_char,
192 | dim.u_charLen: x_u_char_len,
193 | dim.r_charVec: x_r_char,
194 | dim.r_charLen: x_r_char_len,
195 | dim.personas: x_personas,
196 | dim.personas_len: x_personas_len,
197 | dim.p_charVec: x_p_char,
198 | dim.p_charLen: x_p_char_len,
199 | dim.personas_num: x_personas_num
200 | }
201 | batch_accuracy, predicted_prob = sess.run([dim.accuracy, dim.probs], feed_dict)
202 |
203 | num_test += len(predicted_prob)
204 | if num_test % 1000 == 0:
205 | print(num_test)
206 | num_correct += len(predicted_prob) * batch_accuracy
207 |
208 | # predicted_prob = [batch_size, max_response_num]
209 | for i in range(len(predicted_prob)):
210 | probs = predicted_prob[i]
211 | us_id = x_ids[i]
212 | label = x_target[i]
213 | labels = np.zeros(FLAGS.max_response_num)
214 | labels[label] = 1
215 | for r_id, prob in enumerate(probs):
216 | results[us_id].append((str(r_id), labels[r_id], prob))
217 |
218 | #calculate top-1 precision
219 | print('num_test_samples: {} test_accuracy: {}'.format(num_test, num_correct/num_test))
220 | accu, precision, recall, f1, loss = metrics.classification_metrics(results)
221 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss))
222 |
223 | mvp = metrics.mean_average_precision(results)
224 | mrr = metrics.mean_reciprocal_rank(results)
225 | top_1_precision = metrics.top_1_precision(results)
226 | total_valid_query = metrics.get_num_valid_query(results)
227 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tTop-1 precision: {}\tNum_query: {}'.format(mvp, mrr, top_1_precision, total_valid_query))
228 |
229 | return mrr
230 |
231 | best_mrr = 0.0
232 | batches = data_helpers.batch_iter(train_dataset, FLAGS.batch_size, FLAGS.num_epochs, FLAGS.max_utter_num, FLAGS.max_utter_len, \
233 | FLAGS.max_response_num, FLAGS.max_response_len, FLAGS.max_persona_num, FLAGS.max_persona_len, \
234 | charVocab, FLAGS.max_word_length, shuffle=True)
235 | for batch in batches:
236 | x_utterances, x_utterances_len, x_response, x_response_len, \
237 | x_utters_num, x_target, x_ids, \
238 | x_u_char, x_u_char_len, x_r_char, x_r_char_len, \
239 | x_personas, x_personas_len, x_p_char, x_p_char_len, x_personas_num = batch
240 | train_step(x_utterances, x_utterances_len, x_response, x_response_len, x_utters_num, x_target, x_ids, x_u_char, x_u_char_len, x_r_char, x_r_char_len, x_personas, x_personas_len, x_p_char, x_p_char_len, x_personas_num)
241 | current_step = tf.train.global_step(sess, global_step)
242 | if current_step % FLAGS.evaluate_every == 0:
243 | print("\nEvaluation:")
244 | valid_mrr = dev_step()
245 | if valid_mrr > best_mrr:
246 | best_mrr = valid_mrr
247 | path = saver.save(sess, checkpoint_prefix, global_step=current_step)
248 | print("Saved model checkpoint to {}\n".format(path))
249 |
--------------------------------------------------------------------------------
/scripts/compute_recall.py:
--------------------------------------------------------------------------------
1 |
2 | test_out_filename = "persona_test_out.txt"
3 |
4 | with open(test_out_filename, 'r') as f:
5 | cur_q_id = None
6 | num_query = 0
7 | recall = {"recall@1": 0,
8 | "recall@2": 0,
9 | "recall@5": 0,
10 | "recall@10": 0}
11 |
12 | lines = f.readlines()
13 | for line in lines[1:]:
14 | line = line.strip().split('\t')
15 | line = [float(ele) for ele in line]
16 |
17 | if cur_q_id is None:
18 | cur_q_id = line[0]
19 | num_query += 1
20 | elif line[0] != cur_q_id:
21 | cur_q_id = line[0]
22 | num_query += 1
23 |
24 | if line[4] == 1.0:
25 | rank = line[3]
26 |
27 | if rank <= 1:
28 | recall["recall@1"] += 1
29 | if rank <= 2:
30 | recall["recall@2"] += 1
31 | if rank <= 5:
32 | recall["recall@5"] += 1
33 | if rank <= 10:
34 | recall["recall@10"] += 1
35 |
36 | recall["recall@1"] = recall["recall@1"] / float(num_query)
37 | recall["recall@2"] = recall["recall@2"] / float(num_query)
38 | recall["recall@5"] = recall["recall@5"] / float(num_query)
39 | recall["recall@10"] = recall["recall@10"] / float(num_query)
40 | print("num_query = {}".format(num_query))
41 | print("recall@1 = {}".format(recall["recall@1"]))
42 | print("recall@2 = {}".format(recall["recall@2"]))
43 | print("recall@5 = {}".format(recall["recall@5"]))
44 | print("recall@10 = {}".format(recall["recall@10"]))
45 |
--------------------------------------------------------------------------------
/scripts/test.sh:
--------------------------------------------------------------------------------
1 | cur_dir=`pwd`
2 | parentdir="$(dirname $cur_dir)"
3 |
4 | DATA_DIR=${parentdir}/data/personachat_processed
5 |
6 | latest_run=`ls -dt runs/* |head -n 1`
7 | latest_checkpoint=${latest_run}/checkpoints
8 | # latest_checkpoint=runs/1556416288/checkpoints
9 | echo $latest_checkpoint
10 |
11 | test_file=$DATA_DIR/processed_test_self_original.txt # for self_original
12 | # test_file=$DATA_DIR/processed_test_self_revised.txt # for self_revised
13 | # test_file=$DATA_DIR/processed_test_other_original.txt # for other_original
14 | # test_file=$DATA_DIR/processed_test_other_revised.txt # for other_revised
15 | vocab_file=$DATA_DIR/vocab.txt
16 | char_vocab_file=$DATA_DIR/char_vocab.txt
17 | output_file=./persona_test_out.txt
18 |
19 | max_utter_num=15
20 | max_utter_len=20
21 | max_response_num=20
22 | max_response_len=20
23 | max_persona_num=5
24 | max_persona_len=15
25 | max_word_length=18
26 | batch_size=32
27 |
28 | PKG_DIR=${parentdir}
29 |
30 | PYTHONPATH=${PKG_DIR}:$PYTHONPATH CUDA_VISIBLE_DEVICES=3 python -u ${PKG_DIR}/model/eval.py \
31 | --test_file $test_file \
32 | --vocab_file $vocab_file \
33 | --char_vocab_file $char_vocab_file \
34 | --output_file $output_file \
35 | --max_utter_num $max_utter_num \
36 | --max_utter_len $max_utter_len \
37 | --max_response_num $max_response_num \
38 | --max_response_len $max_response_len \
39 | --max_persona_num $max_persona_num \
40 | --max_persona_len $max_persona_len \
41 | --max_word_length $max_word_length \
42 | --batch_size $batch_size \
43 | --checkpoint_dir $latest_checkpoint > log_DIM_test.txt 2>&1 &
44 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | cur_dir=`pwd`
2 | parentdir="$(dirname $cur_dir)"
3 |
4 | DATA_DIR=${parentdir}/data/personachat_processed
5 |
6 | # for self_original
7 | train_file=$DATA_DIR/processed_train_self_original.txt
8 | valid_file=$DATA_DIR/processed_valid_self_original.txt
9 | # for self_revised
10 | # train_file=$DATA_DIR/processed_train_self_revised.txt
11 | # valid_file=$DATA_DIR/processed_valid_self_revised.txt
12 | # for other_original
13 | # train_file=$DATA_DIR/processed_train_other_original.txt
14 | # valid_file=$DATA_DIR/processed_valid_other_original.txt
15 | # for other_revised
16 | # train_file=$DATA_DIR/processed_train_other_revised.txt
17 | # valid_file=$DATA_DIR/processed_valid_other_revised.txt
18 |
19 | vocab_file=$DATA_DIR/vocab.txt
20 | char_vocab_file=$DATA_DIR/char_vocab.txt
21 | embedded_vector_file=$DATA_DIR/glove_42B_300d_vec_plus_word2vec_100.txt
22 |
23 | max_utter_num=15
24 | max_utter_len=20
25 | max_response_num=20
26 | max_response_len=20
27 | max_persona_num=5
28 | max_persona_len=15
29 | max_word_length=18
30 | embedding_dim=400
31 | rnn_size=200
32 |
33 | batch_size=16
34 | lambda=0
35 | dropout_keep_prob=0.8
36 | num_epochs=10
37 | evaluate_every=500
38 |
39 | PKG_DIR=${parentdir}
40 |
41 | PYTHONPATH=${PKG_DIR}:$PYTHONPATH CUDA_VISIBLE_DEVICES=3 python -u ${PKG_DIR}/model/train.py \
42 | --train_file $train_file \
43 | --valid_file $valid_file \
44 | --vocab_file $vocab_file \
45 | --char_vocab_file $char_vocab_file \
46 | --embedded_vector_file $embedded_vector_file \
47 | --max_utter_num $max_utter_num \
48 | --max_utter_len $max_utter_len \
49 | --max_response_num $max_response_num \
50 | --max_response_len $max_response_len \
51 | --max_persona_num $max_persona_num \
52 | --max_persona_len $max_persona_len \
53 | --max_word_length $max_word_length \
54 | --embedding_dim $embedding_dim \
55 | --rnn_size $rnn_size \
56 | --batch_size $batch_size \
57 | --l2_reg_lambda $lambda \
58 | --dropout_keep_prob $dropout_keep_prob \
59 | --num_epochs $num_epochs \
60 | --evaluate_every $evaluate_every > log_DIM_train.txt 2>&1 &
61 |
--------------------------------------------------------------------------------