├── README.md ├── data └── data_preprocess.py ├── image ├── model.png └── result.png ├── model ├── __init__.py ├── data_helpers.py ├── eval.py ├── metrics.py ├── model_DIM.py └── train.py └── scripts ├── compute_recall.py ├── test.sh └── train.sh /README.md: -------------------------------------------------------------------------------- 1 | # Dually Interactive Matching Network for Personalized Response Selection in Retrieval-Based Chatbots 2 | This repository contains the source code and dataset for the EMNLP 2019 paper [Dually Interactive Matching Network for Personalized Response Selection in Retrieval-Based Chatbots](https://www.aclweb.org/anthology/D19-1193.pdf) by Gu et al.
3 | 4 | Our proposed Dually Interactive Matching Network (DIM) has achieved a new state-of-the-art performance of response selection on the PERSONA-CHAT dataset. 5 | 6 | ## Model overview 7 | 8 | 9 | ## Results 10 | 11 | 12 | ## Dependencies 13 | Python 2.7
14 | Tensorflow 1.4.0 15 | 16 | ## Dataset 17 | Your can download the PERSONA-CHAT dataset [here](https://drive.google.com/open?id=1gNyVL5pSMO6DnTIlA9ORNIrd2zm8f3QH) or from [ParlAI](https://parl.ai/), and unzip it to the folder of ```data```.
18 | Run the following commands and the processed files are stored in ```data/personachat_processed/```.
19 | ``` 20 | cd data 21 | python data_preprocess.py 22 | ``` 23 | Then, download the embedding and vocab files [here](https://drive.google.com/open?id=1gGZfQ-m7EGo5Z1Ts93Ta8GPJpdIQqckC), and unzip them to the folder of ```data/personachat_processed/```. 24 | 25 | ## Train a new model 26 | ``` 27 | cd scripts 28 | bash train.sh 29 | ``` 30 | The training process is recorded in ```log_DIM_train.txt``` file. 31 | 32 | ## Test a trained model 33 | ``` 34 | bash test.sh 35 | ``` 36 | The testing process is recorded in ```log_DIM_test.txt``` file. And your can get a ```persona_test_out.txt``` file which records scores for each context-response pair. Run the following command and you can compute the metric of Recall. 37 | ``` 38 | python compute_recall.py 39 | ``` 40 | 41 | ## Cite 42 | If you use the code, please cite the following paper: 43 | **"Dually Interactive Matching Network for Personalized Response Selection in Retrieval-Based Chatbots"** 44 | Jia-Chen Gu, Zhen-Hua Ling, Xiaodan Zhu, Quan Liu. _EMNLP (2019)_ 45 | 46 | ``` 47 | @inproceedings{gu-etal-2019-dually, 48 | title = "Dually Interactive Matching Network for Personalized Response Selection in Retrieval-Based Chatbots", 49 | author = "Gu, Jia-Chen and 50 | Ling, Zhen-Hua and 51 | Zhu, Xiaodan and 52 | Liu, Quan", 53 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)", 54 | month = nov, 55 | year = "2019", 56 | address = "Hong Kong, China", 57 | publisher = "Association for Computational Linguistics", 58 | url = "https://www.aclweb.org/anthology/D19-1193", 59 | pages = "1845--1854", 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /data/data_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from nltk.tokenize import WordPunctTokenizer 5 | 6 | 7 | def tokenize(text): 8 | return WordPunctTokenizer().tokenize(text) 9 | 10 | 11 | def data_process_none(input_path, output_path, fname): 12 | 13 | dialogues = [] 14 | dialogue = [] 15 | with open(os.path.join(input_path, fname), "r") as f: 16 | for line in f: 17 | line = line.decode('utf-8').strip() 18 | if line.split()[0] == "1": # new dialogue 19 | dialogues.append(dialogue) 20 | dialogue = [] 21 | dialogue.append(line) 22 | 23 | dialogues.append(dialogue) 24 | dialogues.remove([]) 25 | print("{} is composed of {} dialogues".format(fname, len(dialogues))) 26 | 27 | context_candidates = [] 28 | for dialogue in dialogues: 29 | context_history = [] 30 | for turn in dialogue: 31 | fields = turn.split("\t") 32 | context = " ".join(tokenize(fields[0])[1:]) 33 | response = fields[1] 34 | candidates = fields[-1].split("|") 35 | random.shuffle(candidates) 36 | label = candidates.index(response) 37 | 38 | context_history.append(context) 39 | # (context, candidates, label, partner's persona, your persona) 40 | context_candidates.append( [" _eos_ ".join(context_history) + " _eos_", 41 | "|".join(candidates), 42 | str(label), 43 | "NA", 44 | "NA"] ) 45 | context_history.append(response) 46 | 47 | print("{} is composed of {} context-candidates".format(fname, len(context_candidates))) 48 | 49 | with open(os.path.join(output_path, "processed_{}".format(fname)), "w") as f: 50 | print("Saving dataset to processed_{} ...".format(fname)) 51 | for dialogue in context_candidates: 52 | f.write(("\t".join(dialogue) + "\n").encode('utf-8')) 53 | 54 | 55 | def data_process_self(input_path, output_path, fname): 56 | 57 | dialogues = [] 58 | dialogue = [] 59 | with open(os.path.join(input_path, fname), "r") as f: 60 | for line in f: 61 | line = line.decode('utf-8').strip() 62 | if line.split()[0] == "1": # new dialogue 63 | dialogues.append(dialogue) 64 | dialogue = [] 65 | dialogue.append(line) 66 | 67 | dialogues.append(dialogue) 68 | dialogues.remove([]) 69 | print("{} is composed of {} dialogues".format(fname, len(dialogues))) 70 | 71 | context_candidates = [] 72 | for dialogue in dialogues: 73 | persona = [] 74 | context_history = [] 75 | for line in dialogue: 76 | fields = line.strip().split("\t") 77 | 78 | if len(fields) == 1: 79 | persona.append((" ").join(tokenize(fields[0])[4:])) 80 | if len(fields) == 4: 81 | context = " ".join(tokenize(fields[0])[1:]) 82 | response = fields[1] 83 | candidates = fields[-1].split("|") 84 | random.shuffle(candidates) 85 | label = candidates.index(response) 86 | 87 | context_history.append(context) 88 | # (context, candidates, label, partner's persona, your persona) 89 | context_candidates.append( [" _eos_ ".join(context_history) + " _eos_", 90 | "|".join(candidates), 91 | str(label), 92 | "NA", 93 | "|".join(persona)] ) 94 | context_history.append(response) 95 | print("{} is composed of {} context-candidates".format(fname, len(context_candidates))) 96 | 97 | with open(os.path.join(output_path, "processed_{}".format(fname)), "w") as f: 98 | print("Saving dataset to processed_{} ...".format(fname)) 99 | for dialogue in context_candidates: 100 | f.write(("\t".join(dialogue) + "\n").encode('utf-8')) 101 | 102 | 103 | def data_process_other(input_path, output_path, fname): 104 | 105 | dialogues = [] 106 | dialogue = [] 107 | with open(os.path.join(input_path, fname), "r") as f: 108 | for line in f: 109 | line = line.decode('utf-8').strip() 110 | if line.split()[0] == "1": # new dialogue 111 | dialogues.append(dialogue) 112 | dialogue = [] 113 | dialogue.append(line) 114 | 115 | dialogues.append(dialogue) 116 | dialogues.remove([]) 117 | print("{} is composed of {} dialogues".format(fname, len(dialogues))) 118 | 119 | context_candidates = [] 120 | for dialogue in dialogues: 121 | persona = [] 122 | context_history = [] 123 | for line in dialogue: 124 | fields = line.strip().split("\t") 125 | 126 | if len(fields) == 1: 127 | persona.append((" ").join(tokenize(fields[0])[6:])) 128 | if len(fields) == 4: 129 | context = " ".join(tokenize(fields[0])[1:]) 130 | response = fields[1] 131 | candidates = fields[-1].split("|") 132 | random.shuffle(candidates) 133 | label = candidates.index(response) 134 | 135 | context_history.append(context) 136 | # (context, candidates, label, partner's persona, your persona) 137 | context_candidates.append( [" _eos_ ".join(context_history) + " _eos_", 138 | "|".join(candidates), 139 | str(label), 140 | "|".join(persona), 141 | "NA"] ) 142 | context_history.append(response) 143 | print("{} is composed of {} context-candidates".format(fname, len(context_candidates))) 144 | 145 | with open(os.path.join(output_path, "processed_{}".format(fname)), "w") as f: 146 | print("Saving dataset to processed_{} ...".format(fname)) 147 | for dialogue in context_candidates: 148 | f.write(("\t".join(dialogue) + "\n").encode('utf-8')) 149 | 150 | 151 | def data_process_both(input_path, output_path, fname): 152 | 153 | dialogues = [] 154 | dialogue = [] 155 | with open(os.path.join(input_path, fname), "r") as f: 156 | for line in f: 157 | line = line.decode('utf-8').strip() 158 | if line.split()[0] == "1": # new dialogue 159 | dialogues.append(dialogue) 160 | dialogue = [] 161 | dialogue.append(line) 162 | 163 | dialogues.append(dialogue) 164 | dialogues.remove([]) 165 | print("{} is composed of {} dialogues".format(fname, len(dialogues))) 166 | 167 | context_candidates = [] 168 | for dialogue in dialogues: 169 | self_persons = [] 170 | other_persona = [] 171 | context_history = [] 172 | for line in dialogue: 173 | fields = line.strip().split("\t") 174 | 175 | if len(fields) == 1: 176 | if fields[0].split()[1] == "your": 177 | self_persons.append((" ").join(tokenize(fields[0])[4:])) 178 | if fields[0].split()[1] == "partner's": 179 | other_persona.append((" ").join(tokenize(fields[0])[6:])) 180 | if len(fields) == 4: 181 | context = " ".join(tokenize(fields[0])[1:]) 182 | response = fields[1] 183 | candidates = fields[-1].split("|") 184 | random.shuffle(candidates) 185 | label = candidates.index(response) 186 | 187 | context_history.append(context) 188 | # (context, candidates, label, partner's persona, your persona) 189 | context_candidates.append( [" _eos_ ".join(context_history) + " _eos_", 190 | "|".join(candidates), 191 | str(label), 192 | "|".join(other_persona), 193 | "|".join(self_persons)] ) 194 | context_history.append(response) 195 | print("{} is composed of {} context-candidates".format(fname, len(context_candidates))) 196 | 197 | with open(os.path.join(output_path, "processed_{}".format(fname)), "w") as f: 198 | print("Saving dataset to processed_{} ...".format(fname)) 199 | for dialogue in context_candidates: 200 | f.write(("\t".join(dialogue) + "\n").encode('utf-8')) 201 | 202 | if __name__ == '__main__': 203 | 204 | input_path = "./personachat" 205 | output_path = "./personachat_processed" 206 | 207 | if not os.path.exists(output_path): 208 | os.makedirs(output_path) 209 | 210 | files = [file for file in os.listdir(input_path)] 211 | files_none = [file for file in files if file.split("_")[1] == "none"] 212 | files_self = [file for file in files if file.split("_")[1] == "self"] 213 | files_other = [file for file in files if file.split("_")[1] == "other"] 214 | files_both = [file for file in files if file.split("_")[1] == "both"] 215 | 216 | print("There are {} files to process.\nStart processing data ...".format(len(files))) 217 | 218 | for file in files_none: 219 | print("Preprocessing {} ...".format(file)) 220 | data_process_none(input_path, output_path, file) 221 | print("="*60) 222 | 223 | for file in files_self: 224 | print("Preprocessing {} ...".format(file)) 225 | data_process_self(input_path, output_path, file) 226 | print("="*60) 227 | 228 | for file in files_other: 229 | print("Preprocessing {} ...".format(file)) 230 | data_process_other(input_path, output_path, file) 231 | print("="*60) 232 | 233 | for file in files_both: 234 | print("Preprocessing {} ...".format(file)) 235 | data_process_both(input_path, output_path, file) 236 | print("="*60) 237 | 238 | print("data preprocess done!") 239 | -------------------------------------------------------------------------------- /image/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonForJoy/DIM/6faa4c61f57c28cfbd9a4ade52c29490a6e40d06/image/model.png -------------------------------------------------------------------------------- /image/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonForJoy/DIM/6faa4c61f57c28cfbd9a4ade52c29490a6e40d06/image/result.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /model/data_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | 5 | def load_vocab(fname): 6 | ''' 7 | vocab = {"I": 0, ...} 8 | ''' 9 | vocab={} 10 | with open(fname, 'rt') as f: 11 | for i,line in enumerate(f): 12 | word = line.decode('utf-8').strip() 13 | vocab[word] = i 14 | return vocab 15 | 16 | def load_char_vocab(fname): 17 | ''' 18 | charVocab = {"U": 0, "!": 1, ...} 19 | ''' 20 | charVocab={} 21 | with open(fname, 'rt') as f: 22 | for line in f: 23 | fields = line.strip().split('\t') 24 | char_id = int(fields[0]) 25 | ch = fields[1] 26 | charVocab[ch] = char_id 27 | return charVocab 28 | 29 | def to_vec(tokens, vocab, maxlen): 30 | ''' 31 | length: length of the input sequence 32 | vec: map the token to the vocab_id, return a varied-length array [3, 6, 4, 3, ...] 33 | ''' 34 | n = len(tokens) 35 | length = 0 36 | vec=[] 37 | for i in range(n): 38 | length += 1 39 | if tokens[i] in vocab: 40 | vec.append(vocab[tokens[i]]) 41 | else: 42 | vec.append(vocab["fiance"]) # fix to fiance 43 | return length, np.array(vec) 44 | 45 | def load_dataset(fname, vocab, max_utter_num, max_utter_len, max_response_len, max_persona_len): 46 | 47 | dataset=[] 48 | with open(fname, 'rt') as f: 49 | for us_id, line in enumerate(f): 50 | line = line.decode('utf-8').strip() 51 | fields = line.split('\t') 52 | 53 | # context utterances 54 | context = fields[0] 55 | utterances = (context + " ").split(' _eos_ ')[:-1] 56 | utterances = [utterance + " _eos_" for utterance in utterances] 57 | utterances = utterances[-max_utter_num:] # select the last max_utter_num utterances 58 | us_tokens = [] 59 | us_vec = [] 60 | us_len = [] 61 | for utterance in utterances: 62 | u_tokens = utterance.split(' ')[:max_utter_len] # select the head max_utter_len tokens in every utterance 63 | u_len, u_vec = to_vec(u_tokens, vocab, max_utter_len) 64 | us_tokens.append(u_tokens) 65 | us_vec.append(u_vec) 66 | us_len.append(u_len) 67 | us_num = len(utterances) 68 | 69 | # responses 70 | responses = fields[1].split("|") 71 | rs_tokens = [] 72 | rs_vec = [] 73 | rs_len = [] 74 | for response in responses: 75 | r_tokens = response.split(' ')[:max_response_len] # select the head max_response_len tokens in every candidate 76 | r_len, r_vec = to_vec(r_tokens, vocab, max_response_len) 77 | rs_tokens.append(r_tokens) 78 | rs_vec.append(r_vec) 79 | rs_len.append(r_len) 80 | 81 | # label 82 | label = int(fields[2]) 83 | 84 | # other persona 85 | if fields[3] != "NA" and fields[4] == "NA": 86 | personas = fields[3].split("|") 87 | ps_tokens = [] 88 | ps_vec = [] 89 | ps_len = [] 90 | for persona in personas: 91 | p_tokens = persona.split(' ')[:max_persona_len] # select the head max_persona_len tokens in every persona 92 | p_len, p_vec = to_vec(p_tokens, vocab, max_persona_len) 93 | ps_tokens.append(p_tokens) 94 | ps_vec.append(p_vec) 95 | ps_len.append(p_len) 96 | ps_num = len(personas) 97 | 98 | # self persona 99 | if fields[3] == "NA" and fields[4] != "NA": 100 | personas = fields[4].split("|") 101 | ps_tokens = [] 102 | ps_vec = [] 103 | ps_len = [] 104 | for persona in personas: 105 | p_tokens = persona.split(' ')[:max_persona_len] # select the head max_persona_len tokens in every persona 106 | p_len, p_vec = to_vec(p_tokens, vocab, max_persona_len) 107 | ps_tokens.append(p_tokens) 108 | ps_vec.append(p_vec) 109 | ps_len.append(p_len) 110 | ps_num = len(personas) 111 | 112 | dataset.append((us_id, us_tokens, us_vec, us_len, us_num, rs_tokens, rs_vec, rs_len, label, ps_tokens, ps_vec, ps_len, ps_num)) 113 | 114 | return dataset 115 | 116 | 117 | def normalize_vec(vec, maxlen): 118 | ''' 119 | pad the original vec to the same maxlen 120 | [3, 4, 7] maxlen=5 --> [3, 4, 7, 0, 0] 121 | ''' 122 | if len(vec) == maxlen: 123 | return vec 124 | 125 | new_vec = np.zeros(maxlen, dtype='int32') 126 | for i in range(len(vec)): 127 | new_vec[i] = vec[i] 128 | return new_vec 129 | 130 | 131 | def charVec(tokens, charVocab, maxlen, maxWordLength): 132 | ''' 133 | chars = np.array( (maxlen, maxWordLength) ) 0 if not found in charVocab or None 134 | word_lengths = np.array( maxlen ) 1 if None 135 | ''' 136 | n = len(tokens) 137 | if n > maxlen: 138 | n = maxlen 139 | 140 | chars = np.zeros((maxlen, maxWordLength), dtype=np.int32) 141 | word_lengths = np.ones(maxlen, dtype=np.int32) 142 | for i in range(n): 143 | token = tokens[i][:maxWordLength] 144 | word_lengths[i] = len(token) 145 | row = chars[i] 146 | for idx, ch in enumerate(token): 147 | if ch in charVocab: 148 | row[idx] = charVocab[ch] 149 | 150 | return chars, word_lengths 151 | 152 | 153 | def batch_iter(data, batch_size, num_epochs, max_utter_num, max_utter_len, max_response_num, max_response_len, 154 | max_persona_num, max_persona_len, charVocab, max_word_length, shuffle=True): 155 | """ 156 | Generates a batch iterator for a dataset. 157 | """ 158 | data_size = len(data) 159 | num_batches_per_epoch = int(len(data)/batch_size) + 1 160 | for epoch in range(num_epochs): 161 | # Shuffle the data at each epoch 162 | if shuffle: 163 | random.shuffle(data) 164 | for batch_num in range(num_batches_per_epoch): 165 | start_index = batch_num * batch_size 166 | end_index = min((batch_num + 1) * batch_size, data_size) 167 | 168 | x_utterances = [] 169 | x_utterances_len = [] 170 | x_responses = [] 171 | x_responses_len = [] 172 | 173 | x_labels = [] 174 | x_ids = [] 175 | x_utterances_num = [] 176 | 177 | x_utterances_char=[] 178 | x_utterances_char_len=[] 179 | x_responses_char=[] 180 | x_responses_char_len=[] 181 | 182 | x_personas = [] 183 | x_personas_len = [] 184 | x_personas_char=[] 185 | x_personas_char_len=[] 186 | x_personas_num = [] 187 | 188 | for rowIdx in range(start_index, end_index): 189 | us_id, us_tokens, us_vec, us_len, us_num, rs_tokens, rs_vec, rs_len, label, ps_tokens, ps_vec, ps_len, ps_num = data[rowIdx] 190 | 191 | # normalize us_vec and us_len 192 | new_utters_vec = np.zeros((max_utter_num, max_utter_len), dtype='int32') 193 | new_utters_len = np.zeros((max_utter_num, ), dtype='int32') 194 | for i in range(len(us_len)): 195 | new_utter_vec = normalize_vec(us_vec[i], max_utter_len) 196 | new_utters_vec[i] = new_utter_vec 197 | new_utters_len[i] = us_len[i] 198 | x_utterances.append(new_utters_vec) 199 | x_utterances_len.append(new_utters_len) 200 | 201 | # normalize rs_vec and rs_len 202 | new_responses_vec = np.zeros((max_response_num, max_response_len), dtype='int32') 203 | new_responses_len = np.zeros((max_response_num, ), dtype='int32') 204 | for i in range(len(rs_len)): 205 | new_response_vec = normalize_vec(rs_vec[i], max_response_len) 206 | new_responses_vec[i] = new_response_vec 207 | new_responses_len[i] = rs_len[i] 208 | x_responses.append(new_responses_vec) 209 | x_responses_len.append(new_responses_len) 210 | 211 | x_labels.append(label) 212 | x_ids.append(us_id) 213 | x_utterances_num.append(us_num) 214 | 215 | # normalize us_CharVec and us_CharLen 216 | uttersCharVec = np.zeros((max_utter_num, max_utter_len, max_word_length), dtype='int32') 217 | uttersCharLen = np.ones((max_utter_num, max_utter_len), dtype='int32') 218 | for i in range(len(us_len)): 219 | utterCharVec, utterCharLen = charVec(us_tokens[i], charVocab, max_utter_len, max_word_length) 220 | uttersCharVec[i] = utterCharVec 221 | uttersCharLen[i] = utterCharLen 222 | x_utterances_char.append(uttersCharVec) 223 | x_utterances_char_len.append(uttersCharLen) 224 | 225 | # normalize rs_CharVec and rs_CharLen 226 | rsCharVec = np.zeros((max_response_num, max_response_len, max_word_length), dtype='int32') 227 | rsCharLen = np.ones((max_response_num, max_response_len), dtype='int32') 228 | for i in range(len(us_len)): 229 | rCharVec, rCharLen = charVec(rs_tokens[i], charVocab, max_response_len, max_word_length) 230 | rsCharVec[i] = rCharVec 231 | rsCharLen[i] = rCharLen 232 | x_responses_char.append(rsCharVec) 233 | x_responses_char_len.append(rsCharLen) 234 | 235 | # normalize ps_vec and ps_len 236 | new_personas_vec = np.zeros((max_persona_num, max_persona_len), dtype='int32') 237 | new_personas_len = np.zeros((max_persona_num, ), dtype='int32') 238 | for i in range(len(ps_len)): 239 | new_persona_vec = normalize_vec(ps_vec[i], max_persona_len) 240 | new_personas_vec[i] = new_persona_vec 241 | new_personas_len[i] = ps_len[i] 242 | x_personas.append(new_personas_vec) 243 | x_personas_len.append(new_personas_len) 244 | 245 | # normalize ps_CharVec and ps_CharLen 246 | psCharVec = np.zeros((max_persona_num, max_persona_len, max_word_length), dtype='int32') 247 | psCharLen = np.ones((max_persona_num, max_persona_len), dtype='int32') 248 | for i in range(len(ps_len)): 249 | pCharVec, pCharLen = charVec(ps_tokens[i], charVocab, max_persona_len, max_word_length) 250 | psCharVec[i] = pCharVec 251 | psCharLen[i] = pCharLen 252 | x_personas_char.append(psCharVec) 253 | x_personas_char_len.append(psCharLen) 254 | 255 | x_personas_num.append(ps_num) 256 | 257 | yield np.array(x_utterances), np.array(x_utterances_len), np.array(x_responses), np.array(x_responses_len), \ 258 | np.array(x_utterances_num), np.array(x_labels), x_ids, \ 259 | np.array(x_utterances_char), np.array(x_utterances_char_len), np.array(x_responses_char), np.array(x_responses_char_len), \ 260 | np.array(x_personas), np.array(x_personas_len), np.array(x_personas_char), np.array(x_personas_char_len), np.array(x_personas_num) 261 | 262 | -------------------------------------------------------------------------------- /model/eval.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | import time 5 | import datetime 6 | import operator 7 | import metrics 8 | from collections import defaultdict 9 | from model import data_helpers 10 | 11 | # Files 12 | tf.flags.DEFINE_string("test_file", "", "path to test file") 13 | tf.flags.DEFINE_string("vocab_file", "", "vocabulary file") 14 | tf.flags.DEFINE_string("char_vocab_file", "", "vocabulary file") 15 | tf.flags.DEFINE_string("output_file", "", "prediction output file") 16 | 17 | # Model Hyperparameters 18 | tf.flags.DEFINE_integer("max_utter_num", 15, "max utterance number") 19 | tf.flags.DEFINE_integer("max_utter_len", 20, "max utterance length") 20 | tf.flags.DEFINE_integer("max_response_num", 20, "max response candidate number") 21 | tf.flags.DEFINE_integer("max_response_len", 20, "max response length") 22 | tf.flags.DEFINE_integer("max_persona_num", 5, "max persona number") 23 | tf.flags.DEFINE_integer("max_persona_len", 15, "max persona length") 24 | tf.flags.DEFINE_integer("max_word_length", 18, "max word length") 25 | 26 | # Test parameters 27 | tf.flags.DEFINE_integer("batch_size", 32, "Batch Size (default: 64)") 28 | tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run") 29 | 30 | # Misc Parameters 31 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") 32 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") 33 | 34 | FLAGS = tf.flags.FLAGS 35 | FLAGS._parse_flags() 36 | print("\nParameters:") 37 | for attr, value in sorted(FLAGS.__flags.items()): 38 | print("{}={}".format(attr.upper(), value)) 39 | print("") 40 | 41 | vocab = data_helpers.load_vocab(FLAGS.vocab_file) 42 | print('vocabulary size: {}'.format(len(vocab))) 43 | charVocab = data_helpers.load_char_vocab(FLAGS.char_vocab_file) 44 | print('charVocab size: {}'.format(len(charVocab))) 45 | 46 | test_dataset = data_helpers.load_dataset(FLAGS.test_file, vocab, FLAGS.max_utter_num, FLAGS.max_utter_len, FLAGS.max_response_len, FLAGS.max_persona_len) 47 | print('test dataset size: {}'.format(len(test_dataset))) 48 | 49 | print("\nEvaluating...\n") 50 | 51 | checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 52 | print(checkpoint_file) 53 | 54 | graph = tf.Graph() 55 | with graph.as_default(): 56 | session_conf = tf.ConfigProto( 57 | allow_soft_placement=FLAGS.allow_soft_placement, 58 | log_device_placement=FLAGS.log_device_placement) 59 | sess = tf.Session(config=session_conf) 60 | with sess.as_default(): 61 | # Load the saved meta graph and restore variables 62 | saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 63 | saver.restore(sess, checkpoint_file) 64 | 65 | # Get the placeholders from the graph by name 66 | utterances = graph.get_operation_by_name("utterances").outputs[0] 67 | utterances_len = graph.get_operation_by_name("utterances_len").outputs[0] 68 | 69 | responses = graph.get_operation_by_name("responses").outputs[0] 70 | responses_len = graph.get_operation_by_name("responses_len").outputs[0] 71 | 72 | utterances_num = graph.get_operation_by_name("utterances_num").outputs[0] 73 | dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] 74 | 75 | u_char_feature = graph.get_operation_by_name("utterances_char").outputs[0] 76 | u_char_len = graph.get_operation_by_name("utterances_char_len").outputs[0] 77 | 78 | r_char_feature = graph.get_operation_by_name("responses_char").outputs[0] 79 | r_char_len = graph.get_operation_by_name("responses_char_len").outputs[0] 80 | 81 | personas = graph.get_operation_by_name("personas").outputs[0] 82 | personas_len = graph.get_operation_by_name("personas_len").outputs[0] 83 | p_char_feature = graph.get_operation_by_name("personas_char").outputs[0] 84 | p_char_len = graph.get_operation_by_name("personas_char_len").outputs[0] 85 | personas_num = graph.get_operation_by_name("personas_num").outputs[0] 86 | 87 | # Tensors we want to evaluate 88 | pred_prob = graph.get_operation_by_name("prediction_layer/prob").outputs[0] 89 | 90 | results = defaultdict(list) 91 | num_test = 0 92 | test_batches = data_helpers.batch_iter(test_dataset, FLAGS.batch_size, 1, FLAGS.max_utter_num, FLAGS.max_utter_len, \ 93 | FLAGS.max_response_num, FLAGS.max_response_len, FLAGS.max_persona_num, FLAGS.max_persona_len, \ 94 | charVocab, FLAGS.max_word_length, shuffle=False) 95 | for test_batch in test_batches: 96 | x_utterances, x_utterances_len, x_response, x_response_len,\ 97 | x_utters_num, x_target, x_ids, \ 98 | x_u_char, x_u_char_len, x_r_char, x_r_char_len, \ 99 | x_personas, x_personas_len, x_p_char, x_p_char_len, x_personas_num = test_batch 100 | feed_dict = { 101 | utterances: x_utterances, 102 | utterances_len: x_utterances_len, 103 | responses: x_response, 104 | responses_len: x_response_len, 105 | utterances_num: x_utters_num, 106 | dropout_keep_prob: 1.0, 107 | u_char_feature: x_u_char, 108 | u_char_len: x_u_char_len, 109 | r_char_feature: x_r_char, 110 | r_char_len: x_r_char_len, 111 | personas: x_personas, 112 | personas_len: x_personas_len, 113 | p_char_feature: x_p_char, 114 | p_char_len: x_p_char_len, 115 | personas_num: x_personas_num 116 | } 117 | predicted_prob = sess.run(pred_prob, feed_dict) 118 | num_test += len(predicted_prob) 119 | print('num_test_sample={}'.format(num_test)) 120 | 121 | for i in range(len(predicted_prob)): 122 | probs = predicted_prob[i] 123 | us_id = x_ids[i] 124 | label = x_target[i] 125 | labels = np.zeros(FLAGS.max_response_num) 126 | labels[label] = 1 127 | for r_id, prob in enumerate(probs): 128 | results[us_id].append((str(r_id), labels[r_id], prob)) 129 | 130 | accu, precision, recall, f1, loss = metrics.classification_metrics(results) 131 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss)) 132 | 133 | mvp = metrics.mean_average_precision(results) 134 | mrr = metrics.mean_reciprocal_rank(results) 135 | top_1_precision = metrics.top_1_precision(results) 136 | total_valid_query = metrics.get_num_valid_query(results) 137 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tTop-1 precision: {}\tNum_query: {}'.format(mvp, mrr, top_1_precision, total_valid_query)) 138 | 139 | out_path = FLAGS.output_file 140 | print("Saving evaluation to {}".format(out_path)) 141 | with open(out_path, 'w') as f: 142 | f.write("query_id\tdocument_id\tscore\trank\trelevance\n") 143 | for us_id, v in results.items(): 144 | v.sort(key=operator.itemgetter(2), reverse=True) 145 | for i, rec in enumerate(v): 146 | r_id, label, prob_score = rec 147 | rank = i+1 148 | f.write('{}\t{}\t{}\t{}\t{}\n'.format(us_id, r_id, prob_score, rank, label)) 149 | -------------------------------------------------------------------------------- /model/metrics.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import math 3 | 4 | def is_valid_query(v): 5 | num_pos = 0 6 | num_neg = 0 7 | for aid, label, score in v: 8 | if label > 0: 9 | num_pos += 1 10 | else: 11 | num_neg += 1 12 | if num_pos > 0 and num_neg > 0: 13 | return True 14 | else: 15 | return False 16 | 17 | def get_num_valid_query(results): 18 | num_query = 0 19 | for k, v in results.items(): 20 | if not is_valid_query(v): 21 | continue 22 | num_query += 1 23 | return num_query 24 | 25 | def top_1_precision(results): 26 | num_query = 0 27 | top_1_correct = 0.0 28 | for k, v in results.items(): 29 | if not is_valid_query(v): 30 | continue 31 | num_query += 1 32 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True) 33 | aid, label, score = sorted_v[0] 34 | if label > 0: 35 | top_1_correct += 1 36 | 37 | if num_query > 0: 38 | return top_1_correct/num_query 39 | else: 40 | return 0.0 41 | 42 | def mean_reciprocal_rank(results): 43 | num_query = 0 44 | mrr = 0.0 45 | for k, v in results.items(): 46 | if not is_valid_query(v): 47 | continue 48 | 49 | num_query += 1 50 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True) 51 | for i, rec in enumerate(sorted_v): 52 | aid, label, score = rec 53 | if label > 0: 54 | mrr += 1.0/(i+1) 55 | break 56 | 57 | if num_query == 0: 58 | return 0.0 59 | else: 60 | mrr = mrr/num_query 61 | return mrr 62 | 63 | def mean_average_precision(results): 64 | num_query = 0 65 | mvp = 0.0 66 | for k, v in results.items(): 67 | if not is_valid_query(v): 68 | continue 69 | 70 | num_query += 1 71 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True) 72 | num_relevant_doc = 0.0 73 | avp = 0.0 74 | for i, rec in enumerate(sorted_v): 75 | aid, label, score = rec 76 | if label == 1: 77 | num_relevant_doc += 1 78 | precision = num_relevant_doc/(i+1) 79 | avp += precision 80 | avp = avp/num_relevant_doc 81 | mvp += avp 82 | 83 | if num_query == 0: 84 | return 0.0 85 | else: 86 | mvp = mvp/num_query 87 | return mvp 88 | 89 | def classification_metrics(results): 90 | total_num = 0 91 | total_correct = 0 92 | true_positive = 0 93 | positive_correct = 0 94 | predicted_positive = 0 95 | 96 | loss = 0.0; 97 | for k, v in results.items(): 98 | for rec in v: 99 | total_num += 1 100 | aid, label, score = rec 101 | 102 | 103 | if score > 0.5: 104 | predicted_positive += 1 105 | 106 | if label > 0: 107 | true_positive += 1 108 | loss += -math.log(score+1e-12) 109 | else: 110 | loss += -math.log(1.0 - score + 1e-12); 111 | 112 | if score > 0.5 and label > 0: 113 | total_correct += 1 114 | positive_correct += 1 115 | 116 | if score < 0.5 and label < 0.5: 117 | total_correct += 1 118 | 119 | accuracy = float(total_correct)/total_num 120 | precision = float(positive_correct)/(predicted_positive+1e-12) 121 | recall = float(positive_correct)/true_positive 122 | F1 = 2.0 * precision * recall/(1e-12+precision + recall) 123 | return accuracy, precision, recall, F1, loss/total_num; 124 | -------------------------------------------------------------------------------- /model/model_DIM.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | FLAGS = tf.flags.FLAGS 5 | 6 | def get_embeddings(vocab): 7 | print("get_embedding") 8 | initializer = load_word_embeddings(vocab, FLAGS.embedding_dim) 9 | return tf.constant(initializer, name="word_embedding") 10 | 11 | def get_char_embedding(charVocab): 12 | print("get_char_embedding") 13 | char_size = len(charVocab) 14 | embeddings = np.zeros((char_size, char_size), dtype='float32') 15 | for i in range(1, char_size): 16 | embeddings[i, i] = 1.0 17 | return tf.constant(embeddings, name="word_char_embedding") 18 | 19 | def load_embed_vectors(fname, dim): 20 | vectors = {} 21 | for line in open(fname, 'rt'): 22 | items = line.strip().split(' ') 23 | if len(items[0]) <= 0: 24 | continue 25 | vec = [float(items[i]) for i in range(1, dim+1)] 26 | vectors[items[0]] = vec 27 | return vectors 28 | 29 | def load_word_embeddings(vocab, dim): 30 | vectors = load_embed_vectors(FLAGS.embedded_vector_file, dim) 31 | vocab_size = len(vocab) 32 | embeddings = np.zeros((vocab_size, dim), dtype='float32') 33 | for word, code in vocab.items(): 34 | if word in vectors: 35 | embeddings[code] = vectors[word] 36 | #else: 37 | # embeddings[code] = np.random.uniform(-0.25, 0.25, dim) 38 | return embeddings 39 | 40 | 41 | def lstm_layer(inputs, input_seq_len, rnn_size, dropout_keep_prob, scope, scope_reuse=False): 42 | with tf.variable_scope(scope, reuse=scope_reuse) as vs: 43 | fw_cell = tf.contrib.rnn.LSTMCell(rnn_size, forget_bias=1.0, state_is_tuple=True, reuse=scope_reuse) 44 | fw_cell = tf.contrib.rnn.DropoutWrapper(fw_cell, output_keep_prob=dropout_keep_prob) 45 | bw_cell = tf.contrib.rnn.LSTMCell(rnn_size, forget_bias=1.0, state_is_tuple=True, reuse=scope_reuse) 46 | bw_cell = tf.contrib.rnn.DropoutWrapper(bw_cell, output_keep_prob=dropout_keep_prob) 47 | rnn_outputs, rnn_states = tf.nn.bidirectional_dynamic_rnn(cell_fw=fw_cell, cell_bw=bw_cell, 48 | inputs=inputs, 49 | sequence_length=input_seq_len, 50 | dtype=tf.float32) 51 | return rnn_outputs, rnn_states 52 | 53 | def cnn_layer(inputs, filter_sizes, num_filters, scope=None, scope_reuse=False): 54 | with tf.variable_scope(scope, reuse=scope_reuse): 55 | input_size = inputs.get_shape()[2].value 56 | 57 | outputs = [] 58 | for i, filter_size in enumerate(filter_sizes): 59 | with tf.variable_scope("conv_{}".format(i)): 60 | w = tf.get_variable("w", [filter_size, input_size, num_filters]) 61 | b = tf.get_variable("b", [num_filters]) 62 | conv = tf.nn.conv1d(inputs, w, stride=1, padding="VALID") # [num_words, num_chars - filter_size, num_filters] 63 | h = tf.nn.relu(tf.nn.bias_add(conv, b)) # [num_words, num_chars - filter_size, num_filters] 64 | pooled = tf.reduce_max(h, 1) # [num_words, num_filters] 65 | outputs.append(pooled) 66 | return tf.concat(outputs, 1) # [num_words, num_filters * len(filter_sizes)] 67 | 68 | 69 | def attended_response(similarity_matrix, contexts, flattened_utters_len, max_utter_len, max_utter_num): 70 | # similarity_matrix: [batch_size, max_response_num, max_response_len, max_utter_num*max_utter_len] 71 | # contexts: [batch_size, max_utter_num*max_utter_len, dim] 72 | # flattened_utters_len: [batch_size* max_utter_num, ] 73 | max_response_num = similarity_matrix.get_shape()[1].value 74 | 75 | # masked similarity_matrix 76 | mask_c = tf.sequence_mask(flattened_utters_len, max_utter_len, dtype=tf.float32) # [batch_size*max_utter_num, max_utter_len] 77 | mask_c = tf.reshape(mask_c, [-1, max_utter_num*max_utter_len]) # [batch_size, max_utter_num*max_utter_len] 78 | mask_c = tf.expand_dims(mask_c, 1) # [batch_size, 1, max_utter_num*max_utter_len] 79 | mask_c = tf.expand_dims(mask_c, 2) # [batch_size, 1, 1, max_utter_num*max_utter_len] 80 | similarity_matrix = similarity_matrix * mask_c + -1e9 * (1-mask_c) # [batch_size, max_response_num, response_len, max_utter_num*max_utter_len] 81 | 82 | attention_weight_for_c = tf.nn.softmax(similarity_matrix, dim=-1) # [batch_size, max_response_num, response_len, max_utter_num*max_utter_len] 83 | contexts_tiled = tf.tile(tf.expand_dims(contexts, 1), [1, max_response_num, 1, 1])# [batch_size, max_response_num, max_utter_num*max_utter_len, dim] 84 | attended_response = tf.matmul(attention_weight_for_c, contexts_tiled) # [batch_size, max_response_num, response_len, dim] 85 | 86 | return attended_response 87 | 88 | def attended_context(similarity_matrix, responses, flattened_responses_len, max_response_len, max_response_num): 89 | # similarity_matrix: [batch_size, max_response_num, max_response_len, max_utter_num*max_utter_len] 90 | # responses: [batch_size, max_response_num, max_response_len, dim] 91 | # flattened_responses_len: [batch_size* max_response_num, ] 92 | 93 | # masked similarity_matrix 94 | mask_r = tf.sequence_mask(flattened_responses_len, max_response_len, dtype=tf.float32) # [batch_size*max_response_num, max_response_len] 95 | mask_r = tf.reshape(mask_r, [-1, max_response_num, max_response_len]) # [batch_size, max_response_num, max_response_len] 96 | mask_r = tf.expand_dims(mask_r, -1) # [batch_size, max_response_num, max_response_len, 1] 97 | similarity_matrix = similarity_matrix * mask_r + -1e9 * (1-mask_r) # [batch_size, max_response_num, max_response_len, max_utter_num*max_utter_len] 98 | 99 | attention_weight_for_r = tf.nn.softmax(tf.transpose(similarity_matrix, perm=[0,1,3,2]), dim=-1) # [batch_size, max_response_num, max_utter_num*max_utter_len, response_len] 100 | attended_context = tf.matmul(attention_weight_for_r, responses) # [batch_size, max_response_num, max_utter_num*max_utter_len, dim] 101 | 102 | return attended_context 103 | 104 | 105 | class DIM(object): 106 | def __init__( 107 | self, max_utter_num, max_utter_len, max_response_num, max_response_len, max_persona_num, max_persona_len, 108 | vocab_size, embedding_size, vocab, rnn_size, maxWordLength, charVocab, l2_reg_lambda=0.0): 109 | 110 | self.utterances = tf.placeholder(tf.int32, [None, max_utter_num, max_utter_len], name="utterances") 111 | self.utterances_len = tf.placeholder(tf.int32, [None, max_utter_num], name="utterances_len") 112 | self.utters_num = tf.placeholder(tf.int32, [None], name="utterances_num") 113 | 114 | self.responses = tf.placeholder(tf.int32, [None, max_response_num, max_response_len], name="responses") 115 | self.responses_len = tf.placeholder(tf.int32, [None, max_response_num], name="responses_len") 116 | 117 | self.personas = tf.placeholder(tf.int32, [None, max_persona_num, max_persona_len], name="personas") 118 | self.personas_len = tf.placeholder(tf.int32, [None, max_persona_num], name="personas_len") 119 | self.personas_num = tf.placeholder(tf.int32, [None], name="personas_num") 120 | 121 | self.target = tf.placeholder(tf.int64, [None], name="target") 122 | self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob") 123 | 124 | self.u_charVec = tf.placeholder(tf.int32, [None, max_utter_num, max_utter_len, maxWordLength], name="utterances_char") 125 | self.u_charLen = tf.placeholder(tf.int32, [None, max_utter_num, max_utter_len], name="utterances_char_len") 126 | 127 | self.r_charVec = tf.placeholder(tf.int32, [None, max_response_num, max_response_len, maxWordLength], name="responses_char") 128 | self.r_charLen = tf.placeholder(tf.int32, [None, max_response_num, max_response_len], name="responses_char_len") 129 | 130 | self.p_charVec = tf.placeholder(tf.int32, [None, max_persona_num, max_persona_len, maxWordLength], name="personas_char") 131 | self.p_charLen = tf.placeholder(tf.int32, [None, max_persona_num, max_persona_len], name="personas_char_len") 132 | 133 | l2_loss = tf.constant(1.0) 134 | 135 | # =============================== Embedding layer =============================== 136 | # word embedding 137 | with tf.name_scope("embedding"): 138 | W = get_embeddings(vocab) 139 | utterances_embedded = tf.nn.embedding_lookup(W, self.utterances) # [batch_size, max_utter_num, max_utter_len, word_dim] 140 | responses_embedded = tf.nn.embedding_lookup(W, self.responses) # [batch_size, max_response_num, max_response_len, word_dim] 141 | personas_embedded = tf.nn.embedding_lookup(W, self.personas) # [batch_size, max_persona_num, max_persona_len, word_dim] 142 | print("original utterances_embedded: {}".format(utterances_embedded.get_shape())) 143 | print("original responses_embedded: {}".format(responses_embedded.get_shape())) 144 | print("original personas_embedded: {}".format(personas_embedded.get_shape())) 145 | 146 | with tf.name_scope('char_embedding'): 147 | char_W = get_char_embedding(charVocab) 148 | utterances_char_embedded = tf.nn.embedding_lookup(char_W, self.u_charVec) # [batch_size, max_utter_num, max_utter_len, maxWordLength, char_dim] 149 | responses_char_embedded = tf.nn.embedding_lookup(char_W, self.r_charVec) # [batch_size, max_response_num, max_response_len, maxWordLength, char_dim] 150 | personas_char_embedded = tf.nn.embedding_lookup(char_W, self.p_charVec) # [batch_size, max_persona_num, max_persona_len, maxWordLength, char_dim] 151 | print("utterances_char_embedded: {}".format(utterances_char_embedded.get_shape())) 152 | print("responses_char_embedded: {}".format(responses_char_embedded.get_shape())) 153 | print("personas_char_embedded: {}".format(personas_char_embedded.get_shape())) 154 | 155 | char_dim = utterances_char_embedded.get_shape()[-1].value 156 | utterances_char_embedded = tf.reshape(utterances_char_embedded, [-1, maxWordLength, char_dim]) # [batch_size*max_utter_num*max_utter_len, maxWordLength, char_dim] 157 | responses_char_embedded = tf.reshape(responses_char_embedded, [-1, maxWordLength, char_dim]) # [batch_size*max_response_num*max_response_len, maxWordLength, char_dim] 158 | personas_char_embedded = tf.reshape(personas_char_embedded, [-1, maxWordLength, char_dim]) # [batch_size*max_persona_num*max_persona_len, maxWordLength, char_dim] 159 | 160 | # char embedding 161 | utterances_cnn_char_emb = cnn_layer(utterances_char_embedded, filter_sizes=[3, 4, 5], num_filters=50, scope="CNN_char_emb", scope_reuse=False) # [batch_size*max_utter_num*max_utter_len, emb] 162 | cnn_char_dim = utterances_cnn_char_emb.get_shape()[1].value 163 | utterances_cnn_char_emb = tf.reshape(utterances_cnn_char_emb, [-1, max_utter_num, max_utter_len, cnn_char_dim]) # [batch_size, max_utter_num, max_utter_len, emb] 164 | 165 | responses_cnn_char_emb = cnn_layer(responses_char_embedded, filter_sizes=[3, 4, 5], num_filters=50, scope="CNN_char_emb", scope_reuse=True) # [batch_size*max_response_num*max_response_len, emb] 166 | responses_cnn_char_emb = tf.reshape(responses_cnn_char_emb, [-1, max_response_num, max_response_len, cnn_char_dim]) # [batch_size, max_response_num, max_response_len, emb] 167 | 168 | personas_cnn_char_emb = cnn_layer(personas_char_embedded, filter_sizes=[3, 4, 5], num_filters=50, scope="CNN_char_emb", scope_reuse=True) # [batch_size*max_persona_num*max_persona_len, emb] 169 | personas_cnn_char_emb = tf.reshape(personas_cnn_char_emb, [-1, max_persona_num, max_persona_len, cnn_char_dim]) # [batch_size, max_persona_num, max_persona_len, emb] 170 | 171 | utterances_embedded = tf.concat(axis=-1, values=[utterances_embedded, utterances_cnn_char_emb]) # [batch_size, max_utter_num, max_utter_len, emb] 172 | responses_embedded = tf.concat(axis=-1, values=[responses_embedded, responses_cnn_char_emb]) # [batch_size, max_response_num, max_response_len, emb] 173 | personas_embedded = tf.concat(axis=-1, values=[personas_embedded, personas_cnn_char_emb]) # [batch_size, max_persona_num, max_persona_len, emb] 174 | utterances_embedded = tf.nn.dropout(utterances_embedded, keep_prob=self.dropout_keep_prob) 175 | responses_embedded = tf.nn.dropout(responses_embedded, keep_prob=self.dropout_keep_prob) 176 | personas_embedded = tf.nn.dropout(personas_embedded, keep_prob=self.dropout_keep_prob) 177 | print("utterances_embedded: {}".format(utterances_embedded.get_shape())) 178 | print("responses_embedded: {}".format(responses_embedded.get_shape())) 179 | print("personas_embedded: {}".format(personas_embedded.get_shape())) 180 | 181 | 182 | # =============================== Encoding layer =============================== 183 | with tf.variable_scope("encoding_layer") as vs: 184 | 185 | emb_dim = utterances_embedded.get_shape()[-1].value 186 | flattened_utterances_embedded = tf.reshape(utterances_embedded, [-1, max_utter_len, emb_dim]) # [batch_size*max_utter_num, max_utter_len, emb] 187 | flattened_utterances_len = tf.reshape(self.utterances_len, [-1]) # [batch_size*max_utter_num, ] 188 | flattened_responses_embedded = tf.reshape(responses_embedded, [-1, max_response_len, emb_dim]) # [batch_size*max_response_num, max_response_len, emb] 189 | flattened_responses_len = tf.reshape(self.responses_len, [-1]) # [batch_size*max_response_num, ] 190 | flattened_personas_embedded = tf.reshape(personas_embedded, [-1, max_persona_len, emb_dim]) # [batch_size*max_persona_num, max_persona_len, emb] 191 | flattened_personas_len = tf.reshape(self.personas_len, [-1]) # [batch_size*max_persona_num, ] 192 | 193 | rnn_scope_name = "bidirectional_rnn" 194 | u_rnn_output, u_rnn_states = lstm_layer(flattened_utterances_embedded, flattened_utterances_len, rnn_size, self.dropout_keep_prob, rnn_scope_name, scope_reuse=False) 195 | utterances_output = tf.concat(axis=2, values=u_rnn_output) # [batch_size*max_utter_num, max_utter_len, rnn_size*2] 196 | r_rnn_output, r_rnn_states = lstm_layer(flattened_responses_embedded, flattened_responses_len, rnn_size, self.dropout_keep_prob, rnn_scope_name, scope_reuse=True) 197 | responses_output = tf.concat(axis=2, values=r_rnn_output) # [batch_size*max_response_num, max_response_len, rnn_size*2] 198 | p_rnn_output, p_rnn_states = lstm_layer(flattened_personas_embedded, flattened_personas_len, rnn_size, self.dropout_keep_prob, rnn_scope_name, scope_reuse=True) 199 | personas_output = tf.concat(axis=2, values=p_rnn_output) # [batch_size*max_persona_num, max_persona_len, rnn_size*2] 200 | print("encoded utterances : {}".format(utterances_output.shape)) 201 | print("encoded responses : {}".format(responses_output.shape)) 202 | print("encoded personas : {}".format(personas_output.shape)) 203 | 204 | 205 | # =============================== Matching layer =============================== 206 | with tf.variable_scope("matching_layer") as vs: 207 | 208 | output_dim = utterances_output.get_shape()[-1].value 209 | utterances_output = tf.reshape(utterances_output, [-1, max_utter_num*max_utter_len, output_dim]) # [batch_size, max_utter_num*max_utter_len, rnn_size*2] 210 | utterances_output_tiled = tf.tile(tf.expand_dims(utterances_output, 1), [1, max_response_num, 1, 1]) # [batch_size, max_response_num, max_utter_num*max_utter_len, rnn_size*2] 211 | responses_output = tf.reshape(responses_output, [-1, max_response_num, max_response_len, output_dim]) # [batch_size, max_response_num, max_response_len, rnn_size*2] 212 | personas_output = tf.reshape(personas_output, [-1, max_persona_num*max_persona_len, output_dim]) # [batch_size, max_persona_num*max_persona_len, rnn_size*2] 213 | personas_output_tiled = tf.tile(tf.expand_dims(personas_output, 1), [1, max_response_num, 1, 1]) # [batch_size, max_response_num, max_persona_num*max_persona_len, rnn_size*2] 214 | 215 | # 1. cross-attention between context and response 216 | similarity_UR = tf.matmul(responses_output, # [batch_size, max_response_num, response_len, max_utter_num*max_utter_len] 217 | tf.transpose(utterances_output_tiled, perm=[0,1,3,2]), name='similarity_matrix_UR') 218 | attended_utterances_output_ur = attended_context(similarity_UR, responses_output, flattened_responses_len, max_response_len, max_response_num) # [batch_size, max_response_num, max_utter_num*max_utter_len, dim] 219 | attended_responses_output_ur = attended_response(similarity_UR, utterances_output, flattened_utterances_len, max_utter_len, max_utter_num) # [batch_size, max_response_num, response_len, dim] 220 | 221 | m_u_ur = tf.concat(axis=-1, values=[utterances_output_tiled, attended_utterances_output_ur, tf.multiply(utterances_output_tiled, attended_utterances_output_ur), utterances_output_tiled-attended_utterances_output_ur]) # [batch_size, max_response_num, max_utter_num*max_utter_len, dim] 222 | m_r_ur = tf.concat(axis=-1, values=[responses_output, attended_responses_output_ur, tf.multiply(responses_output, attended_responses_output_ur), responses_output-attended_responses_output_ur]) # [batch_size, max_response_num, response_len, dim] 223 | concat_dim = m_u_ur.get_shape()[-1].value 224 | m_u_ur = tf.reshape(m_u_ur, [-1, max_utter_len, concat_dim]) # [batch_size*max_response_num*max_utter_num, max_utter_len, dim] 225 | m_r_ur = tf.reshape(m_r_ur, [-1, max_response_len, concat_dim]) # [batch_size*max_response_num, max_response_len, dim] 226 | 227 | rnn_scope_cross = 'bidirectional_rnn_cross' 228 | rnn_size_layer_2 = rnn_size 229 | tiled_flattened_utterances_len = tf.reshape(tf.tile(tf.expand_dims(self.utterances_len, 1), [1, max_response_num, 1]), [-1, ]) # [batch_size*max_response_num*max_utter_num, ] 230 | u_ur_rnn_output, u_ur_rnn_state = lstm_layer(m_u_ur, tiled_flattened_utterances_len, rnn_size_layer_2, self.dropout_keep_prob, rnn_scope_cross, scope_reuse=False) 231 | r_ur_rnn_output, r_ur_rnn_state = lstm_layer(m_r_ur, flattened_responses_len, rnn_size_layer_2, self.dropout_keep_prob, rnn_scope_cross, scope_reuse=True) 232 | utterances_output_cross_ur = tf.concat(axis=-1, values=u_ur_rnn_output) # [batch_size*max_response_num*max_utter_num, max_utter_len, rnn_size*2] 233 | responses_output_cross_ur = tf.concat(axis=-1, values=r_ur_rnn_output) # [batch_size*max_response_num, max_response_len, rnn_size*2] 234 | print("establish cross-attention between context and response") 235 | 236 | 237 | # 2. cross-attention between persona and response without decay 238 | similarity_PR = tf.matmul(responses_output, # [batch_size, max_response_num, response_len, max_persona_num*max_persona_len] 239 | tf.transpose(personas_output_tiled, perm=[0,1,3,2]), name='similarity_matrix_PR') 240 | attended_personas_output_pr = attended_context(similarity_PR, responses_output, flattened_responses_len, max_response_len, max_response_num) # [batch_size, max_response_num, max_persona_num*max_persona_len, dim] 241 | attended_responses_output_pr = attended_response(similarity_PR, personas_output, flattened_personas_len, max_persona_len, max_persona_num) # [batch_size, max_response_num, response_len, dim] 242 | 243 | m_p_pr = tf.concat(axis=-1, values=[personas_output_tiled, attended_personas_output_pr, tf.multiply(personas_output_tiled, attended_personas_output_pr), personas_output_tiled-attended_personas_output_pr]) # [batch_size, max_response_num, max_persona_num*max_persona_len, dim] 244 | m_r_pr = tf.concat(axis=-1, values=[responses_output, attended_responses_output_pr, tf.multiply(responses_output, attended_responses_output_pr), responses_output-attended_responses_output_pr]) # [batch_size, max_response_num, response_len, dim] 245 | m_p_pr = tf.reshape(m_p_pr, [-1, max_persona_len, concat_dim]) # [batch_size*max_response_num*max_persona_num, max_persona_len, dim] 246 | m_r_pr = tf.reshape(m_r_pr, [-1, max_response_len, concat_dim]) # [batch_size*max_response_num, max_response_len, dim] 247 | 248 | tiled_flattened_personas_len = tf.reshape(tf.tile(tf.expand_dims(self.personas_len, 1), [1, max_response_num, 1]), [-1, ]) # [batch_size*max_response_num*max_persona_num, ] 249 | p_pr_rnn_output, p_pr_rnn_state = lstm_layer(m_p_pr, tiled_flattened_personas_len, rnn_size_layer_2, self.dropout_keep_prob, rnn_scope_cross, scope_reuse=True) 250 | r_pr_rnn_output, r_pr_rnn_state = lstm_layer(m_r_pr, flattened_responses_len, rnn_size_layer_2, self.dropout_keep_prob, rnn_scope_cross, scope_reuse=True) 251 | personas_output_cross_pr = tf.concat(axis=-1, values=p_pr_rnn_output) # [batch_size*max_response_num*max_persona_num, max_persona_len, rnn_size*2] 252 | responses_output_cross_pr = tf.concat(axis=-1, values=r_pr_rnn_output) # [batch_size*max_response_num, max_response_len, rnn_size*2] 253 | print("establish cross-attention between persona and response") 254 | 255 | 256 | # =============================== Aggregation layer =============================== 257 | with tf.variable_scope("aggregation_layer") as vs: 258 | # aggregate utterance across utterance_len 259 | final_utterances_max = tf.reduce_max(utterances_output_cross_ur, axis=1) 260 | final_utterances_state = tf.concat(axis=1, values=[u_ur_rnn_state[0].h, u_ur_rnn_state[1].h]) 261 | final_utterances = tf.concat(axis=1, values=[final_utterances_max, final_utterances_state]) # [batch_size*max_response_num*max_utter_num, 4*rnn_size] 262 | 263 | # aggregate utterance across utterance_num 264 | final_utterances = tf.reshape(final_utterances, [-1, max_utter_num, output_dim*2]) # [batch_size*max_response_num, max_utter_num, 4*rnn_size] 265 | tiled_utters_num = tf.reshape(tf.tile(tf.expand_dims(self.utters_num, 1), [1, max_response_num]), [-1, ]) # [batch_size*max_response_num, ] 266 | rnn_scope_aggre = "bidirectional_rnn_aggregation" 267 | final_utterances_output, final_utterances_state = lstm_layer(final_utterances, tiled_utters_num, rnn_size, self.dropout_keep_prob, rnn_scope_aggre, scope_reuse=False) 268 | final_utterances_output = tf.concat(axis=2, values=final_utterances_output) # [batch_size*max_response_num, max_utter_num, 2*rnn_size] 269 | final_utterances_max = tf.reduce_max(final_utterances_output, axis=1) # [batch_size*max_response_num, 2*rnn_size] 270 | final_utterances_state = tf.concat(axis=1, values=[final_utterances_state[0].h, final_utterances_state[1].h]) # [batch_size*max_response_num, 2*rnn_size] 271 | aggregated_utterances = tf.concat(axis=1, values=[final_utterances_max, final_utterances_state]) # [batch_size*max_response_num, 4*rnn_size] 272 | 273 | # aggregate response across response_len 274 | final_responses_max = tf.reduce_max(responses_output_cross_ur, axis=1) # [batch_size*max_response_num, 2*rnn_size] 275 | final_responses_state = tf.concat(axis=1, values=[r_ur_rnn_state[0].h, r_ur_rnn_state[1].h]) # [batch_size*max_response_num, 2*rnn_size] 276 | aggregated_responses_ur = tf.concat(axis=1, values=[final_responses_max, final_responses_state]) # [batch_size*max_response_num, 4*rnn_size] 277 | print("establish RNN aggregation on context and response") 278 | 279 | 280 | # aggregate persona across persona_len 281 | final_personas_max = tf.reduce_max(personas_output_cross_pr, axis=1) # [batch_size*max_response_num*max_persona_num, 2*rnn_size] 282 | final_personas_state = tf.concat(axis=1, values=[p_pr_rnn_state[0].h, p_pr_rnn_state[1].h]) # [batch_size*max_response_num*max_persona_num, 2*rnn_size] 283 | final_personas = tf.concat(axis=1, values=[final_personas_max, final_personas_state]) # [batch_size*max_response_num*max_persona_num, 4*rnn_size] 284 | 285 | # aggregate persona across persona_num 286 | # 1. RNN aggregation 287 | # final_personas = tf.reshape(final_personas, [-1, max_persona_num, output_dim*2]) # [batch_size*max_response_num, max_persona_num, 4*rnn_size] 288 | # tiled_personas_num = tf.reshape(tf.tile(tf.expand_dims(self.personas_num, 1), [1, max_response_num]), [-1, ]) # [batch_size*max_response_num, ] 289 | # final_personas_output, final_personas_state = lstm_layer(final_personas, tiled_personas_num, rnn_size, self.dropout_keep_prob, rnn_scope_aggre, scope_reuse=True) 290 | # final_personas_output = tf.concat(axis=2, values=final_personas_output) # [batch_size*max_response_num, max_persona_num, 2*rnn_size] 291 | # final_personas_max = tf.reduce_max(final_personas_output, axis=1) # [batch_size*max_response_num, 2*rnn_size] 292 | # final_personas_state = tf.concat(axis=1, values=[final_personas_state[0].h, final_personas_state[1].h]) # [batch_size*max_response_num, 2*rnn_size] 293 | # aggregated_personas = tf.concat(axis=1, values=[final_personas_max, final_personas_state]) # [batch_size*max_response_num, 4*rnn_size] 294 | # print("establish RNN aggregation on persona") 295 | # 2. ATT aggregation 296 | final_personas = tf.reshape(final_personas, [-1, max_persona_num, output_dim*2]) # [batch_size*max_response_num, max_persona_num, 4*rnn_size] 297 | pers_w = tf.get_variable("pers_w", [output_dim*2, 1], initializer=tf.contrib.layers.xavier_initializer()) 298 | pers_b = tf.get_variable("pers_b", shape=[1, ], initializer=tf.zeros_initializer()) 299 | pers_weights = tf.nn.relu(tf.einsum('aij,jk->aik', final_personas, pers_w) + pers_b) # [batch_size*max_response_num, max_persona_num, 1] 300 | tiled_personas_num = tf.reshape(tf.tile(tf.expand_dims(self.personas_num, 1), [1, max_response_num]), [-1, ]) # [batch_size*max_response_num, ] 301 | mask_p = tf.expand_dims(tf.sequence_mask(tiled_personas_num, max_persona_num, dtype=tf.float32), -1) # [batch_size*max_response_num, max_persona_num, 1] 302 | pers_weights = pers_weights * mask_p + -1e9 * (1-mask_p) # [batch_size*max_response_num, max_persona_num, 1] 303 | pers_weights = tf.nn.softmax(pers_weights, dim=1) 304 | aggregated_personas = tf.matmul(tf.transpose(pers_weights, [0, 2, 1]), final_personas) # [batch_size*max_response_num, 1, 4*rnn_size] 305 | aggregated_personas = tf.squeeze(aggregated_personas, [1]) # [batch_size*max_response_num, 4*rnn_size] 306 | 307 | # aggregate response across response_len 308 | final_responses_max = tf.reduce_max(responses_output_cross_pr, axis=1) # [batch_size*max_response_num, 2*rnn_size] 309 | final_responses_state = tf.concat(axis=1, values=[r_pr_rnn_state[0].h, r_pr_rnn_state[1].h]) # [batch_size*max_response_num, 2*rnn_size] 310 | aggregated_responses_pr = tf.concat(axis=1, values=[final_responses_max, final_responses_state]) # [batch_size*max_response_num, 4*rnn_size] 311 | print("establish ATT aggregation on persona and response") 312 | 313 | joined_feature = tf.concat(axis=1, values=[aggregated_utterances, aggregated_responses_ur, aggregated_personas, aggregated_responses_pr]) # [batch_size*max_response_num, 16*rnn_size(3200)] 314 | print("joined feature: {}".format(joined_feature.get_shape())) 315 | 316 | 317 | # =============================== Prediction layer =============================== 318 | with tf.variable_scope("prediction_layer") as vs: 319 | hidden_input_size = joined_feature.get_shape()[1].value 320 | hidden_output_size = 256 321 | regularizer = tf.contrib.layers.l2_regularizer(l2_reg_lambda) 322 | #regularizer = None 323 | # dropout On MLP 324 | joined_feature = tf.nn.dropout(joined_feature, keep_prob=self.dropout_keep_prob) 325 | full_out = tf.contrib.layers.fully_connected(joined_feature, hidden_output_size, 326 | activation_fn=tf.nn.relu, 327 | reuse=False, 328 | trainable=True, 329 | scope="projected_layer") # [batch_size*max_response_num, hidden_output_size(256)] 330 | full_out = tf.nn.dropout(full_out, keep_prob=self.dropout_keep_prob) 331 | 332 | last_weight_dim = full_out.get_shape()[1].value 333 | print("last_weight_dim: {}".format(last_weight_dim)) 334 | bias = tf.Variable(tf.constant(0.1, shape=[1]), name="bias") 335 | s_w = tf.get_variable("s_w", shape=[last_weight_dim, 1], initializer=tf.contrib.layers.xavier_initializer()) 336 | logits = tf.reshape(tf.matmul(full_out, s_w) + bias, [-1, max_response_num]) # [batch_size, max_response_num] 337 | print("logits: {}".format(logits.get_shape())) 338 | 339 | self.probs = tf.nn.softmax(logits, name="prob") # [batch_size, max_response_num] 340 | 341 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=self.target) 342 | self.mean_loss = tf.reduce_mean(losses, name="mean_loss") + l2_reg_lambda * l2_loss + sum( 343 | tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) 344 | 345 | with tf.name_scope("accuracy"): 346 | correct_prediction = tf.equal(tf.argmax(self.probs, 1), self.target) 347 | self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy") 348 | -------------------------------------------------------------------------------- /model/train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | import time 5 | import datetime 6 | import operator 7 | from collections import defaultdict 8 | from model import metrics 9 | from model import data_helpers 10 | from model.model_DIM import DIM 11 | 12 | 13 | # Files 14 | tf.flags.DEFINE_string("train_file", "", "path to train file") 15 | tf.flags.DEFINE_string("valid_file", "", "path to valid file") 16 | tf.flags.DEFINE_string("vocab_file", "", "vocabulary file") 17 | tf.flags.DEFINE_string("char_vocab_file", "", "path to char vocab file") 18 | tf.flags.DEFINE_string("embedded_vector_file", "", "pre-trained embedded word vector") 19 | 20 | # Model Hyperparameters 21 | tf.flags.DEFINE_integer("max_utter_num", 15, "max utterance number") 22 | tf.flags.DEFINE_integer("max_utter_len", 20, "max utterance length") 23 | tf.flags.DEFINE_integer("max_response_num", 20, "max response candidate number") 24 | tf.flags.DEFINE_integer("max_response_len", 20, "max response length") 25 | tf.flags.DEFINE_integer("max_persona_num", 5, "max persona number") 26 | tf.flags.DEFINE_integer("max_persona_len", 15, "max persona length") 27 | tf.flags.DEFINE_integer("max_word_length", 18, "max word length") 28 | tf.flags.DEFINE_integer("embedding_dim", 200, "dimensionality of word embedding") 29 | tf.flags.DEFINE_integer("rnn_size", 200, "number of RNN units") 30 | 31 | # Training parameters 32 | tf.flags.DEFINE_integer("batch_size", 128, "batch size (default: 128)") 33 | tf.flags.DEFINE_float("l2_reg_lambda", 0, "L2 regularizaion lambda (default: 0)") 34 | tf.flags.DEFINE_float("dropout_keep_prob", 1.0, "dropout keep probability (default: 1.0)") 35 | tf.flags.DEFINE_integer("num_epochs", 1000000, "number of training epochs (default: 1000000)") 36 | tf.flags.DEFINE_integer("evaluate_every", 1000, "evaluate model on valid dataset after this many steps (default: 1000)") 37 | 38 | # Misc Parameters 39 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") 40 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") 41 | 42 | FLAGS = tf.flags.FLAGS 43 | FLAGS._parse_flags() 44 | print("\nParameters:") 45 | for attr, value in sorted(FLAGS.__flags.items()): 46 | print("{}={}".format(attr.upper(), value)) 47 | print("") 48 | 49 | # Load data 50 | print("Loading data...") 51 | 52 | vocab = data_helpers.load_vocab(FLAGS.vocab_file) 53 | print('vocabulary size: {}'.format(len(vocab))) 54 | charVocab = data_helpers.load_char_vocab(FLAGS.char_vocab_file) 55 | print('charVocab size: {}'.format(len(charVocab))) 56 | 57 | train_dataset = data_helpers.load_dataset(FLAGS.train_file, vocab, FLAGS.max_utter_num, FLAGS.max_utter_len, FLAGS.max_response_len, FLAGS.max_persona_len) 58 | print('train dataset size: {}'.format(len(train_dataset))) 59 | valid_dataset = data_helpers.load_dataset(FLAGS.valid_file, vocab, FLAGS.max_utter_num, FLAGS.max_utter_len, FLAGS.max_response_len, FLAGS.max_persona_len) 60 | print('valid dataset size: {}'.format(len(valid_dataset))) 61 | 62 | 63 | with tf.Graph().as_default(): 64 | session_conf = tf.ConfigProto( 65 | allow_soft_placement=FLAGS.allow_soft_placement, 66 | log_device_placement=FLAGS.log_device_placement) 67 | sess = tf.Session(config=session_conf) 68 | with sess.as_default(): 69 | dim = DIM( 70 | max_utter_num=FLAGS.max_utter_num, 71 | max_utter_len=FLAGS.max_utter_len, 72 | max_response_num=FLAGS.max_response_num, 73 | max_response_len=FLAGS.max_response_len, 74 | max_persona_num=FLAGS.max_persona_num, 75 | max_persona_len=FLAGS.max_persona_len, 76 | vocab_size=len(vocab), 77 | embedding_size=FLAGS.embedding_dim, 78 | vocab=vocab, 79 | rnn_size=FLAGS.rnn_size, 80 | maxWordLength=FLAGS.max_word_length, 81 | charVocab=charVocab, 82 | l2_reg_lambda=FLAGS.l2_reg_lambda) 83 | # Define Training procedure 84 | global_step = tf.Variable(0, name="global_step", trainable=False) 85 | starter_learning_rate = 0.001 86 | learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 87 | 5000, 0.96, staircase=True) 88 | optimizer = tf.train.AdamOptimizer(learning_rate) 89 | grads_and_vars = optimizer.compute_gradients(dim.mean_loss) 90 | train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) 91 | 92 | # Keep track of gradient values and sparsity (optional) 93 | """ 94 | grad_summaries = [] 95 | for g, v in grads_and_vars: 96 | if g is not None: 97 | grad_hist_summary = tf.histogram_summary("{}/grad/hist".format(v.name), g) 98 | sparsity_summary = tf.scalar_summary("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) 99 | grad_summaries.append(grad_hist_summary) 100 | grad_summaries.append(sparsity_summary) 101 | grad_summaries_merged = tf.merge_summary(grad_summaries) 102 | """ 103 | 104 | # Output directory for models and summaries 105 | timestamp = str(int(time.time())) 106 | out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp)) 107 | print("Writing to {}\n".format(out_dir)) 108 | 109 | # Summaries for loss and accuracy 110 | """ 111 | loss_summary = tf.scalar_summary("loss", dim.mean_loss) 112 | acc_summary = tf.scalar_summary("accuracy", dim.accuracy) 113 | 114 | # Train Summaries 115 | train_summary_op = tf.merge_summary([loss_summary, acc_summary, grad_summaries_merged]) 116 | train_summary_dir = os.path.join(out_dir, "summaries", "train") 117 | train_summary_writer = tf.train.SummaryWriter(train_summary_dir, sess.graph_def) 118 | 119 | # Dev summaries 120 | dev_summary_op = tf.merge_summary([loss_summary, acc_summary]) 121 | dev_summary_dir = os.path.join(out_dir, "summaries", "dev") 122 | dev_summary_writer = tf.train.SummaryWriter(dev_summary_dir, sess.graph_def) 123 | """ 124 | 125 | # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it 126 | checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints")) 127 | checkpoint_prefix = os.path.join(checkpoint_dir, "model") 128 | if not os.path.exists(checkpoint_dir): 129 | os.makedirs(checkpoint_dir) 130 | saver = tf.train.Saver(tf.global_variables()) 131 | 132 | # Initialize all variables 133 | sess.run(tf.global_variables_initializer()) 134 | 135 | def train_step(x_utterances, x_utterances_len, x_response, x_response_len, 136 | x_utters_num, x_target, x_ids, 137 | x_u_char, x_u_char_len, x_r_char, x_r_char_len, 138 | x_personas, x_personas_len, x_p_char, x_p_char_len, x_personas_num): 139 | """ 140 | A single training step 141 | """ 142 | feed_dict = { 143 | dim.utterances: x_utterances, 144 | dim.utterances_len: x_utterances_len, 145 | dim.responses: x_response, 146 | dim.responses_len: x_response_len, 147 | dim.utters_num: x_utters_num, 148 | dim.target: x_target, 149 | dim.dropout_keep_prob: FLAGS.dropout_keep_prob, 150 | dim.u_charVec: x_u_char, 151 | dim.u_charLen: x_u_char_len, 152 | dim.r_charVec: x_r_char, 153 | dim.r_charLen: x_r_char_len, 154 | dim.personas: x_personas, 155 | dim.personas_len: x_personas_len, 156 | dim.p_charVec: x_p_char, 157 | dim.p_charLen: x_p_char_len, 158 | dim.personas_num: x_personas_num 159 | } 160 | 161 | _, step, loss, accuracy, predicted_prob = sess.run( 162 | [train_op, global_step, dim.mean_loss, dim.accuracy, dim.probs], 163 | feed_dict) 164 | 165 | if step % 100 == 0: 166 | time_str = datetime.datetime.now().isoformat() 167 | print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy)) 168 | #train_summary_writer.add_summary(summaries, step) 169 | 170 | 171 | def dev_step(): 172 | results = defaultdict(list) 173 | num_test = 0 174 | num_correct = 0.0 175 | valid_batches = data_helpers.batch_iter(valid_dataset, FLAGS.batch_size, 1, FLAGS.max_utter_num, FLAGS.max_utter_len, \ 176 | FLAGS.max_response_num, FLAGS.max_response_len, FLAGS.max_persona_num, FLAGS.max_persona_len, \ 177 | charVocab, FLAGS.max_word_length, shuffle=True) 178 | for valid_batch in valid_batches: 179 | x_utterances, x_utterances_len, x_response, x_response_len, \ 180 | x_utters_num, x_target, x_ids, \ 181 | x_u_char, x_u_char_len, x_r_char, x_r_char_len, \ 182 | x_personas, x_personas_len, x_p_char, x_p_char_len, x_personas_num = valid_batch 183 | feed_dict = { 184 | dim.utterances: x_utterances, 185 | dim.utterances_len: x_utterances_len, 186 | dim.responses: x_response, 187 | dim.responses_len: x_response_len, 188 | dim.utters_num: x_utters_num, 189 | dim.target: x_target, 190 | dim.dropout_keep_prob: 1.0, 191 | dim.u_charVec: x_u_char, 192 | dim.u_charLen: x_u_char_len, 193 | dim.r_charVec: x_r_char, 194 | dim.r_charLen: x_r_char_len, 195 | dim.personas: x_personas, 196 | dim.personas_len: x_personas_len, 197 | dim.p_charVec: x_p_char, 198 | dim.p_charLen: x_p_char_len, 199 | dim.personas_num: x_personas_num 200 | } 201 | batch_accuracy, predicted_prob = sess.run([dim.accuracy, dim.probs], feed_dict) 202 | 203 | num_test += len(predicted_prob) 204 | if num_test % 1000 == 0: 205 | print(num_test) 206 | num_correct += len(predicted_prob) * batch_accuracy 207 | 208 | # predicted_prob = [batch_size, max_response_num] 209 | for i in range(len(predicted_prob)): 210 | probs = predicted_prob[i] 211 | us_id = x_ids[i] 212 | label = x_target[i] 213 | labels = np.zeros(FLAGS.max_response_num) 214 | labels[label] = 1 215 | for r_id, prob in enumerate(probs): 216 | results[us_id].append((str(r_id), labels[r_id], prob)) 217 | 218 | #calculate top-1 precision 219 | print('num_test_samples: {} test_accuracy: {}'.format(num_test, num_correct/num_test)) 220 | accu, precision, recall, f1, loss = metrics.classification_metrics(results) 221 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss)) 222 | 223 | mvp = metrics.mean_average_precision(results) 224 | mrr = metrics.mean_reciprocal_rank(results) 225 | top_1_precision = metrics.top_1_precision(results) 226 | total_valid_query = metrics.get_num_valid_query(results) 227 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tTop-1 precision: {}\tNum_query: {}'.format(mvp, mrr, top_1_precision, total_valid_query)) 228 | 229 | return mrr 230 | 231 | best_mrr = 0.0 232 | batches = data_helpers.batch_iter(train_dataset, FLAGS.batch_size, FLAGS.num_epochs, FLAGS.max_utter_num, FLAGS.max_utter_len, \ 233 | FLAGS.max_response_num, FLAGS.max_response_len, FLAGS.max_persona_num, FLAGS.max_persona_len, \ 234 | charVocab, FLAGS.max_word_length, shuffle=True) 235 | for batch in batches: 236 | x_utterances, x_utterances_len, x_response, x_response_len, \ 237 | x_utters_num, x_target, x_ids, \ 238 | x_u_char, x_u_char_len, x_r_char, x_r_char_len, \ 239 | x_personas, x_personas_len, x_p_char, x_p_char_len, x_personas_num = batch 240 | train_step(x_utterances, x_utterances_len, x_response, x_response_len, x_utters_num, x_target, x_ids, x_u_char, x_u_char_len, x_r_char, x_r_char_len, x_personas, x_personas_len, x_p_char, x_p_char_len, x_personas_num) 241 | current_step = tf.train.global_step(sess, global_step) 242 | if current_step % FLAGS.evaluate_every == 0: 243 | print("\nEvaluation:") 244 | valid_mrr = dev_step() 245 | if valid_mrr > best_mrr: 246 | best_mrr = valid_mrr 247 | path = saver.save(sess, checkpoint_prefix, global_step=current_step) 248 | print("Saved model checkpoint to {}\n".format(path)) 249 | -------------------------------------------------------------------------------- /scripts/compute_recall.py: -------------------------------------------------------------------------------- 1 | 2 | test_out_filename = "persona_test_out.txt" 3 | 4 | with open(test_out_filename, 'r') as f: 5 | cur_q_id = None 6 | num_query = 0 7 | recall = {"recall@1": 0, 8 | "recall@2": 0, 9 | "recall@5": 0, 10 | "recall@10": 0} 11 | 12 | lines = f.readlines() 13 | for line in lines[1:]: 14 | line = line.strip().split('\t') 15 | line = [float(ele) for ele in line] 16 | 17 | if cur_q_id is None: 18 | cur_q_id = line[0] 19 | num_query += 1 20 | elif line[0] != cur_q_id: 21 | cur_q_id = line[0] 22 | num_query += 1 23 | 24 | if line[4] == 1.0: 25 | rank = line[3] 26 | 27 | if rank <= 1: 28 | recall["recall@1"] += 1 29 | if rank <= 2: 30 | recall["recall@2"] += 1 31 | if rank <= 5: 32 | recall["recall@5"] += 1 33 | if rank <= 10: 34 | recall["recall@10"] += 1 35 | 36 | recall["recall@1"] = recall["recall@1"] / float(num_query) 37 | recall["recall@2"] = recall["recall@2"] / float(num_query) 38 | recall["recall@5"] = recall["recall@5"] / float(num_query) 39 | recall["recall@10"] = recall["recall@10"] / float(num_query) 40 | print("num_query = {}".format(num_query)) 41 | print("recall@1 = {}".format(recall["recall@1"])) 42 | print("recall@2 = {}".format(recall["recall@2"])) 43 | print("recall@5 = {}".format(recall["recall@5"])) 44 | print("recall@10 = {}".format(recall["recall@10"])) 45 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | cur_dir=`pwd` 2 | parentdir="$(dirname $cur_dir)" 3 | 4 | DATA_DIR=${parentdir}/data/personachat_processed 5 | 6 | latest_run=`ls -dt runs/* |head -n 1` 7 | latest_checkpoint=${latest_run}/checkpoints 8 | # latest_checkpoint=runs/1556416288/checkpoints 9 | echo $latest_checkpoint 10 | 11 | test_file=$DATA_DIR/processed_test_self_original.txt # for self_original 12 | # test_file=$DATA_DIR/processed_test_self_revised.txt # for self_revised 13 | # test_file=$DATA_DIR/processed_test_other_original.txt # for other_original 14 | # test_file=$DATA_DIR/processed_test_other_revised.txt # for other_revised 15 | vocab_file=$DATA_DIR/vocab.txt 16 | char_vocab_file=$DATA_DIR/char_vocab.txt 17 | output_file=./persona_test_out.txt 18 | 19 | max_utter_num=15 20 | max_utter_len=20 21 | max_response_num=20 22 | max_response_len=20 23 | max_persona_num=5 24 | max_persona_len=15 25 | max_word_length=18 26 | batch_size=32 27 | 28 | PKG_DIR=${parentdir} 29 | 30 | PYTHONPATH=${PKG_DIR}:$PYTHONPATH CUDA_VISIBLE_DEVICES=3 python -u ${PKG_DIR}/model/eval.py \ 31 | --test_file $test_file \ 32 | --vocab_file $vocab_file \ 33 | --char_vocab_file $char_vocab_file \ 34 | --output_file $output_file \ 35 | --max_utter_num $max_utter_num \ 36 | --max_utter_len $max_utter_len \ 37 | --max_response_num $max_response_num \ 38 | --max_response_len $max_response_len \ 39 | --max_persona_num $max_persona_num \ 40 | --max_persona_len $max_persona_len \ 41 | --max_word_length $max_word_length \ 42 | --batch_size $batch_size \ 43 | --checkpoint_dir $latest_checkpoint > log_DIM_test.txt 2>&1 & 44 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | cur_dir=`pwd` 2 | parentdir="$(dirname $cur_dir)" 3 | 4 | DATA_DIR=${parentdir}/data/personachat_processed 5 | 6 | # for self_original 7 | train_file=$DATA_DIR/processed_train_self_original.txt 8 | valid_file=$DATA_DIR/processed_valid_self_original.txt 9 | # for self_revised 10 | # train_file=$DATA_DIR/processed_train_self_revised.txt 11 | # valid_file=$DATA_DIR/processed_valid_self_revised.txt 12 | # for other_original 13 | # train_file=$DATA_DIR/processed_train_other_original.txt 14 | # valid_file=$DATA_DIR/processed_valid_other_original.txt 15 | # for other_revised 16 | # train_file=$DATA_DIR/processed_train_other_revised.txt 17 | # valid_file=$DATA_DIR/processed_valid_other_revised.txt 18 | 19 | vocab_file=$DATA_DIR/vocab.txt 20 | char_vocab_file=$DATA_DIR/char_vocab.txt 21 | embedded_vector_file=$DATA_DIR/glove_42B_300d_vec_plus_word2vec_100.txt 22 | 23 | max_utter_num=15 24 | max_utter_len=20 25 | max_response_num=20 26 | max_response_len=20 27 | max_persona_num=5 28 | max_persona_len=15 29 | max_word_length=18 30 | embedding_dim=400 31 | rnn_size=200 32 | 33 | batch_size=16 34 | lambda=0 35 | dropout_keep_prob=0.8 36 | num_epochs=10 37 | evaluate_every=500 38 | 39 | PKG_DIR=${parentdir} 40 | 41 | PYTHONPATH=${PKG_DIR}:$PYTHONPATH CUDA_VISIBLE_DEVICES=3 python -u ${PKG_DIR}/model/train.py \ 42 | --train_file $train_file \ 43 | --valid_file $valid_file \ 44 | --vocab_file $vocab_file \ 45 | --char_vocab_file $char_vocab_file \ 46 | --embedded_vector_file $embedded_vector_file \ 47 | --max_utter_num $max_utter_num \ 48 | --max_utter_len $max_utter_len \ 49 | --max_response_num $max_response_num \ 50 | --max_response_len $max_response_len \ 51 | --max_persona_num $max_persona_num \ 52 | --max_persona_len $max_persona_len \ 53 | --max_word_length $max_word_length \ 54 | --embedding_dim $embedding_dim \ 55 | --rnn_size $rnn_size \ 56 | --batch_size $batch_size \ 57 | --l2_reg_lambda $lambda \ 58 | --dropout_keep_prob $dropout_keep_prob \ 59 | --num_epochs $num_epochs \ 60 | --evaluate_every $evaluate_every > log_DIM_train.txt 2>&1 & 61 | --------------------------------------------------------------------------------