├── DialogueDataset.py
├── MSN.py
├── Metrics.py
├── NeuralNetwork.py
├── README.md
├── checkpoint
└── README.md
├── dataset
├── DoubanConversaionCorpus
│ └── README.md
├── E_commerce
│ └── README.md
└── ubuntu_data
│ └── README.md
├── log
├── alime.msn.log
├── douban.msn.log
└── ubuntu.msn.log
└── run.py
/DialogueDataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import TensorDataset
3 |
4 |
5 |
6 | class DialogueDataset(TensorDataset):
7 |
8 | def __init__(self, X_utterances, X_responses, y_labels=None):
9 | super(DialogueDataset, self).__init__()
10 | X_utterances = torch.LongTensor(X_utterances)
11 |
12 | X_responses = torch.LongTensor(X_responses)
13 | print("X_utterances: ", X_utterances.size())
14 | print("X_responses: ", X_responses.size())
15 |
16 | if y_labels is not None:
17 | y_labels = torch.FloatTensor(y_labels)
18 | print("y_labels: ", y_labels.size())
19 | self.tensors = [X_utterances, X_responses, y_labels]
20 | else:
21 | self.tensors = [X_utterances, X_responses]
22 |
23 | def __getitem__(self, index):
24 | return tuple(tensor[index] for tensor in self.tensors)
25 |
26 | def __len__(self):
27 | return len(self.tensors[0])
28 |
29 |
--------------------------------------------------------------------------------
/MSN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 | from NeuralNetwork import NeuralNetwork
6 |
7 |
8 | class TransformerBlock(nn.Module):
9 |
10 | def __init__(self, input_size, is_layer_norm=False):
11 | super(TransformerBlock, self).__init__()
12 | self.is_layer_norm = is_layer_norm
13 | if is_layer_norm:
14 | self.layer_morm = nn.LayerNorm(normalized_shape=input_size)
15 |
16 | self.relu = nn.ReLU()
17 | self.linear1 = nn.Linear(input_size, input_size)
18 | self.linear2 = nn.Linear(input_size, input_size)
19 | self.init_weights()
20 |
21 | def init_weights(self):
22 | init.xavier_normal_(self.linear1.weight)
23 | init.xavier_normal_(self.linear2.weight)
24 |
25 | def FFN(self, X):
26 | return self.linear2(self.relu(self.linear1(X)))
27 |
28 | def forward(self, Q, K, V, episilon=1e-8):
29 | '''
30 | :param Q: (batch_size, max_r_words, embedding_dim)
31 | :param K: (batch_size, max_u_words, embedding_dim)
32 | :param V: (batch_size, max_u_words, embedding_dim)
33 | :return: output: (batch_size, max_r_words, embedding_dim) same size as Q
34 | '''
35 | dk = torch.Tensor([max(1.0, Q.size(-1))]).cuda()
36 |
37 | Q_K = Q.bmm(K.permute(0, 2, 1)) / (torch.sqrt(dk) + episilon)
38 | Q_K_score = F.softmax(Q_K, dim=-1) # (batch_size, max_r_words, max_u_words)
39 | V_att = Q_K_score.bmm(V)
40 |
41 | if self.is_layer_norm:
42 | X = self.layer_morm(Q + V_att) # (batch_size, max_r_words, embedding_dim)
43 | output = self.layer_morm(self.FFN(X) + X)
44 | else:
45 | X = Q + V_att
46 | output = self.FFN(X) + X
47 |
48 | return output
49 |
50 |
51 | class Attention(nn.Module):
52 | def __init__(self, input_size, hidden_size):
53 | super(Attention, self).__init__()
54 | self.linear1 = nn.Linear(in_features=input_size, out_features=hidden_size)
55 | self.linear2 = nn.Linear(in_features=hidden_size, out_features=1)
56 | self.init_weights()
57 |
58 | def init_weights(self):
59 | init.xavier_normal_(self.linear1.weight)
60 | init.xavier_normal_(self.linear2.weight)
61 |
62 | def forward(self, X, mask=None):
63 | '''
64 | :param X:
65 | :param mask: http://juditacs.github.io/2018/12/27/masked-attention.html
66 | :return:
67 | '''
68 | M = F.tanh(self.linear1(X)) # (batch_size, max_u_words, embedding_dim)
69 | M = self.linear2(M)
70 | M[~mask] = float('-inf')
71 | score = F.softmax(M, dim=1) # (batch_size, max_u_words, 1)
72 |
73 | output = (score * X).sum(dim=1) # (batch_size, embedding_dim)
74 | return output
75 |
76 |
77 |
78 | class MSN(NeuralNetwork):
79 | '''
80 | A pytorch version of Sequential Matching Network which is proposed in
81 | "Sequential Matching Network: A New Architecture for Multi-turn Response Selection in Retrieval-based Chatbots"
82 | '''
83 | def __init__(self, word_embeddings, args):
84 | self.args = args
85 | super(MSN, self).__init__()
86 |
87 | self.word_embedding = nn.Embedding(num_embeddings=len(word_embeddings), embedding_dim=200, padding_idx=0,
88 | _weight=torch.FloatTensor(word_embeddings))
89 |
90 | self.alpha = 0.5
91 | self.gamma = 0.3
92 | self.selector_transformer = TransformerBlock(input_size=200)
93 | self.W_word = nn.Parameter(data=torch.Tensor(200, 200, 10))
94 | self.v = nn.Parameter(data=torch.Tensor(10, 1))
95 | self.linear_word = nn.Linear(2*50, 1)
96 | self.linear_score = nn.Linear(in_features=3, out_features=1)
97 |
98 | self.transformer_utt = TransformerBlock(input_size=200)
99 | self.transformer_res = TransformerBlock(input_size=200)
100 | self.transformer_ur = TransformerBlock(input_size=200)
101 | self.transformer_ru = TransformerBlock(input_size=200)
102 |
103 | self.A1 = nn.Parameter(data=torch.Tensor(200, 200))
104 | self.A2 = nn.Parameter(data=torch.Tensor(200, 200))
105 | self.A3 = nn.Parameter(data=torch.Tensor(200, 200))
106 |
107 | self.cnn_2d_1 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(3,3))
108 | self.maxpooling1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
109 |
110 | self.cnn_2d_2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3,3))
111 | self.maxpooling2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
112 |
113 | self.cnn_2d_3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3,3))
114 | self.maxpooling3 = nn.MaxPool2d(kernel_size=(3, 3), stride=(3, 3))
115 |
116 | self.affine2 = nn.Linear(in_features=3*3*64, out_features=300)
117 |
118 | self.gru_acc = nn.GRU(input_size=300, hidden_size=args.gru_hidden, batch_first=True)
119 | # self.attention = Attention(input_size=300, hidden_size=300)
120 | self.affine_out = nn.Linear(in_features=args.gru_hidden, out_features=1)
121 |
122 | self.tanh = nn.Tanh()
123 | self.relu = nn.ReLU()
124 | self.dropout = nn.Dropout(0.2)
125 | self.init_weights()
126 | print(self)
127 |
128 |
129 | def init_weights(self):
130 | init.uniform_(self.W_word)
131 | init.uniform_(self.v)
132 | init.uniform_(self.linear_word.weight)
133 | init.uniform_(self.linear_score.weight)
134 |
135 | init.xavier_normal_(self.A1)
136 | init.xavier_normal_(self.A2)
137 | init.xavier_normal_(self.A3)
138 | init.xavier_normal_(self.cnn_2d_1.weight)
139 | init.xavier_normal_(self.cnn_2d_2.weight)
140 | init.xavier_normal_(self.cnn_2d_3.weight)
141 | init.xavier_normal_(self.affine2.weight)
142 | init.xavier_normal_(self.affine_out.weight)
143 | for weights in [self.gru_acc.weight_hh_l0, self.gru_acc.weight_ih_l0]:
144 | init.orthogonal_(weights)
145 |
146 |
147 | def word_selector(self, key, context):
148 | '''
149 | :param key: (bsz, max_u_words, d)
150 | :param context: (bsz, max_u_words, d)
151 | :return: score:
152 | '''
153 | dk = torch.sqrt(torch.Tensor([200])).cuda()
154 | A = torch.tanh(torch.einsum("blrd,ddh,bud->blruh", context, self.W_word, key)/dk)
155 | A = torch.einsum("blruh,hp->blrup", A, self.v).squeeze() # b x l x u x u
156 |
157 | a = torch.cat([A.max(dim=2)[0], A.max(dim=3)[0]], dim=-1) # b x l x 2u
158 | s1 = torch.softmax(self.linear_word(a).squeeze(), dim=-1) # b x l
159 | return s1
160 |
161 | def utterance_selector(self, key, context):
162 | '''
163 | :param key: (bsz, max_u_words, d)
164 | :param context: (bsz, max_u_words, d)
165 | :return: score:
166 | '''
167 | key = key.mean(dim=1)
168 | context = context.mean(dim=2)
169 | s2 = torch.einsum("bud,bd->bu", context, key)/(1e-6 + torch.norm(context, dim=-1)*torch.norm(key, dim=-1, keepdim=True) )
170 | return s2
171 |
172 | def distance(self, A, B, C, epsilon=1e-6):
173 | M1 = torch.einsum("bud,dd,brd->bur", [A, B, C])
174 |
175 | A_norm = A.norm(dim=-1)
176 | C_norm = C.norm(dim=-1)
177 | M2 = torch.einsum("bud,brd->bur", [A, C]) / (torch.einsum("bu,br->bur", A_norm, C_norm) + epsilon)
178 | return M1, M2
179 |
180 | def context_selector(self, context, hop=[1, 2, 3]):
181 | '''
182 | :param context: (batch_size, max_utterances, max_u_words, embedding_dim)
183 | :param key: (batch_size, max_u_words, embedding_dim)
184 | :return:
185 | '''
186 | su1, su2, su3, su4 = context.size()
187 | context_ = context.view(-1, su3, su4) # (batch_size*max_utterances, max_u_words, embedding_dim)
188 | context_ = self.selector_transformer(context_, context_, context_)
189 | context_ = context_.view(su1, su2, su3, su4)
190 |
191 | multi_match_score = []
192 | for hop_i in hop:
193 | key = context[:, 10-hop_i:, :, :].mean(dim=1)
194 | key = self.selector_transformer(key, key, key)
195 |
196 | s1 = self.word_selector(key, context_)
197 | s2 = self.utterance_selector(key, context_)
198 | s = self.alpha * s1 + (1 - self.alpha) * s2
199 | multi_match_score.append(s)
200 |
201 | multi_match_score = torch.stack(multi_match_score, dim=-1)
202 | match_score = self.linear_score(multi_match_score).squeeze()
203 | mask = (match_score.sigmoid() >= self.gamma).float()
204 | match_score = match_score * mask
205 | context = context * match_score.unsqueeze(dim=-1).unsqueeze(dim=-1)
206 | return context
207 |
208 | def get_Matching_Map(self, bU_embedding, bR_embedding):
209 | '''
210 | :param bU_embedding: (batch_size*max_utterances, max_u_words, embedding_dim)
211 | :param bR_embedding: (batch_size*max_utterances, max_r_words, embedding_dim)
212 | :return: E: (bsz*max_utterances, max_u_words, max_r_words)
213 | '''
214 | # M1 = torch.einsum("bud,dd,brd->bur", bU_embedding, self.A1, bR_embedding) # (bsz*max_utterances, max_u_words, max_r_words)
215 | M1, M2 = self.distance(bU_embedding, self.A1, bR_embedding)
216 |
217 | Hu = self.transformer_utt(bU_embedding, bU_embedding, bU_embedding)
218 | Hr = self.transformer_res(bR_embedding, bR_embedding, bR_embedding)
219 | # M2 = torch.einsum("bud,dd,brd->bur", [Hu, self.A2, Hr])
220 | M3, M4 = self.distance(Hu, self.A2, Hr)
221 |
222 | Hur = self.transformer_ur(bU_embedding, bR_embedding, bR_embedding)
223 | Hru = self.transformer_ru(bR_embedding, bU_embedding, bU_embedding)
224 | # M3 = torch.einsum("bud,dd,brd->bur", [Hur, self.A3, Hru])
225 | M5, M6 = self.distance(Hur, self.A3, Hru)
226 |
227 | M = torch.stack([M1, M2, M3, M4, M5, M6], dim=1) # (bsz*max_utterances, channel, max_u_words, max_r_words)
228 | return M
229 |
230 |
231 | def UR_Matching(self, bU_embedding, bR_embedding):
232 | '''
233 | :param bU_embedding: (batch_size*max_utterances, max_u_words, embedding_dim)
234 | :param bR_embedding: (batch_size*max_utterances, max_r_words, embedding_dim)
235 | :return: (bsz*max_utterances, (max_u_words - width)/stride + 1, (max_r_words -height)/stride + 1, channel)
236 | '''
237 | M = self.get_Matching_Map(bU_embedding, bR_embedding)
238 |
239 | Z = self.relu(self.cnn_2d_1(M))
240 | Z = self.maxpooling1(Z)
241 |
242 | Z = self.relu(self.cnn_2d_2(Z))
243 | Z =self.maxpooling2(Z)
244 |
245 | Z = self.relu(self.cnn_2d_3(Z))
246 | Z =self.maxpooling3(Z)
247 |
248 | Z = Z.view(Z.size(0), -1) # (bsz*max_utterances, *)
249 |
250 | V = self.tanh(self.affine2(Z)) # (bsz*max_utterances, 50)
251 | return V
252 |
253 | def forward(self, bU, bR):
254 | '''
255 | :param bU: batch utterance, size: (batch_size, max_utterances, max_u_words)
256 | :param bR: batch responses, size: (batch_size, max_r_words)
257 | :return: scores, size: (batch_size, )
258 | '''
259 | # u_mask = (bU != 0).unsqueeze(dim=-1).float()
260 | # u_mask_sent = ((bU != 0).sum(dim=-1) !=0 ).unsqueeze(dim=-1)
261 | # r_mask = (bR != 0).unsqueeze(dim=-1).float()
262 |
263 | bU_embedding = self.word_embedding(bU) # + self.position_embedding(bU_pos) # * u_mask
264 | bR_embedding = self.word_embedding(bR) # + self.position_embedding(bR_pos) # * r_mask
265 | multi_context = self.context_selector(bU_embedding, hop=[1, 2, 3])
266 |
267 | su1, su2, su3, su4 = multi_context.size()
268 | multi_context = multi_context.view(-1, su3, su4) # (batch_size*max_utterances, max_u_words, embedding_dim)
269 |
270 | sr1, sr2, sr3= bR_embedding.size() # (batch_size, max_r_words, embedding_dim)
271 | bR_embedding = bR_embedding.unsqueeze(dim=1).repeat(1, su2, 1, 1) # (batch_size, max_utterances, max_r_words, embedding_dim)
272 | bR_embedding = bR_embedding.view(-1, sr2, sr3) # (batch_size*max_utterances, max_r_words, embedding_dim)
273 |
274 | V = self.UR_Matching(multi_context, bR_embedding)
275 | V = V.view(su1, su2, -1) # (bsz, max_utterances, 300)
276 |
277 | H, _ = self.gru_acc(V) # (bsz, max_utterances, rnn2_hidden)
278 | # L = self.attention(V, u_mask_sent)
279 | L = self.dropout(H[:,-1,:])
280 |
281 | output = torch.sigmoid(self.affine_out(L))
282 | return output.squeeze()
283 |
284 |
285 |
--------------------------------------------------------------------------------
/Metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | np.random.seed(0)
3 |
4 |
5 | class Metrics(object):
6 |
7 | def __init__(self, score_file_path:str):
8 | super(Metrics, self).__init__()
9 | self.score_file_path = score_file_path
10 | self.segment = 10
11 |
12 | def __read_socre_file(self, score_file_path):
13 | sessions = []
14 | one_sess = []
15 | with open(score_file_path, 'r') as infile:
16 | i = 0
17 | for line in infile.readlines():
18 | i += 1
19 | tokens = line.strip().split('\t')
20 | one_sess.append((float(tokens[0]), int(tokens[1])))
21 | if i % self.segment == 0:
22 | one_sess_tmp = np.array(one_sess)
23 | if one_sess_tmp[:, 1].sum() > 0:
24 | sessions.append(one_sess)
25 | one_sess = []
26 | return sessions
27 |
28 |
29 | def __mean_average_precision(self, sort_data):
30 | #to do
31 | count_1 = 0
32 | sum_precision = 0
33 | for index in range(len(sort_data)):
34 | if sort_data[index][1] == 1:
35 | count_1 += 1
36 | sum_precision += 1.0 * count_1 / (index+1)
37 | return sum_precision / count_1
38 |
39 |
40 | def __mean_reciprocal_rank(self, sort_data):
41 | sort_lable = [s_d[1] for s_d in sort_data]
42 | assert 1 in sort_lable
43 | return 1.0 / (1 + sort_lable.index(1))
44 |
45 | def __precision_at_position_1(self, sort_data):
46 | if sort_data[0][1] == 1:
47 | return 1
48 | else:
49 | return 0
50 |
51 | def __recall_at_position_k_in_10(self, sort_data, k):
52 | sort_label = [s_d[1] for s_d in sort_data]
53 | select_label = sort_label[:k]
54 | return 1.0 * select_label.count(1) / sort_label.count(1)
55 |
56 |
57 | def evaluation_one_session(self, data):
58 | '''
59 | :param data: one conversion session, which layout is [(score1, label1), (score2, label2), ..., (score10, label10)].
60 | :return: all kinds of metrics used in paper.
61 | '''
62 | np.random.shuffle(data)
63 | sort_data = sorted(data, key=lambda x: x[0], reverse=True)
64 | m_a_p = self.__mean_average_precision(sort_data)
65 | m_r_r = self.__mean_reciprocal_rank(sort_data)
66 | p_1 = self.__precision_at_position_1(sort_data)
67 | r_1 = self.__recall_at_position_k_in_10(sort_data, 1)
68 | r_2 = self.__recall_at_position_k_in_10(sort_data, 2)
69 | r_5 = self.__recall_at_position_k_in_10(sort_data, 5)
70 | return m_a_p, m_r_r, p_1, r_1, r_2, r_5
71 |
72 |
73 | def evaluate_all_metrics(self):
74 | sum_m_a_p = 0
75 | sum_m_r_r = 0
76 | sum_p_1 = 0
77 | sum_r_1 = 0
78 | sum_r_2 = 0
79 | sum_r_5 = 0
80 |
81 | sessions = self.__read_socre_file(self.score_file_path)
82 | total_s = len(sessions)
83 | for session in sessions:
84 | m_a_p, m_r_r, p_1, r_1, r_2, r_5 = self.evaluation_one_session(session)
85 | sum_m_a_p += m_a_p
86 | sum_m_r_r += m_r_r
87 | sum_p_1 += p_1
88 | sum_r_1 += r_1
89 | sum_r_2 += r_2
90 | sum_r_5 += r_5
91 |
92 | return (sum_m_a_p/total_s,
93 | sum_m_r_r/total_s,
94 | sum_p_1/total_s,
95 | sum_r_1/total_s,
96 | sum_r_2/total_s,
97 | sum_r_5/total_s)
98 |
99 |
100 | if __name__ == '__main__':
101 | metric = Metrics('dataset/E_commerce/score_file.txt')
102 | result = metric.evaluate_all_metrics(is_test=True)
103 | for r in result:
104 | print(r)
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
--------------------------------------------------------------------------------
/NeuralNetwork.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.utils as utils
4 | import torch.optim as optim
5 | from torch.utils.data import DataLoader
6 | from DialogueDataset import DialogueDataset
7 | from Metrics import Metrics
8 |
9 | torch.backends.cudnn.benchmark = True
10 | torch.manual_seed(0)
11 | torch.cuda.manual_seed_all(0)
12 |
13 | class NeuralNetwork(nn.Module):
14 |
15 | def __init__(self):
16 | super(NeuralNetwork, self).__init__()
17 | self.patience = 0
18 | self.init_clip_max_norm = 5.0
19 | self.optimizer = None
20 | self.best_result = [0, 0, 0, 0, 0, 0]
21 | self.metrics = Metrics(self.args.score_file_path)
22 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
23 |
24 | def forward(self):
25 | raise NotImplementedError
26 |
27 |
28 | def train_step(self, i, data):
29 | with torch.no_grad():
30 | batch_u, batch_r, batch_y = (item.cuda(device=self.device) for item in data)
31 |
32 | self.optimizer.zero_grad()
33 | logits = self.forward(batch_u, batch_r)
34 | loss = self.loss_func(logits, target=batch_y)
35 | loss.backward()
36 | self.optimizer.step()
37 | print('Batch[{}] - loss: {:.6f} batch_size:{}'.format(i, loss.item(), batch_y.size(0)) ) # , accuracy, corrects
38 | return loss
39 |
40 |
41 | def fit(self, X_train_utterances, X_train_responses, y_train,
42 | X_dev_utterances, X_dev_responses, y_dev):
43 |
44 | if torch.cuda.is_available(): self.cuda()
45 |
46 | dataset = DialogueDataset(X_train_utterances, X_train_responses, y_train)
47 | dataloader = DataLoader(dataset, batch_size=self.args.batch_size, shuffle=True)
48 |
49 | self.loss_func = nn.BCELoss()
50 | self.optimizer = optim.Adam(self.parameters(), lr=self.args.learning_rate, weight_decay=self.args.l2_reg)
51 |
52 | for epoch in range(self.args.epochs):
53 | print("\nEpoch ", epoch+1, "/", self.args.epochs)
54 | avg_loss = 0
55 |
56 | self.train()
57 | for i, data in enumerate(dataloader):
58 | loss = self.train_step(i, data)
59 |
60 | if i > 0 and i % 500 == 0:
61 | self.evaluate(X_dev_utterances, X_dev_responses, y_dev)
62 | self.train()
63 |
64 | if epoch >= 2 and self.patience >= 3:
65 | print("Reload the best model...")
66 | self.load_state_dict(torch.load(self.args.save_path))
67 | self.adjust_learning_rate()
68 | self.patience = 0
69 |
70 | if self.init_clip_max_norm is not None:
71 | utils.clip_grad_norm_(self.parameters(), max_norm=self.init_clip_max_norm)
72 |
73 | avg_loss += loss.item()
74 | cnt = len(y_train) // self.args.batch_size + 1
75 | print("Average loss:{:.6f} ".format(avg_loss/cnt))
76 | self.evaluate(X_dev_utterances, X_dev_responses, y_dev)
77 |
78 |
79 | def adjust_learning_rate(self, decay_rate=.5):
80 | for param_group in self.optimizer.param_groups:
81 | param_group['lr'] = param_group['lr'] * decay_rate
82 | self.args.learning_rate = param_group['lr']
83 | print("Decay learning rate to: ", self.args.learning_rate)
84 |
85 |
86 | def evaluate(self, X_dev_utterances, X_dev_responses, y_dev, is_test=False):
87 | y_pred = self.predict(X_dev_utterances, X_dev_responses)
88 | with open(self.args.score_file_path, 'w') as output:
89 | for score, label in zip(y_pred, y_dev):
90 | output.write(
91 | str(score) + '\t' +
92 | str(label) + '\n'
93 | )
94 |
95 | result = self.metrics.evaluate_all_metrics()
96 | print("Evaluation Result: \n",
97 | "MAP:", result[0], "\t",
98 | "MRR:", result[1], "\t",
99 | "P@1:", result[2], "\t",
100 | "R1:", result[3], "\t",
101 | "R2:", result[4], "\t",
102 | "R5:", result[5])
103 |
104 | if not is_test and result[3] + result[4] + result[5] > self.best_result[3] + self.best_result[4] + self.best_result[5]:
105 | print("Best Result: \n",
106 | "MAP:", self.best_result[0], "\t",
107 | "MRR:", self.best_result[1], "\t",
108 | "P@1:", self.best_result[2], "\t",
109 | "R1:", self.best_result[3], "\t",
110 | "R2:", self.best_result[4], "\t",
111 | "R5:", self.best_result[5])
112 | self.patience = 0
113 | self.best_result = result
114 | torch.save(self.state_dict(), self.args.save_path)
115 | print("save model!!!\n")
116 | else:
117 | self.patience += 1
118 |
119 |
120 | def predict(self, X_dev_utterances, X_dev_responses):
121 | self.eval()
122 | y_pred = []
123 | dataset = DialogueDataset(X_dev_utterances, X_dev_responses)
124 | dataloader = DataLoader(dataset, batch_size=100)
125 |
126 | for i, data in enumerate(dataloader):
127 | with torch.no_grad():
128 | batch_u, batch_r = (item.cuda() for item in data)
129 |
130 | logits = self.forward(batch_u, batch_r)
131 | y_pred += logits.data.cpu().numpy().tolist()
132 | return y_pred
133 |
134 |
135 | def load_model(self, path):
136 | self.load_state_dict(state_dict=torch.load(path))
137 | if torch.cuda.is_available(): self.cuda()
138 |
139 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Paper of the source codes released:
2 | This repository contains the source code and datasets for the model Multi-hop Selector Network.
3 |
4 |
5 | ## Dependencies
6 | Python 3.x
7 | Pytorch 1.1.0
8 |
9 | ## Datasets and Trained Models
10 | Your can download the processed datasets and the checkpoints of trained models for reproduce the experimental results in the paper by the following url:
11 | https://drive.google.com/drive/folders/1pJKIppcbjuTZxbTc8ye5mfnC2ygR2xTo?usp=sharing
12 |
13 | Unzip the dataset.rar file to the folder of ```dataset``` and unzip the checkpoint.rar file to the folder of ```checkpoint```.
14 | The ```log/``` directory already contains the training and evaluation logs of each dataset.
15 |
16 | ## Reproduce the experimental results by the pre-trained model
17 | ```
18 | cd ./Dialogue/
19 | python ./run.py --task "ubuntu"
20 | python ./run.py --task "douban"
21 | python ./run.py --task "alime"
22 | ```
23 |
24 | ## Train a new model
25 | ```
26 | cd ./Dialogue/
27 | python ./run.py --task "ubuntu" --is_training True
28 | python ./run.py --task "douban" --is_training True
29 | python ./run.py --task "alime" --is_training True
30 | ```
31 | The training process is recorded in ```log/[ubuntu|douban|alime].msn.log``` file.
32 |
33 |
34 | ## Citation
35 | If you find this code useful in your research, please cite our paper:
36 | ```
37 | @inproceedings{yuan2019multi,
38 | title={Multi-hop Selector Network for Multi-turn Response Selection in Retrieval-based Chatbots},
39 | author={Yuan, Chunyuan and Zhou, Wei and Li, Mingming and Lv, Shangwen and Zhu, Fuqing and Han, Jizhong and Hu, Songlin},
40 | booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)},
41 | pages={111--120},
42 | year={2019}
43 | }
44 | ```
45 |
--------------------------------------------------------------------------------
/checkpoint/README.md:
--------------------------------------------------------------------------------
1 | Put the trained model in this directory.
--------------------------------------------------------------------------------
/dataset/DoubanConversaionCorpus/README.md:
--------------------------------------------------------------------------------
1 | Put the processed douban dataset in this directory.
--------------------------------------------------------------------------------
/dataset/E_commerce/README.md:
--------------------------------------------------------------------------------
1 | Put the processed E_commerce dataset in this directory.
--------------------------------------------------------------------------------
/dataset/ubuntu_data/README.md:
--------------------------------------------------------------------------------
1 | Put the processed ubuntu dataset in this directory.
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import time
2 | import argparse
3 | import pickle
4 | from MSN import MSN
5 | import os
6 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
7 |
8 | task_dic = {
9 | 'ubuntu':'./dataset/ubuntu_data/',
10 | 'douban':'./dataset/DoubanConversaionCorpus/',
11 | 'alime':'./dataset/E_commerce/'
12 | }
13 | data_batch_size = {
14 | "ubuntu": 200,
15 | "douban": 150,
16 | "alime": 200
17 | }
18 |
19 | ## Required parameters
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("--task",
22 | default='ubuntu',
23 | type=str,
24 | help="The dataset used for training and test.")
25 | parser.add_argument("--is_training",
26 | default=False,
27 | type=bool,
28 | help="Training model or evaluating model?")
29 | parser.add_argument("--max_utterances",
30 | default=10,
31 | type=int,
32 | help="The maximum number of utterances.")
33 | parser.add_argument("--max_words",
34 | default=50,
35 | type=int,
36 | help="The maximum number of words for each utterance.")
37 | parser.add_argument("--batch_size",
38 | default=0,
39 | type=int,
40 | help="The batch size.")
41 | parser.add_argument("--gru_hidden",
42 | default=300,
43 | type=int,
44 | help="The hidden size of GRU in layer 1")
45 | parser.add_argument("--learning_rate",
46 | default=1e-3,
47 | type=float,
48 | help="The initial learning rate for Adam.")
49 | parser.add_argument("--l2_reg",
50 | default=0.0,
51 | type=float,
52 | help="The l2 regularization.")
53 | parser.add_argument("--epochs",
54 | default=5,
55 | type=float,
56 | help="Total number of training epochs to perform.")
57 | parser.add_argument("--save_path",
58 | default="./checkpoint/",
59 | type=str,
60 | help="The path to save model.")
61 | parser.add_argument("--score_file_path",
62 | default="score_file.txt",
63 | type=str,
64 | help="The path to save model.")
65 | args = parser.parse_args()
66 | args.batch_size = data_batch_size[args.task]
67 | args.save_path += args.task + '.' + MSN.__name__ + ".pt"
68 | args.score_file_path = task_dic[args.task] + args.score_file_path
69 |
70 | print(args)
71 | print("Task: ", args.task)
72 |
73 |
74 | def train_model():
75 | path = task_dic[args.task]
76 | X_train_utterances, X_train_responses, y_train = pickle.load(file=open(path+"train.pkl", 'rb'))
77 | X_dev_utterances, X_dev_responses, y_dev = pickle.load(file=open(path+"test.pkl", 'rb'))
78 | vocab, word_embeddings = pickle.load(file=open(path + "vocab_and_embeddings.pkl", 'rb'))
79 |
80 | model = MSN(word_embeddings, args=args)
81 | model.fit(
82 | X_train_utterances, X_train_responses, y_train,
83 | X_dev_utterances, X_dev_responses, y_dev
84 | )
85 |
86 |
87 | def test_model():
88 | path = task_dic[args.task]
89 | X_test_utterances, X_test_responses, y_test = pickle.load(file=open(path+"test.pkl", 'rb'))
90 | vocab, word_embeddings = pickle.load(file=open(path + "vocab_and_embeddings.pkl", 'rb'))
91 |
92 | model = MSN(word_embeddings, args=args)
93 | model.load_model(args.save_path)
94 | model.evaluate(X_test_utterances, X_test_responses, y_test, is_test=True)
95 |
96 | def test_adversarial():
97 | path = task_dic[args.task]
98 | vocab, word_embeddings = pickle.load(file=open(path + "vocab_and_embeddings.pkl", 'rb'))
99 | model = MSN(word_embeddings, args=args)
100 | model.load_model(args.save_path)
101 | print("adversarial test set (k=1): ")
102 | X_test_utterances, X_test_responses, y_test = pickle.load(file=open(path+"test_adversarial_k_1.pkl", 'rb'))
103 | model.evaluate(X_test_utterances, X_test_responses, y_test, is_test=True)
104 | print("adversarial test set (k=2): ")
105 | X_test_utterances, X_test_responses, y_test = pickle.load(file=open(path+"test_adversarial_k_2.pkl", 'rb'))
106 | model.evaluate(X_test_utterances, X_test_responses, y_test, is_test=True)
107 | print("adversarial test set (k=3): ")
108 | X_test_utterances, X_test_responses, y_test = pickle.load(file=open(path+"test_adversarial_k_3.pkl", 'rb'))
109 | model.evaluate(X_test_utterances, X_test_responses, y_test, is_test=True)
110 |
111 |
112 | if __name__ == '__main__':
113 | start = time.time()
114 | if args.is_training:
115 | train_model()
116 | test_model()
117 | else:
118 | test_model()
119 | # test_adversarial()
120 | end = time.time()
121 | print("use time: ", (end-start)/60, " min")
122 |
123 |
124 |
125 |
126 |
--------------------------------------------------------------------------------