├── DialogueDataset.py ├── MSN.py ├── Metrics.py ├── NeuralNetwork.py ├── README.md ├── checkpoint └── README.md ├── dataset ├── DoubanConversaionCorpus │ └── README.md ├── E_commerce │ └── README.md └── ubuntu_data │ └── README.md ├── log ├── alime.msn.log ├── douban.msn.log └── ubuntu.msn.log └── run.py /DialogueDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import TensorDataset 3 | 4 | 5 | 6 | class DialogueDataset(TensorDataset): 7 | 8 | def __init__(self, X_utterances, X_responses, y_labels=None): 9 | super(DialogueDataset, self).__init__() 10 | X_utterances = torch.LongTensor(X_utterances) 11 | 12 | X_responses = torch.LongTensor(X_responses) 13 | print("X_utterances: ", X_utterances.size()) 14 | print("X_responses: ", X_responses.size()) 15 | 16 | if y_labels is not None: 17 | y_labels = torch.FloatTensor(y_labels) 18 | print("y_labels: ", y_labels.size()) 19 | self.tensors = [X_utterances, X_responses, y_labels] 20 | else: 21 | self.tensors = [X_utterances, X_responses] 22 | 23 | def __getitem__(self, index): 24 | return tuple(tensor[index] for tensor in self.tensors) 25 | 26 | def __len__(self): 27 | return len(self.tensors[0]) 28 | 29 | -------------------------------------------------------------------------------- /MSN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from NeuralNetwork import NeuralNetwork 6 | 7 | 8 | class TransformerBlock(nn.Module): 9 | 10 | def __init__(self, input_size, is_layer_norm=False): 11 | super(TransformerBlock, self).__init__() 12 | self.is_layer_norm = is_layer_norm 13 | if is_layer_norm: 14 | self.layer_morm = nn.LayerNorm(normalized_shape=input_size) 15 | 16 | self.relu = nn.ReLU() 17 | self.linear1 = nn.Linear(input_size, input_size) 18 | self.linear2 = nn.Linear(input_size, input_size) 19 | self.init_weights() 20 | 21 | def init_weights(self): 22 | init.xavier_normal_(self.linear1.weight) 23 | init.xavier_normal_(self.linear2.weight) 24 | 25 | def FFN(self, X): 26 | return self.linear2(self.relu(self.linear1(X))) 27 | 28 | def forward(self, Q, K, V, episilon=1e-8): 29 | ''' 30 | :param Q: (batch_size, max_r_words, embedding_dim) 31 | :param K: (batch_size, max_u_words, embedding_dim) 32 | :param V: (batch_size, max_u_words, embedding_dim) 33 | :return: output: (batch_size, max_r_words, embedding_dim) same size as Q 34 | ''' 35 | dk = torch.Tensor([max(1.0, Q.size(-1))]).cuda() 36 | 37 | Q_K = Q.bmm(K.permute(0, 2, 1)) / (torch.sqrt(dk) + episilon) 38 | Q_K_score = F.softmax(Q_K, dim=-1) # (batch_size, max_r_words, max_u_words) 39 | V_att = Q_K_score.bmm(V) 40 | 41 | if self.is_layer_norm: 42 | X = self.layer_morm(Q + V_att) # (batch_size, max_r_words, embedding_dim) 43 | output = self.layer_morm(self.FFN(X) + X) 44 | else: 45 | X = Q + V_att 46 | output = self.FFN(X) + X 47 | 48 | return output 49 | 50 | 51 | class Attention(nn.Module): 52 | def __init__(self, input_size, hidden_size): 53 | super(Attention, self).__init__() 54 | self.linear1 = nn.Linear(in_features=input_size, out_features=hidden_size) 55 | self.linear2 = nn.Linear(in_features=hidden_size, out_features=1) 56 | self.init_weights() 57 | 58 | def init_weights(self): 59 | init.xavier_normal_(self.linear1.weight) 60 | init.xavier_normal_(self.linear2.weight) 61 | 62 | def forward(self, X, mask=None): 63 | ''' 64 | :param X: 65 | :param mask: http://juditacs.github.io/2018/12/27/masked-attention.html 66 | :return: 67 | ''' 68 | M = F.tanh(self.linear1(X)) # (batch_size, max_u_words, embedding_dim) 69 | M = self.linear2(M) 70 | M[~mask] = float('-inf') 71 | score = F.softmax(M, dim=1) # (batch_size, max_u_words, 1) 72 | 73 | output = (score * X).sum(dim=1) # (batch_size, embedding_dim) 74 | return output 75 | 76 | 77 | 78 | class MSN(NeuralNetwork): 79 | ''' 80 | A pytorch version of Sequential Matching Network which is proposed in 81 | "Sequential Matching Network: A New Architecture for Multi-turn Response Selection in Retrieval-based Chatbots" 82 | ''' 83 | def __init__(self, word_embeddings, args): 84 | self.args = args 85 | super(MSN, self).__init__() 86 | 87 | self.word_embedding = nn.Embedding(num_embeddings=len(word_embeddings), embedding_dim=200, padding_idx=0, 88 | _weight=torch.FloatTensor(word_embeddings)) 89 | 90 | self.alpha = 0.5 91 | self.gamma = 0.3 92 | self.selector_transformer = TransformerBlock(input_size=200) 93 | self.W_word = nn.Parameter(data=torch.Tensor(200, 200, 10)) 94 | self.v = nn.Parameter(data=torch.Tensor(10, 1)) 95 | self.linear_word = nn.Linear(2*50, 1) 96 | self.linear_score = nn.Linear(in_features=3, out_features=1) 97 | 98 | self.transformer_utt = TransformerBlock(input_size=200) 99 | self.transformer_res = TransformerBlock(input_size=200) 100 | self.transformer_ur = TransformerBlock(input_size=200) 101 | self.transformer_ru = TransformerBlock(input_size=200) 102 | 103 | self.A1 = nn.Parameter(data=torch.Tensor(200, 200)) 104 | self.A2 = nn.Parameter(data=torch.Tensor(200, 200)) 105 | self.A3 = nn.Parameter(data=torch.Tensor(200, 200)) 106 | 107 | self.cnn_2d_1 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(3,3)) 108 | self.maxpooling1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) 109 | 110 | self.cnn_2d_2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3,3)) 111 | self.maxpooling2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) 112 | 113 | self.cnn_2d_3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3,3)) 114 | self.maxpooling3 = nn.MaxPool2d(kernel_size=(3, 3), stride=(3, 3)) 115 | 116 | self.affine2 = nn.Linear(in_features=3*3*64, out_features=300) 117 | 118 | self.gru_acc = nn.GRU(input_size=300, hidden_size=args.gru_hidden, batch_first=True) 119 | # self.attention = Attention(input_size=300, hidden_size=300) 120 | self.affine_out = nn.Linear(in_features=args.gru_hidden, out_features=1) 121 | 122 | self.tanh = nn.Tanh() 123 | self.relu = nn.ReLU() 124 | self.dropout = nn.Dropout(0.2) 125 | self.init_weights() 126 | print(self) 127 | 128 | 129 | def init_weights(self): 130 | init.uniform_(self.W_word) 131 | init.uniform_(self.v) 132 | init.uniform_(self.linear_word.weight) 133 | init.uniform_(self.linear_score.weight) 134 | 135 | init.xavier_normal_(self.A1) 136 | init.xavier_normal_(self.A2) 137 | init.xavier_normal_(self.A3) 138 | init.xavier_normal_(self.cnn_2d_1.weight) 139 | init.xavier_normal_(self.cnn_2d_2.weight) 140 | init.xavier_normal_(self.cnn_2d_3.weight) 141 | init.xavier_normal_(self.affine2.weight) 142 | init.xavier_normal_(self.affine_out.weight) 143 | for weights in [self.gru_acc.weight_hh_l0, self.gru_acc.weight_ih_l0]: 144 | init.orthogonal_(weights) 145 | 146 | 147 | def word_selector(self, key, context): 148 | ''' 149 | :param key: (bsz, max_u_words, d) 150 | :param context: (bsz, max_u_words, d) 151 | :return: score: 152 | ''' 153 | dk = torch.sqrt(torch.Tensor([200])).cuda() 154 | A = torch.tanh(torch.einsum("blrd,ddh,bud->blruh", context, self.W_word, key)/dk) 155 | A = torch.einsum("blruh,hp->blrup", A, self.v).squeeze() # b x l x u x u 156 | 157 | a = torch.cat([A.max(dim=2)[0], A.max(dim=3)[0]], dim=-1) # b x l x 2u 158 | s1 = torch.softmax(self.linear_word(a).squeeze(), dim=-1) # b x l 159 | return s1 160 | 161 | def utterance_selector(self, key, context): 162 | ''' 163 | :param key: (bsz, max_u_words, d) 164 | :param context: (bsz, max_u_words, d) 165 | :return: score: 166 | ''' 167 | key = key.mean(dim=1) 168 | context = context.mean(dim=2) 169 | s2 = torch.einsum("bud,bd->bu", context, key)/(1e-6 + torch.norm(context, dim=-1)*torch.norm(key, dim=-1, keepdim=True) ) 170 | return s2 171 | 172 | def distance(self, A, B, C, epsilon=1e-6): 173 | M1 = torch.einsum("bud,dd,brd->bur", [A, B, C]) 174 | 175 | A_norm = A.norm(dim=-1) 176 | C_norm = C.norm(dim=-1) 177 | M2 = torch.einsum("bud,brd->bur", [A, C]) / (torch.einsum("bu,br->bur", A_norm, C_norm) + epsilon) 178 | return M1, M2 179 | 180 | def context_selector(self, context, hop=[1, 2, 3]): 181 | ''' 182 | :param context: (batch_size, max_utterances, max_u_words, embedding_dim) 183 | :param key: (batch_size, max_u_words, embedding_dim) 184 | :return: 185 | ''' 186 | su1, su2, su3, su4 = context.size() 187 | context_ = context.view(-1, su3, su4) # (batch_size*max_utterances, max_u_words, embedding_dim) 188 | context_ = self.selector_transformer(context_, context_, context_) 189 | context_ = context_.view(su1, su2, su3, su4) 190 | 191 | multi_match_score = [] 192 | for hop_i in hop: 193 | key = context[:, 10-hop_i:, :, :].mean(dim=1) 194 | key = self.selector_transformer(key, key, key) 195 | 196 | s1 = self.word_selector(key, context_) 197 | s2 = self.utterance_selector(key, context_) 198 | s = self.alpha * s1 + (1 - self.alpha) * s2 199 | multi_match_score.append(s) 200 | 201 | multi_match_score = torch.stack(multi_match_score, dim=-1) 202 | match_score = self.linear_score(multi_match_score).squeeze() 203 | mask = (match_score.sigmoid() >= self.gamma).float() 204 | match_score = match_score * mask 205 | context = context * match_score.unsqueeze(dim=-1).unsqueeze(dim=-1) 206 | return context 207 | 208 | def get_Matching_Map(self, bU_embedding, bR_embedding): 209 | ''' 210 | :param bU_embedding: (batch_size*max_utterances, max_u_words, embedding_dim) 211 | :param bR_embedding: (batch_size*max_utterances, max_r_words, embedding_dim) 212 | :return: E: (bsz*max_utterances, max_u_words, max_r_words) 213 | ''' 214 | # M1 = torch.einsum("bud,dd,brd->bur", bU_embedding, self.A1, bR_embedding) # (bsz*max_utterances, max_u_words, max_r_words) 215 | M1, M2 = self.distance(bU_embedding, self.A1, bR_embedding) 216 | 217 | Hu = self.transformer_utt(bU_embedding, bU_embedding, bU_embedding) 218 | Hr = self.transformer_res(bR_embedding, bR_embedding, bR_embedding) 219 | # M2 = torch.einsum("bud,dd,brd->bur", [Hu, self.A2, Hr]) 220 | M3, M4 = self.distance(Hu, self.A2, Hr) 221 | 222 | Hur = self.transformer_ur(bU_embedding, bR_embedding, bR_embedding) 223 | Hru = self.transformer_ru(bR_embedding, bU_embedding, bU_embedding) 224 | # M3 = torch.einsum("bud,dd,brd->bur", [Hur, self.A3, Hru]) 225 | M5, M6 = self.distance(Hur, self.A3, Hru) 226 | 227 | M = torch.stack([M1, M2, M3, M4, M5, M6], dim=1) # (bsz*max_utterances, channel, max_u_words, max_r_words) 228 | return M 229 | 230 | 231 | def UR_Matching(self, bU_embedding, bR_embedding): 232 | ''' 233 | :param bU_embedding: (batch_size*max_utterances, max_u_words, embedding_dim) 234 | :param bR_embedding: (batch_size*max_utterances, max_r_words, embedding_dim) 235 | :return: (bsz*max_utterances, (max_u_words - width)/stride + 1, (max_r_words -height)/stride + 1, channel) 236 | ''' 237 | M = self.get_Matching_Map(bU_embedding, bR_embedding) 238 | 239 | Z = self.relu(self.cnn_2d_1(M)) 240 | Z = self.maxpooling1(Z) 241 | 242 | Z = self.relu(self.cnn_2d_2(Z)) 243 | Z =self.maxpooling2(Z) 244 | 245 | Z = self.relu(self.cnn_2d_3(Z)) 246 | Z =self.maxpooling3(Z) 247 | 248 | Z = Z.view(Z.size(0), -1) # (bsz*max_utterances, *) 249 | 250 | V = self.tanh(self.affine2(Z)) # (bsz*max_utterances, 50) 251 | return V 252 | 253 | def forward(self, bU, bR): 254 | ''' 255 | :param bU: batch utterance, size: (batch_size, max_utterances, max_u_words) 256 | :param bR: batch responses, size: (batch_size, max_r_words) 257 | :return: scores, size: (batch_size, ) 258 | ''' 259 | # u_mask = (bU != 0).unsqueeze(dim=-1).float() 260 | # u_mask_sent = ((bU != 0).sum(dim=-1) !=0 ).unsqueeze(dim=-1) 261 | # r_mask = (bR != 0).unsqueeze(dim=-1).float() 262 | 263 | bU_embedding = self.word_embedding(bU) # + self.position_embedding(bU_pos) # * u_mask 264 | bR_embedding = self.word_embedding(bR) # + self.position_embedding(bR_pos) # * r_mask 265 | multi_context = self.context_selector(bU_embedding, hop=[1, 2, 3]) 266 | 267 | su1, su2, su3, su4 = multi_context.size() 268 | multi_context = multi_context.view(-1, su3, su4) # (batch_size*max_utterances, max_u_words, embedding_dim) 269 | 270 | sr1, sr2, sr3= bR_embedding.size() # (batch_size, max_r_words, embedding_dim) 271 | bR_embedding = bR_embedding.unsqueeze(dim=1).repeat(1, su2, 1, 1) # (batch_size, max_utterances, max_r_words, embedding_dim) 272 | bR_embedding = bR_embedding.view(-1, sr2, sr3) # (batch_size*max_utterances, max_r_words, embedding_dim) 273 | 274 | V = self.UR_Matching(multi_context, bR_embedding) 275 | V = V.view(su1, su2, -1) # (bsz, max_utterances, 300) 276 | 277 | H, _ = self.gru_acc(V) # (bsz, max_utterances, rnn2_hidden) 278 | # L = self.attention(V, u_mask_sent) 279 | L = self.dropout(H[:,-1,:]) 280 | 281 | output = torch.sigmoid(self.affine_out(L)) 282 | return output.squeeze() 283 | 284 | 285 | -------------------------------------------------------------------------------- /Metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | np.random.seed(0) 3 | 4 | 5 | class Metrics(object): 6 | 7 | def __init__(self, score_file_path:str): 8 | super(Metrics, self).__init__() 9 | self.score_file_path = score_file_path 10 | self.segment = 10 11 | 12 | def __read_socre_file(self, score_file_path): 13 | sessions = [] 14 | one_sess = [] 15 | with open(score_file_path, 'r') as infile: 16 | i = 0 17 | for line in infile.readlines(): 18 | i += 1 19 | tokens = line.strip().split('\t') 20 | one_sess.append((float(tokens[0]), int(tokens[1]))) 21 | if i % self.segment == 0: 22 | one_sess_tmp = np.array(one_sess) 23 | if one_sess_tmp[:, 1].sum() > 0: 24 | sessions.append(one_sess) 25 | one_sess = [] 26 | return sessions 27 | 28 | 29 | def __mean_average_precision(self, sort_data): 30 | #to do 31 | count_1 = 0 32 | sum_precision = 0 33 | for index in range(len(sort_data)): 34 | if sort_data[index][1] == 1: 35 | count_1 += 1 36 | sum_precision += 1.0 * count_1 / (index+1) 37 | return sum_precision / count_1 38 | 39 | 40 | def __mean_reciprocal_rank(self, sort_data): 41 | sort_lable = [s_d[1] for s_d in sort_data] 42 | assert 1 in sort_lable 43 | return 1.0 / (1 + sort_lable.index(1)) 44 | 45 | def __precision_at_position_1(self, sort_data): 46 | if sort_data[0][1] == 1: 47 | return 1 48 | else: 49 | return 0 50 | 51 | def __recall_at_position_k_in_10(self, sort_data, k): 52 | sort_label = [s_d[1] for s_d in sort_data] 53 | select_label = sort_label[:k] 54 | return 1.0 * select_label.count(1) / sort_label.count(1) 55 | 56 | 57 | def evaluation_one_session(self, data): 58 | ''' 59 | :param data: one conversion session, which layout is [(score1, label1), (score2, label2), ..., (score10, label10)]. 60 | :return: all kinds of metrics used in paper. 61 | ''' 62 | np.random.shuffle(data) 63 | sort_data = sorted(data, key=lambda x: x[0], reverse=True) 64 | m_a_p = self.__mean_average_precision(sort_data) 65 | m_r_r = self.__mean_reciprocal_rank(sort_data) 66 | p_1 = self.__precision_at_position_1(sort_data) 67 | r_1 = self.__recall_at_position_k_in_10(sort_data, 1) 68 | r_2 = self.__recall_at_position_k_in_10(sort_data, 2) 69 | r_5 = self.__recall_at_position_k_in_10(sort_data, 5) 70 | return m_a_p, m_r_r, p_1, r_1, r_2, r_5 71 | 72 | 73 | def evaluate_all_metrics(self): 74 | sum_m_a_p = 0 75 | sum_m_r_r = 0 76 | sum_p_1 = 0 77 | sum_r_1 = 0 78 | sum_r_2 = 0 79 | sum_r_5 = 0 80 | 81 | sessions = self.__read_socre_file(self.score_file_path) 82 | total_s = len(sessions) 83 | for session in sessions: 84 | m_a_p, m_r_r, p_1, r_1, r_2, r_5 = self.evaluation_one_session(session) 85 | sum_m_a_p += m_a_p 86 | sum_m_r_r += m_r_r 87 | sum_p_1 += p_1 88 | sum_r_1 += r_1 89 | sum_r_2 += r_2 90 | sum_r_5 += r_5 91 | 92 | return (sum_m_a_p/total_s, 93 | sum_m_r_r/total_s, 94 | sum_p_1/total_s, 95 | sum_r_1/total_s, 96 | sum_r_2/total_s, 97 | sum_r_5/total_s) 98 | 99 | 100 | if __name__ == '__main__': 101 | metric = Metrics('dataset/E_commerce/score_file.txt') 102 | result = metric.evaluate_all_metrics(is_test=True) 103 | for r in result: 104 | print(r) 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /NeuralNetwork.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.utils as utils 4 | import torch.optim as optim 5 | from torch.utils.data import DataLoader 6 | from DialogueDataset import DialogueDataset 7 | from Metrics import Metrics 8 | 9 | torch.backends.cudnn.benchmark = True 10 | torch.manual_seed(0) 11 | torch.cuda.manual_seed_all(0) 12 | 13 | class NeuralNetwork(nn.Module): 14 | 15 | def __init__(self): 16 | super(NeuralNetwork, self).__init__() 17 | self.patience = 0 18 | self.init_clip_max_norm = 5.0 19 | self.optimizer = None 20 | self.best_result = [0, 0, 0, 0, 0, 0] 21 | self.metrics = Metrics(self.args.score_file_path) 22 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 23 | 24 | def forward(self): 25 | raise NotImplementedError 26 | 27 | 28 | def train_step(self, i, data): 29 | with torch.no_grad(): 30 | batch_u, batch_r, batch_y = (item.cuda(device=self.device) for item in data) 31 | 32 | self.optimizer.zero_grad() 33 | logits = self.forward(batch_u, batch_r) 34 | loss = self.loss_func(logits, target=batch_y) 35 | loss.backward() 36 | self.optimizer.step() 37 | print('Batch[{}] - loss: {:.6f} batch_size:{}'.format(i, loss.item(), batch_y.size(0)) ) # , accuracy, corrects 38 | return loss 39 | 40 | 41 | def fit(self, X_train_utterances, X_train_responses, y_train, 42 | X_dev_utterances, X_dev_responses, y_dev): 43 | 44 | if torch.cuda.is_available(): self.cuda() 45 | 46 | dataset = DialogueDataset(X_train_utterances, X_train_responses, y_train) 47 | dataloader = DataLoader(dataset, batch_size=self.args.batch_size, shuffle=True) 48 | 49 | self.loss_func = nn.BCELoss() 50 | self.optimizer = optim.Adam(self.parameters(), lr=self.args.learning_rate, weight_decay=self.args.l2_reg) 51 | 52 | for epoch in range(self.args.epochs): 53 | print("\nEpoch ", epoch+1, "/", self.args.epochs) 54 | avg_loss = 0 55 | 56 | self.train() 57 | for i, data in enumerate(dataloader): 58 | loss = self.train_step(i, data) 59 | 60 | if i > 0 and i % 500 == 0: 61 | self.evaluate(X_dev_utterances, X_dev_responses, y_dev) 62 | self.train() 63 | 64 | if epoch >= 2 and self.patience >= 3: 65 | print("Reload the best model...") 66 | self.load_state_dict(torch.load(self.args.save_path)) 67 | self.adjust_learning_rate() 68 | self.patience = 0 69 | 70 | if self.init_clip_max_norm is not None: 71 | utils.clip_grad_norm_(self.parameters(), max_norm=self.init_clip_max_norm) 72 | 73 | avg_loss += loss.item() 74 | cnt = len(y_train) // self.args.batch_size + 1 75 | print("Average loss:{:.6f} ".format(avg_loss/cnt)) 76 | self.evaluate(X_dev_utterances, X_dev_responses, y_dev) 77 | 78 | 79 | def adjust_learning_rate(self, decay_rate=.5): 80 | for param_group in self.optimizer.param_groups: 81 | param_group['lr'] = param_group['lr'] * decay_rate 82 | self.args.learning_rate = param_group['lr'] 83 | print("Decay learning rate to: ", self.args.learning_rate) 84 | 85 | 86 | def evaluate(self, X_dev_utterances, X_dev_responses, y_dev, is_test=False): 87 | y_pred = self.predict(X_dev_utterances, X_dev_responses) 88 | with open(self.args.score_file_path, 'w') as output: 89 | for score, label in zip(y_pred, y_dev): 90 | output.write( 91 | str(score) + '\t' + 92 | str(label) + '\n' 93 | ) 94 | 95 | result = self.metrics.evaluate_all_metrics() 96 | print("Evaluation Result: \n", 97 | "MAP:", result[0], "\t", 98 | "MRR:", result[1], "\t", 99 | "P@1:", result[2], "\t", 100 | "R1:", result[3], "\t", 101 | "R2:", result[4], "\t", 102 | "R5:", result[5]) 103 | 104 | if not is_test and result[3] + result[4] + result[5] > self.best_result[3] + self.best_result[4] + self.best_result[5]: 105 | print("Best Result: \n", 106 | "MAP:", self.best_result[0], "\t", 107 | "MRR:", self.best_result[1], "\t", 108 | "P@1:", self.best_result[2], "\t", 109 | "R1:", self.best_result[3], "\t", 110 | "R2:", self.best_result[4], "\t", 111 | "R5:", self.best_result[5]) 112 | self.patience = 0 113 | self.best_result = result 114 | torch.save(self.state_dict(), self.args.save_path) 115 | print("save model!!!\n") 116 | else: 117 | self.patience += 1 118 | 119 | 120 | def predict(self, X_dev_utterances, X_dev_responses): 121 | self.eval() 122 | y_pred = [] 123 | dataset = DialogueDataset(X_dev_utterances, X_dev_responses) 124 | dataloader = DataLoader(dataset, batch_size=100) 125 | 126 | for i, data in enumerate(dataloader): 127 | with torch.no_grad(): 128 | batch_u, batch_r = (item.cuda() for item in data) 129 | 130 | logits = self.forward(batch_u, batch_r) 131 | y_pred += logits.data.cpu().numpy().tolist() 132 | return y_pred 133 | 134 | 135 | def load_model(self, path): 136 | self.load_state_dict(state_dict=torch.load(path)) 137 | if torch.cuda.is_available(): self.cuda() 138 | 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Paper of the source codes released: 2 | This repository contains the source code and datasets for the model Multi-hop Selector Network.
3 | 4 | 5 | ## Dependencies 6 | Python 3.x
7 | Pytorch 1.1.0 8 | 9 | ## Datasets and Trained Models 10 | Your can download the processed datasets and the checkpoints of trained models for reproduce the experimental results in the paper by the following url:
11 | https://drive.google.com/drive/folders/1pJKIppcbjuTZxbTc8ye5mfnC2ygR2xTo?usp=sharing 12 | 13 | Unzip the dataset.rar file to the folder of ```dataset``` and unzip the checkpoint.rar file to the folder of ```checkpoint```.
14 | The ```log/``` directory already contains the training and evaluation logs of each dataset. 15 | 16 | ## Reproduce the experimental results by the pre-trained model 17 | ``` 18 | cd ./Dialogue/ 19 | python ./run.py --task "ubuntu" 20 | python ./run.py --task "douban" 21 | python ./run.py --task "alime" 22 | ``` 23 | 24 | ## Train a new model 25 | ``` 26 | cd ./Dialogue/ 27 | python ./run.py --task "ubuntu" --is_training True 28 | python ./run.py --task "douban" --is_training True 29 | python ./run.py --task "alime" --is_training True 30 | ``` 31 | The training process is recorded in ```log/[ubuntu|douban|alime].msn.log``` file. 32 | 33 | 34 | ## Citation 35 | If you find this code useful in your research, please cite our paper: 36 | ``` 37 | @inproceedings{yuan2019multi, 38 | title={Multi-hop Selector Network for Multi-turn Response Selection in Retrieval-based Chatbots}, 39 | author={Yuan, Chunyuan and Zhou, Wei and Li, Mingming and Lv, Shangwen and Zhu, Fuqing and Han, Jizhong and Hu, Songlin}, 40 | booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)}, 41 | pages={111--120}, 42 | year={2019} 43 | } 44 | ``` 45 | -------------------------------------------------------------------------------- /checkpoint/README.md: -------------------------------------------------------------------------------- 1 | Put the trained model in this directory. -------------------------------------------------------------------------------- /dataset/DoubanConversaionCorpus/README.md: -------------------------------------------------------------------------------- 1 | Put the processed douban dataset in this directory. -------------------------------------------------------------------------------- /dataset/E_commerce/README.md: -------------------------------------------------------------------------------- 1 | Put the processed E_commerce dataset in this directory. -------------------------------------------------------------------------------- /dataset/ubuntu_data/README.md: -------------------------------------------------------------------------------- 1 | Put the processed ubuntu dataset in this directory. -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | import pickle 4 | from MSN import MSN 5 | import os 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 7 | 8 | task_dic = { 9 | 'ubuntu':'./dataset/ubuntu_data/', 10 | 'douban':'./dataset/DoubanConversaionCorpus/', 11 | 'alime':'./dataset/E_commerce/' 12 | } 13 | data_batch_size = { 14 | "ubuntu": 200, 15 | "douban": 150, 16 | "alime": 200 17 | } 18 | 19 | ## Required parameters 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--task", 22 | default='ubuntu', 23 | type=str, 24 | help="The dataset used for training and test.") 25 | parser.add_argument("--is_training", 26 | default=False, 27 | type=bool, 28 | help="Training model or evaluating model?") 29 | parser.add_argument("--max_utterances", 30 | default=10, 31 | type=int, 32 | help="The maximum number of utterances.") 33 | parser.add_argument("--max_words", 34 | default=50, 35 | type=int, 36 | help="The maximum number of words for each utterance.") 37 | parser.add_argument("--batch_size", 38 | default=0, 39 | type=int, 40 | help="The batch size.") 41 | parser.add_argument("--gru_hidden", 42 | default=300, 43 | type=int, 44 | help="The hidden size of GRU in layer 1") 45 | parser.add_argument("--learning_rate", 46 | default=1e-3, 47 | type=float, 48 | help="The initial learning rate for Adam.") 49 | parser.add_argument("--l2_reg", 50 | default=0.0, 51 | type=float, 52 | help="The l2 regularization.") 53 | parser.add_argument("--epochs", 54 | default=5, 55 | type=float, 56 | help="Total number of training epochs to perform.") 57 | parser.add_argument("--save_path", 58 | default="./checkpoint/", 59 | type=str, 60 | help="The path to save model.") 61 | parser.add_argument("--score_file_path", 62 | default="score_file.txt", 63 | type=str, 64 | help="The path to save model.") 65 | args = parser.parse_args() 66 | args.batch_size = data_batch_size[args.task] 67 | args.save_path += args.task + '.' + MSN.__name__ + ".pt" 68 | args.score_file_path = task_dic[args.task] + args.score_file_path 69 | 70 | print(args) 71 | print("Task: ", args.task) 72 | 73 | 74 | def train_model(): 75 | path = task_dic[args.task] 76 | X_train_utterances, X_train_responses, y_train = pickle.load(file=open(path+"train.pkl", 'rb')) 77 | X_dev_utterances, X_dev_responses, y_dev = pickle.load(file=open(path+"test.pkl", 'rb')) 78 | vocab, word_embeddings = pickle.load(file=open(path + "vocab_and_embeddings.pkl", 'rb')) 79 | 80 | model = MSN(word_embeddings, args=args) 81 | model.fit( 82 | X_train_utterances, X_train_responses, y_train, 83 | X_dev_utterances, X_dev_responses, y_dev 84 | ) 85 | 86 | 87 | def test_model(): 88 | path = task_dic[args.task] 89 | X_test_utterances, X_test_responses, y_test = pickle.load(file=open(path+"test.pkl", 'rb')) 90 | vocab, word_embeddings = pickle.load(file=open(path + "vocab_and_embeddings.pkl", 'rb')) 91 | 92 | model = MSN(word_embeddings, args=args) 93 | model.load_model(args.save_path) 94 | model.evaluate(X_test_utterances, X_test_responses, y_test, is_test=True) 95 | 96 | def test_adversarial(): 97 | path = task_dic[args.task] 98 | vocab, word_embeddings = pickle.load(file=open(path + "vocab_and_embeddings.pkl", 'rb')) 99 | model = MSN(word_embeddings, args=args) 100 | model.load_model(args.save_path) 101 | print("adversarial test set (k=1): ") 102 | X_test_utterances, X_test_responses, y_test = pickle.load(file=open(path+"test_adversarial_k_1.pkl", 'rb')) 103 | model.evaluate(X_test_utterances, X_test_responses, y_test, is_test=True) 104 | print("adversarial test set (k=2): ") 105 | X_test_utterances, X_test_responses, y_test = pickle.load(file=open(path+"test_adversarial_k_2.pkl", 'rb')) 106 | model.evaluate(X_test_utterances, X_test_responses, y_test, is_test=True) 107 | print("adversarial test set (k=3): ") 108 | X_test_utterances, X_test_responses, y_test = pickle.load(file=open(path+"test_adversarial_k_3.pkl", 'rb')) 109 | model.evaluate(X_test_utterances, X_test_responses, y_test, is_test=True) 110 | 111 | 112 | if __name__ == '__main__': 113 | start = time.time() 114 | if args.is_training: 115 | train_model() 116 | test_model() 117 | else: 118 | test_model() 119 | # test_adversarial() 120 | end = time.time() 121 | print("use time: ", (end-start)/60, " min") 122 | 123 | 124 | 125 | 126 | --------------------------------------------------------------------------------