├── photo ├── model.png └── loss0.5.png ├── README.md ├── LICENSE ├── model_utils.py ├── data_helper.py ├── model.py ├── data_preprocess.py └── qacnn.py /photo/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenRichard/CNN-in-Answer-selection/HEAD/photo/model.png -------------------------------------------------------------------------------- /photo/loss0.5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenRichard/CNN-in-Answer-selection/HEAD/photo/loss0.5.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WikiQA on QACNN 2 | ## 复现论文《APPLYING DEEP LEARNING TO ANSWER SELECTION: A STUDY AND AN OPEN TASK》 3 | 本项目采取了论文中最好的模型进行实验,数据集采用WikiQA,后期会上传insuranceQA的实验结果 4 | 模型图如下: 5 | ![model]( https://github.com/WenRichard/QACNN/raw/master/photo/model.png) 6 | **实验结果**: 7 | 8 | |Model|CNN share|Dropout|Parameters|Margin|Epoch|MAP|MRR| 9 | |-|-|-|-|-|-|-|-| 10 | |QACNN|No|0.5|2115200|0.5|100|0.655|0.673| 11 | |QACNN|Yes|0.5|481664|0.5|100|0.684|0.697| 12 | |QACNN|Yes|0.5|481664|0.25|100|0.668|0.674| 13 | |QACNN|Yes|0.5|481664|0.2|100|0.690|0.695| 14 | 15 | **Loss**: 16 | ![Pairwise Loss]( https://github.com/WenRichard/QACNN/raw/master/photo/loss0.5.png) 17 | 18 | **有时间就会更新QA实验,有兴趣的同学可以follow一下,也欢迎Fork和Star!** 19 | **留言请在Issues或者email xiezhengwen2013@163.com** 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 WenRichard 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/3/19 22:11 3 | # @Author : Alan 4 | # @Email : xiezhengwen2013@163.com 5 | # @File : model_utils.py 6 | # @Software: PyCharm 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | 12 | def eval_map_mrr(qids, aids, preds, labels): 13 | # 衡量map指标和mrr指标 14 | dic = dict() 15 | pre_dic = dict() 16 | for qid, aid, pred, label in zip(qids, aids, preds, labels): 17 | pre_dic.setdefault(qid, []) 18 | pre_dic[qid].append([aid, pred, label]) 19 | for qid in pre_dic: 20 | dic[qid] = sorted(pre_dic[qid], key=lambda k: k[1], reverse=True) 21 | aid2rank = {aid: [label, rank] for (rank, (aid, pred, label)) in enumerate(dic[qid])} 22 | dic[qid] = aid2rank 23 | # correct = 0 24 | # total = 0 25 | # for qid in dic: 26 | # cur_correct = 0 27 | # for aid in dic[qid]: 28 | # if dic[qid][aid][0] == 1: 29 | # cur_correct += 1 30 | # if cur_correct > 0: 31 | # correct += 1 32 | # total += 1 33 | # print(correct * 1. / total) 34 | 35 | MAP = 0.0 36 | MRR = 0.0 37 | useful_q_len = 0 38 | for q_id in dic: 39 | sort_rank = sorted(dic[q_id].items(), key=lambda k: k[1][1], reverse=False) 40 | correct = 0 41 | total = 0 42 | AP = 0.0 43 | mrr_mark = False 44 | for i in range(len(sort_rank)): 45 | if sort_rank[i][1][0] == 1: 46 | correct += 1 47 | if correct == 0: 48 | continue 49 | useful_q_len += 1 50 | correct = 0 51 | for i in range(len(sort_rank)): 52 | # compute MRR 53 | if sort_rank[i][1][0] == 1 and mrr_mark == False: 54 | MRR += 1.0 / float(i + 1) 55 | mrr_mark = True 56 | # compute MAP 57 | total += 1 58 | if sort_rank[i][1][0] == 1: 59 | correct += 1 60 | AP += float(correct) / float(total) 61 | 62 | AP /= float(correct) 63 | MAP += AP 64 | 65 | MAP /= useful_q_len 66 | MRR /= useful_q_len 67 | return MAP, MRR 68 | 69 | 70 | # print tensor shape 71 | def print_shape(varname, var): 72 | """ 73 | :param varname: tensor name 74 | :param var: tensor variable 75 | """ 76 | try: 77 | print('{0} : {1}'.format(varname, var.get_shape())) 78 | except: 79 | print('{0} : {1}'.format(varname, np.shape(var))) 80 | 81 | 82 | # count the number of trainable parameters in model 83 | def count_parameters(): 84 | totalParams = 0 85 | for variable in tf.trainable_variables(): 86 | shape = variable.get_shape() 87 | variableParams = 1 88 | for dim in shape: 89 | variableParams *= dim.value 90 | totalParams += variableParams 91 | return totalParams 92 | 93 | 94 | # 余弦相似度计算 95 | def feature2cos_sim(feat_q, feat_a): 96 | # feat_q: 2D:(bz, hz) 97 | norm_q = tf.sqrt(tf.reduce_sum(tf.multiply(feat_q, feat_q), 1)) 98 | norm_a = tf.sqrt(tf.reduce_sum(tf.multiply(feat_a, feat_a), 1)) 99 | mul_q_a = tf.reduce_sum(tf.multiply(feat_q, feat_a), 1) 100 | cos_sim_q_a = tf.div(mul_q_a, tf.multiply(norm_q, norm_a)) 101 | return tf.clip_by_value(cos_sim_q_a, 1e-5, 0.99999) 102 | 103 | -------------------------------------------------------------------------------- /data_helper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/3/19 16:10 3 | # @Author : Alan 4 | # @Email : xiezhengwen2013@163.com 5 | # @File : data_helper.py 6 | # @Software: PyCharm 7 | 8 | import sys 9 | import numpy as np 10 | import random 11 | from collections import namedtuple 12 | import pickle 13 | 14 | random.seed(1337) 15 | np.random.seed(1337) 16 | 17 | def load_embedding(dstPath): 18 | with open(dstPath, 'rb') as fin: 19 | _embeddings = pickle.load(fin) 20 | print("load embedding finish! embedding shape:{}".format(np.shape(_embeddings))) 21 | return _embeddings 22 | 23 | 24 | class Batch: 25 | # batch类,里面包含了encoder输入,decoder输入,decoder标签,decoder样本长度mask 26 | def __init__(self): 27 | self.quest_id = [] 28 | self.ans_id = [] 29 | self.quest = [] 30 | self.ans = [] 31 | self.quest_mask = [] 32 | self.ans_mask = [] 33 | self.label = [] 34 | 35 | 36 | def transform(fin_path, vocab, unk_id=1): 37 | word2id = {} 38 | transformed_corpus = [] 39 | with open(vocab, 'r', encoding='utf-8') as f1: 40 | for line in f1: 41 | word = line.strip().split('\t')[1].lower() 42 | id = int(line.strip().split('\t')[0]) 43 | word2id[word] = id 44 | with open(fin_path, 'r', encoding='utf-8') as fin: 45 | fin.readline() 46 | for line in fin: 47 | qid, q, aid, a, label = line.strip().split('\t') 48 | q = [word2id.get(w.lower(), unk_id) for w in q.split()] 49 | a = [word2id.get(w.lower(), unk_id) for w in a.split()] 50 | transformed_corpus.append([qid, q, aid, a, int(label)]) 51 | return transformed_corpus 52 | 53 | 54 | def transform_train(fin_path, vocab, unk_id=1): 55 | word2id = {} 56 | transformed_corpus = [] 57 | with open(vocab, 'r', encoding='utf-8') as f1: 58 | for line in f1: 59 | word = line.strip().split('\t')[1].lower() 60 | id = int(line.strip().split('\t')[0]) 61 | word2id[word] = id 62 | with open(fin_path, 'r', encoding='utf-8') as fin: 63 | fin.readline() 64 | for line in fin: 65 | q, a_pos, a_neg = line.strip().split('\t') 66 | q = [word2id.get(w.lower(), unk_id) for w in q.split()] 67 | a_pos = [word2id.get(w.lower(), unk_id) for w in a_pos.split()] 68 | a_neg = [word2id.get(w.lower(), unk_id) for w in a_neg.split()] 69 | transformed_corpus.append([q, a_pos, a_neg]) 70 | return transformed_corpus 71 | 72 | 73 | def padding(sent, sequence_len): 74 | """ 75 | convert sentence to index array 76 | """ 77 | if len(sent) > sequence_len: 78 | sent = sent[:sequence_len] 79 | padding = sequence_len - len(sent) 80 | sent2idx = sent + [0]*padding 81 | return sent2idx, len(sent) 82 | 83 | 84 | def load_train_data(transformed_corpus, ques_len, ans_len): 85 | """ 86 | load train data 87 | """ 88 | pairwise_corpus = [] 89 | for sample in transformed_corpus: 90 | q, a_pos, a_neg = sample 91 | q_pad, q_len = padding(q, ques_len) 92 | a_pos_pad, a_pos_len = padding(a_pos, ans_len) 93 | a_neg_pad, a_neg_len = padding(a_neg, ans_len) 94 | pairwise_corpus.append((q_pad, a_pos_pad, a_neg_pad, q_len, a_pos_len, a_neg_len)) 95 | return pairwise_corpus 96 | 97 | 98 | def load_data(transformed_corpus, ques_len, ans_len, keep_ids=False): 99 | """ 100 | load test data 101 | """ 102 | pairwise_corpus = [] 103 | for sample in transformed_corpus: 104 | qid, q, aid, a, label = sample 105 | q_pad, q_len = padding(q, ques_len) 106 | a_pad, a_len = padding(a, ans_len) 107 | if keep_ids: 108 | pairwise_corpus.append((qid, q_pad, aid, a_pad, q_len, a_len, label)) 109 | else: 110 | pairwise_corpus.append((q_pad, a_pad, q_len, a_len, label)) 111 | return pairwise_corpus 112 | 113 | 114 | class Iterator(object): 115 | """ 116 | 数据迭代器 117 | """ 118 | def __init__(self, x): 119 | self.x = x 120 | self.sample_num = len(self.x) 121 | 122 | def next_batch(self, batch_size, shuffle=True): 123 | # produce X, Y_out, Y_in, X_len, Y_in_len, Y_out_len 124 | if shuffle: 125 | np.random.shuffle(self.x) 126 | l = np.random.randint(0, self.sample_num - batch_size + 1) 127 | r = l + batch_size 128 | x_part = self.x[l:r] 129 | return x_part 130 | 131 | def next(self, batch_size, shuffle=False): 132 | if shuffle: 133 | np.random.shuffle(self.x) 134 | l = 0 135 | while l < self.sample_num: 136 | r = min(l + batch_size, self.sample_num) 137 | batch_size = r - l 138 | x_part = self.x[l:r] 139 | l += batch_size 140 | yield x_part 141 | 142 | 143 | if __name__ == '__main__': 144 | train = '../data/WikiQA/processed/pointwise/WikiQA-train.tsv' 145 | vocab = '../data/WikiQA/processed/pointwise/wiki_clean_vocab.txt' 146 | transform(train, vocab) 147 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/3/19 16:21 3 | # @Author : Alan 4 | # @Email : xiezhengwen2013@163.com 5 | # @File : model.py 6 | # @Software: PyCharm 7 | 8 | import tensorflow as tf 9 | from model_utils import feature2cos_sim 10 | 11 | 12 | class SiameseQACNN(object): 13 | def __init__(self, config): 14 | self.ques_len = config.ques_length 15 | self.ans_len = config.ans_length 16 | self.hidden_size = config.hidden_size 17 | self.output_size = config.output_size 18 | self.pos_weight = config.pos_weight 19 | self.learning_rate = config.learning_rate 20 | self.optimizer = config.optimizer 21 | self.l2_lambda = config.l2_lambda 22 | self.clip_value = config.clip_value 23 | self.embeddings = config.embeddings 24 | self.window_sizes = config.window_sizes 25 | self.n_filters = config.n_filters 26 | self.margin = config.margin 27 | 28 | self._placeholder_init_pointwise() 29 | self.q_a_cosine, self.q_aneg_cosine = self._build(self.embeddings) 30 | # 损失和精确度 31 | self.total_loss, self.accu = self._add_loss_op(self.q_a_cosine, self.q_aneg_cosine, self.l2_lambda) 32 | # 训练节点 33 | self.train_op = self._add_train_op(self.total_loss) 34 | 35 | def _placeholder_init_pointwise(self): 36 | self._ques = tf.placeholder(tf.int32, [None, self.ques_len], name='ques_point') 37 | self._ans = tf.placeholder(tf.int32, [None, self.ans_len], name='ans_point') 38 | self._ans_neg = tf.placeholder(tf.int32, [None, self.ans_len], name='ans_point') 39 | self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob") 40 | self.batch_size, self.list_size = tf.shape(self._ans)[0], tf.shape(self._ans)[1] 41 | 42 | def _HL_layer(self, bottom, n_weight, name): 43 | """ 44 | 全连接层 45 | """ 46 | assert len(bottom.get_shape()) == 3 47 | n_prev_weight = bottom.get_shape()[-1] 48 | max_len = bottom.get_shape()[1] 49 | initer = tf.truncated_normal_initializer(stddev=0.01) 50 | W = tf.get_variable(name + 'W', dtype=tf.float32, shape=[n_prev_weight, n_weight], 51 | initializer=tf.uniform_unit_scaling_initializer()) 52 | b = tf.get_variable(name + 'b', dtype=tf.float32, 53 | initializer=tf.constant(0.1, shape=[n_weight], dtype=tf.float32)) 54 | bottom_2 = tf.reshape(bottom, [-1, n_prev_weight]) 55 | hl = tf.nn.bias_add(tf.matmul(bottom_2, W), b) 56 | hl_tanh = tf.nn.tanh(hl) 57 | HL = tf.reshape(hl_tanh, [-1, max_len, n_weight]) 58 | return HL 59 | 60 | def fc_layer(self, bottom, n_weight, name): 61 | """ 62 | 全连接层 63 | """ 64 | assert len(bottom.get_shape()) == 2 65 | n_prev_weight = bottom.get_shape()[1] 66 | initer = tf.truncated_normal_initializer(stddev=0.01) 67 | W = tf.get_variable(name + 'W', dtype=tf.float32, shape=[n_prev_weight, n_weight], initializer=initer) 68 | b = tf.get_variable(name + 'b', dtype=tf.float32, 69 | initializer=tf.constant(0.01, shape=[n_weight], dtype=tf.float32)) 70 | fc = tf.nn.bias_add(tf.matmul(bottom, W), b) 71 | return fc 72 | 73 | def _network(self, x): 74 | """ 75 | 核心网络 76 | """ 77 | fc1 = self.fc_layer(x, self.hidden_size, "fc1") 78 | ac1 = tf.nn.relu(fc1) 79 | fc2 = self.fc_layer(ac1, self.hidden_size, "fc2") 80 | return fc2 81 | 82 | def _cnn_layer(self, input): 83 | """ 84 | 卷积层 85 | """ 86 | all = [] 87 | max_len = input.get_shape()[1] 88 | for i, filter_size in enumerate(self.window_sizes): 89 | with tf.variable_scope('filter{}'.format(filter_size)): 90 | # 卷积 91 | cnn_out = tf.layers.conv1d(input, self.n_filters, filter_size, padding='valid', 92 | activation=tf.nn.relu, name='q_conv_' + str(i)) 93 | # 池化 94 | pool_out = tf.reduce_max(cnn_out, axis=1, keepdims=True) 95 | tanh_out = tf.nn.tanh(pool_out) 96 | all.append(tanh_out) 97 | cnn_outs = tf.concat(all, axis=-1) 98 | dim = cnn_outs.get_shape()[-1] 99 | cnn_outs = tf.reshape(cnn_outs, [-1, dim]) 100 | return cnn_outs 101 | 102 | def _build(self, embeddings): 103 | self.Embedding = tf.Variable(tf.to_float(embeddings), trainable=False, name='Embedding') 104 | self.q_embed = tf.nn.dropout(tf.nn.embedding_lookup(self.Embedding, self._ques), keep_prob=self.dropout_keep_prob) 105 | self.a_embed = tf.nn.dropout(tf.nn.embedding_lookup(self.Embedding, self._ans), keep_prob=self.dropout_keep_prob) 106 | self.a_neg_embed = tf.nn.dropout(tf.nn.embedding_lookup(self.Embedding, self._ans_neg), keep_prob=self.dropout_keep_prob) 107 | 108 | with tf.variable_scope('siamese') as scope: 109 | # 计算隐藏和卷积层 110 | hl_q = self._HL_layer(self.q_embed, self.hidden_size, 'HL_layer') 111 | conv1_q = self._cnn_layer(hl_q) 112 | scope.reuse_variables() 113 | hl_a = self._HL_layer(self.a_embed, self.hidden_size, 'HL_layer') 114 | hl_a_neg = self._HL_layer(self.a_neg_embed, self.hidden_size, 'HL_layer') 115 | conv1_a = self._cnn_layer(hl_a) 116 | conv1_a_neg = self._cnn_layer(hl_a_neg) 117 | 118 | # 计算余弦相似度 119 | # q_a_cosine = feature2cos_sim(tf.nn.l2_normalize(conv1_q, dim=1), tf.nn.l2_normalize(conv1_a, dim=1)) 120 | # q_aneg_cosine = feature2cos_sim(tf.nn.l2_normalize(conv1_q, dim=1), tf.nn.l2_normalize(conv1_a_neg, dim=1)) 121 | q_a_cosine = tf.reduce_sum(tf.multiply(tf.nn.l2_normalize(conv1_q, dim=1), tf.nn.l2_normalize(conv1_a, dim=1)), 1) 122 | q_aneg_cosine = tf.reduce_sum(tf.multiply(tf.nn.l2_normalize(conv1_q, dim=1), tf.nn.l2_normalize(conv1_a_neg, dim=1)), 1) 123 | return q_a_cosine, q_aneg_cosine 124 | 125 | def _margin_loss(self, pos_sim, neg_sim): 126 | original_loss = self.margin - pos_sim + neg_sim 127 | l = tf.maximum(tf.zeros_like(original_loss), original_loss) 128 | loss = tf.reduce_sum(l) 129 | return loss, l 130 | 131 | def _add_loss_op(self, p_sim, n_sim, l2_lambda=0.0001): 132 | """ 133 | 损失节点 134 | """ 135 | loss, l = self._margin_loss(p_sim, n_sim) 136 | accu = tf.reduce_mean(tf.cast(tf.equal(0., l), tf.float32)) 137 | reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 138 | l2_loss = sum(reg_losses) * l2_lambda 139 | pairwise_loss = loss + l2_loss 140 | tf.summary.scalar('pairwise_loss', pairwise_loss) 141 | self.summary_op = tf.summary.merge_all() 142 | return pairwise_loss, accu 143 | 144 | def _add_train_op(self, loss): 145 | """ 146 | 训练节点 147 | """ 148 | with tf.name_scope('train_op'): 149 | # 记录训练步骤 150 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 151 | opt = tf.train.AdamOptimizer(self.learning_rate) 152 | train_op = opt.minimize(loss, self.global_step) 153 | return train_op 154 | -------------------------------------------------------------------------------- /data_preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/3/19 15:09 3 | # @Author : Alan 4 | # @Email : xiezhengwen2013@163.com 5 | # @File : data_preprocess.py 6 | # @Software: PyCharm 7 | 8 | import nltk 9 | import codecs 10 | import logging 11 | import numpy as np 12 | import re 13 | from collections import defaultdict 14 | import pickle 15 | import os 16 | from collections import Counter 17 | import copy 18 | import random 19 | #nltk.download('wordnet') 20 | 21 | raw_data_path = '../data/WikiQA/raw' 22 | lemmatized_data_path = '../data/WikiQA/lemmatized' 23 | processed_data_path = '../data/WikiQA/processed' 24 | glove_path = '../glove//glove.6B.300d.txt' 25 | 26 | processed_data_path_pointwise = '../data/WikiQA/processed/pointwise' 27 | processed_data_path_pairwise = '../data/WikiQA/processed/pairwise' 28 | 29 | if not os.path.exists(lemmatized_data_path): 30 | os.mkdir(lemmatized_data_path) 31 | 32 | if not os.path.exists(processed_data_path): 33 | os.mkdir(processed_data_path) 34 | 35 | if not os.path.exists(processed_data_path_pointwise): 36 | os.mkdir(processed_data_path_pointwise) 37 | 38 | if not os.path.exists(processed_data_path_pairwise): 39 | os.mkdir(processed_data_path_pairwise) 40 | 41 | class QaSample(object): 42 | def __init__(self, q_id, question, a_id, answer, label=None, score=0): 43 | self.q_id = q_id 44 | self.question = question 45 | self.a_id = a_id 46 | self.answer = answer 47 | self.label = int(label) 48 | self.score = float(score) 49 | 50 | 51 | def load_qa_data(fname): 52 | with open(fname, 'r', encoding='utf-8') as fin: 53 | for line in fin: 54 | try: 55 | q_id, question, a_id, answer, label = line.strip().split('\t') 56 | except ValueError: 57 | q_id, question, a_id, answer = line.strip().split('\t') 58 | label = 0 59 | yield QaSample(q_id, question, a_id, answer, label) 60 | 61 | 62 | def lemmatize(): 63 | wn_lemmatizer = nltk.stem.WordNetLemmatizer() 64 | data_sets = ['WikiQA-train.tsv', 'WikiQA-dev.tsv', 'WikiQA-test.tsv'] 65 | for set_name in data_sets: 66 | fin_path = os.path.join(raw_data_path, set_name) 67 | fout_path = os.path.join(lemmatized_data_path, set_name) 68 | with open(fin_path, 'r', encoding='utf-8') as fin, open(fout_path, 'w', encoding='utf-8') as fout: 69 | fin.readline() 70 | for line in fin: 71 | line_info = line.strip().split('\t') 72 | q_id = line_info[0] 73 | question = line_info[1] 74 | a_id = line_info[4] 75 | answer = line_info[5] 76 | question = ' '.join(map(lambda x: wn_lemmatizer.lemmatize(x), nltk.word_tokenize(question))) 77 | answer = ' '.join(map(lambda x: wn_lemmatizer.lemmatize(x), nltk.word_tokenize(answer))) 78 | if set_name != 'test': 79 | label = line_info[6] 80 | fout.write('\t'.join([q_id, question, a_id, answer, label]) + '\n') 81 | else: 82 | fout.write('\t'.join([q_id, question, a_id, answer]) + '\n') 83 | 84 | 85 | def gen_train_triplets(same_q_sample_group): 86 | question = same_q_sample_group[0].question 87 | pos_answers = [sample.answer for sample in same_q_sample_group if sample.label == 1] 88 | neg_answers = [sample.answer for sample in same_q_sample_group if sample.label == 0] 89 | if len(pos_answers) != 0: 90 | for pos_answer in pos_answers: 91 | for neg_answer in neg_answers: 92 | yield question, pos_answer, neg_answer 93 | 94 | 95 | # 获取clean的dev和test数据,写入文件 96 | def gen_clean_test(filename): 97 | f_in = os.path.join(lemmatized_data_path, filename) 98 | f_out = os.path.join(processed_data_path_pointwise, filename) 99 | qa_samples = load_qa_data(f_in) 100 | dic = {} 101 | dic2 = {} 102 | for qasa in qa_samples: 103 | if qasa.q_id not in dic: 104 | dic[qasa.q_id] = [qasa.label] 105 | dic2[qasa.q_id] = [qasa] 106 | else: 107 | dic[qasa.q_id].append(qasa.label) 108 | dic2[qasa.q_id].append(qasa) 109 | q = [] 110 | for k, v in dic.items(): 111 | if sum(v) != 0: 112 | q.append(k) 113 | print('所有label有效(不全为0)的数据为:{}'.format(len(q))) 114 | with open(f_out, 'w', encoding='utf-8') as fout: 115 | for t in q: 116 | same_q_samples = dic2[t] 117 | for r in same_q_samples: 118 | fout.write('{}\t{}\t{}\t{}\t{}\n'.format(r.q_id, r.question, r.a_id, r.answer, r.label)) 119 | 120 | 121 | # 获得train、dev、test中所有的词,目前采用lemmatized的,并不是clean后的,但是没有什么影响 122 | def gen_vocab(): 123 | words = [] 124 | data_sets = ['WikiQA-train.tsv', 'WikiQA-dev.tsv', 'WikiQA-test.tsv'] 125 | for set_name in data_sets: 126 | fin_path = os.path.join(lemmatized_data_path, set_name) 127 | with open(fin_path, 'r', encoding='utf-8') as fin: 128 | fin.readline() 129 | for line in fin: 130 | line_in = line.strip().split('\t') 131 | question = line_in[1].split(' ') 132 | answer = line_in[3].split(' ') 133 | for r1 in question: 134 | if r1 not in words: 135 | words.append(r1) 136 | for r2 in answer: 137 | if r2 not in words: 138 | words.append(r2) 139 | fout_path = os.path.join(processed_data_path_pointwise, 'wiki_vocab.txt') 140 | with open(fout_path, 'w', encoding='utf-8') as f: 141 | for i, j in enumerate(words): 142 | f.write('{}\t{}\n'.format(i, j)) 143 | 144 | 145 | # 根据词表生成对应的embedding 146 | def data_transform(embedding_size): 147 | file_in = os.path.join(processed_data_path_pointwise, 'wiki_vocab.txt') 148 | clean_vocab_out = os.path.join(processed_data_path_pointwise, 'wiki_clean_vocab.txt') 149 | embedding_out = os.path.join(processed_data_path_pointwise, 'wiki_embedding.pkl') 150 | words = [] 151 | with open(file_in, 'r', encoding='utf-8') as f1: 152 | for line in f1: 153 | word = line.strip().split('\t')[1].lower() 154 | words.append(word) 155 | print('wiki_vocab.txt总共有{}个词'.format(len(words))) 156 | 157 | embedding_dic = {} 158 | rng = np.random.RandomState(None) 159 | pad_embedding = rng.uniform(-0.25, 0.25, size=(1, embedding_size)) 160 | unk_embedding = rng.uniform(-0.25, 0.25, size=(1, embedding_size)) 161 | embeddings = [] 162 | clean_words = ['', ''] 163 | embeddings.append(pad_embedding.reshape(-1).tolist()) 164 | embeddings.append(unk_embedding.reshape(-1).tolist()) 165 | print('uniform_init...') 166 | with open(glove_path, 'r', encoding='utf-8') as fin: 167 | for line in fin: 168 | try: 169 | line_info = line.strip().split() 170 | word = line_info[0] 171 | embedding = [float(val) for val in line_info[1:]] 172 | embedding_dic[word] = embedding 173 | if word in words: 174 | clean_words.append(word) 175 | embeddings.append(embedding) 176 | except: 177 | print('Error while loading line: {}'.format(line.strip())) 178 | print('目前词表总共有{}个词'.format(len(clean_words))) 179 | print('embeddings总共有{}个词'.format(len(embeddings))) 180 | print('embeddings的shape为: {}'.format(np.shape(embeddings))) 181 | with open(clean_vocab_out, 'w', encoding='utf-8') as f: 182 | for i, j in enumerate(clean_words): 183 | f.write('{}\t{}\n'.format(i, j)) 184 | with open(embedding_out, 'wb') as f2: 185 | pickle.dump(embeddings, f2) 186 | 187 | 188 | if __name__ == '__main__': 189 | # 1.nltk分词 190 | # 2.获取clean的train, dev和test数据,写入文件 191 | # 3.获取词表 192 | # 4.生成相应的embedding 193 | 194 | type = 'pointwise' 195 | # 1.nltk分词 196 | # lemmatize() 197 | 198 | # 2.获取clean的train, dev和test数据,写入文件 199 | train_file = 'WikiQA-train.tsv' 200 | dev_file = 'WikiQA-dev.tsv' 201 | test_file = 'WikiQA-test.tsv' 202 | gen_clean_test(train_file) 203 | gen_clean_test(dev_file) 204 | gen_clean_test(test_file) 205 | 206 | # 3.获取对应的词表 207 | gen_vocab() 208 | 209 | # 4.生成相应的embedding 210 | data_transform(300) -------------------------------------------------------------------------------- /qacnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/3/19 20:19 3 | # @Author : Alan 4 | # @Email : xiezhengwen2013@163.com 5 | # @File : train.py 6 | # @Software: PyCharm 7 | 8 | 9 | import time 10 | import logging 11 | import numpy as np 12 | import tensorflow as tf 13 | import os 14 | import tqdm 15 | import sys 16 | from copy import deepcopy 17 | stdout = sys.stdout 18 | 19 | from data_helper import * 20 | from model import SiameseQACNN 21 | from model_utils import * 22 | 23 | # 创建一个logger 24 | logger = logging.getLogger('mylogger') 25 | logger.setLevel(logging.DEBUG) 26 | 27 | # 创建一个handler,用于写入日志文件 28 | timestamp = str(int(time.time())) 29 | fh = logging.FileHandler('./log/log_' + timestamp +'.txt') 30 | fh.setLevel(logging.DEBUG) 31 | 32 | # 定义handler的输出格式 33 | formatter = logging.Formatter('[%(asctime)s][%(levelname)s] ## %(message)s') 34 | fh.setFormatter(formatter) 35 | # ch.setFormatter(formatter) 36 | 37 | # 给logger添加handler 38 | logger.addHandler(fh) 39 | # logger.addHandler(ch) 40 | 41 | 42 | class NNConfig(object): 43 | def __init__(self, embeddings=None): 44 | # 输入问题(句子)长度 45 | self.ques_length = 25 46 | # 输入答案长度 47 | self.ans_length = 90 48 | # 循环数 49 | self.num_epochs = 100 50 | # batch大小 51 | self.batch_size = 128 52 | # 不同类型的filter,对应不同的尺寸 53 | self.window_sizes = [1, 2, 3, 5, 7, 9] 54 | # 隐层大小 55 | self.hidden_size = 128 56 | self.output_size = 128 57 | self.keep_prob = 0.5 58 | # 每种filter的数量 59 | self.n_filters = 128 60 | # margin大小 61 | self.margin = 0.5 62 | # 词向量大小 63 | self.embeddings = np.array(embeddings).astype(np.float32) 64 | # 学习率 65 | self.learning_rate = 0.001 66 | # contrasive loss 中的 positive loss部分的权重 67 | self.pos_weight = 0.25 68 | # 优化器 69 | self.optimizer = 'adam' 70 | self.clip_value = 5 71 | self.l2_lambda = 0.0001 72 | # 评测 73 | self.eval_batch = 100 74 | 75 | # self.cf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 76 | # self.cf.gpu_options.per_process_gpu_memory_fraction = 0.2 77 | 78 | 79 | def evaluate(sess, model, corpus, config): 80 | iterator = Iterator(corpus) 81 | 82 | count = 0 83 | total_qids = [] 84 | total_aids = [] 85 | total_pred = [] 86 | total_labels = [] 87 | total_loss = 0. 88 | Acc = [] 89 | for batch_x in iterator.next(config.batch_size, shuffle=False): 90 | batch_qids, batch_q, batch_aids, batch_a, batch_qmask, batch_amask, labels = zip(*batch_x) 91 | batch_q = np.asarray(batch_q) 92 | batch_a = np.asarray(batch_a) 93 | q_ap_cosine, loss, acc = sess.run([model.q_a_cosine, model.total_loss, model.accu], 94 | feed_dict={model._ques: batch_q, 95 | model._ans: batch_a, 96 | model._ans_neg: batch_a, 97 | model.dropout_keep_prob: 1.0}) 98 | total_loss += loss 99 | Acc.append(acc) 100 | count += 1 101 | total_qids.append(batch_qids) 102 | total_aids.append(batch_aids) 103 | total_pred.append(q_ap_cosine) 104 | total_labels.append(labels) 105 | 106 | # print(batch_qids[0], [id2word[_] for _ in batch_q[0]], 107 | # batch_aids[0], [id2word[_] for _ in batch_ap[0]]) 108 | total_qids = np.concatenate(total_qids, axis=0) 109 | total_aids = np.concatenate(total_aids, axis=0) 110 | total_pred = np.concatenate(total_pred, axis=0) 111 | total_labels = np.concatenate(total_labels, axis=0) 112 | MAP, MRR = eval_map_mrr(total_qids, total_aids, total_pred, total_labels) 113 | acc_ = np.sum(Acc)/count 114 | ave_loss = total_loss/count 115 | # print('Eval loss:{}'.format(total_loss / count)) 116 | return MAP, MRR, ave_loss, acc_ 117 | 118 | 119 | def test(corpus, config): 120 | with tf.Session() as sess: 121 | model = SiameseQACNN(config) 122 | saver = tf.train.Saver() 123 | saver.restore(sess, tf.train.latest_checkpoint(best_path)) 124 | test_MAP, test_MRR, _, acc = evaluate(sess, model, corpus, config) 125 | print('start test...............') 126 | print("-- test MAP %.5f -- test MRR %.5f" % (test_MAP, test_MRR)) 127 | 128 | 129 | def train(train_corpus, val_corpus, test_corpus, config, eval_train_corpus=None): 130 | iterator = Iterator(train_corpus) 131 | if not os.path.exists(save_path): 132 | os.makedirs(save_path) 133 | if not os.path.exists(best_path): 134 | os.makedirs(best_path) 135 | 136 | with tf.Session() as sess: 137 | # training 138 | print('Start training and evaluating ...') 139 | start_time = time.time() 140 | 141 | model = SiameseQACNN(config) 142 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=10) 143 | best_saver = tf.train.Saver(tf.global_variables(), max_to_keep=5) 144 | ckpt = tf.train.get_checkpoint_state(save_path) 145 | print('Configuring TensorBoard and Saver ...') 146 | summary_writer = tf.summary.FileWriter(save_path, graph=sess.graph) 147 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 148 | print('Reloading model parameters..') 149 | saver.restore(sess, ckpt.model_checkpoint_path) 150 | else: 151 | print('Created new model parameters..') 152 | sess.run(tf.global_variables_initializer()) 153 | 154 | # count trainable parameters 155 | total_parameters = count_parameters() 156 | print('Total trainable parameters : {}'.format(total_parameters)) 157 | 158 | current_step = 0 159 | best_map_val = 0.0 160 | best_mrr_val = 0.0 161 | last_dev_map = 0.0 162 | last_dev_mrr = 0.0 163 | for epoch in range(config.num_epochs): 164 | print("----- Epoch {}/{} -----".format(epoch + 1, config.num_epochs)) 165 | count = 0 166 | for batch_x in iterator.next(config.batch_size, shuffle=True): 167 | 168 | batch_q, batch_a_pos, batch_a_neg, batch_qmask, batch_a_pos_mask, batch_a_neg_mask = zip(*batch_x) 169 | batch_q = np.asarray(batch_q) 170 | batch_a_pos = np.asarray(batch_a_pos) 171 | batch_a_neg = np.asarray(batch_a_neg) 172 | _, loss, summary, train_acc = sess.run([model.train_op, model.total_loss, model.summary_op, model.accu], 173 | feed_dict={model._ques: batch_q, 174 | model._ans: batch_a_pos, 175 | model._ans_neg: batch_a_neg, 176 | model.dropout_keep_prob: config.keep_prob}) 177 | count += 1 178 | current_step += 1 179 | if count % 100 == 0: 180 | print('[epoch {}, batch {}]Loss:{}'.format(epoch, count, loss)) 181 | summary_writer.add_summary(summary, current_step) 182 | if eval_train_corpus is not None: 183 | train_MAP, train_MRR, train_Loss, train_acc_ = evaluate(sess, model, eval_train_corpus, config) 184 | print("--- epoch %d -- train Loss %.5f -- train Acc %.5f -- train MAP %.5f -- train MRR %.5f" % ( 185 | epoch+1, train_Loss, train_acc_, train_MAP, train_MRR)) 186 | if val_corpus is not None: 187 | dev_MAP, dev_MRR, dev_Loss, dev_acc = evaluate(sess, model, val_corpus, config) 188 | print("--- epoch %d -- dev Loss %.5f -- dev Acc %.5f --dev MAP %.5f -- dev MRR %.5f" % ( 189 | epoch + 1, dev_Loss, dev_acc, dev_MAP, dev_MRR)) 190 | logger.info("\nEvaluation:") 191 | logger.info("--- epoch %d -- dev Loss %.5f -- dev Acc %.5f --dev MAP %.5f -- dev MRR %.5f" % ( 192 | epoch + 1, dev_Loss, dev_acc, dev_MAP, dev_MRR)) 193 | 194 | test_MAP, test_MRR, test_Loss, test_acc= evaluate(sess, model, test_corpus, config) 195 | print("--- epoch %d -- test Loss %.5f -- test Acc %.5f --test MAP %.5f -- test MRR %.5f" % ( 196 | epoch + 1, test_Loss, test_acc, test_MAP, test_MRR)) 197 | logger.info("\nTest:") 198 | logger.info("--- epoch %d -- test Loss %.5f -- dev Acc %.5f --test MAP %.5f -- test MRR %.5f" % ( 199 | epoch + 1, test_Loss, test_acc, test_MAP, test_MRR)) 200 | 201 | checkpoint_path = os.path.join(save_path, 'map{:.5f}_{}.ckpt'.format(test_MAP, current_step)) 202 | bestcheck_path = os.path.join(best_path, 'map{:.5f}_{}.ckpt'.format(test_MAP, current_step)) 203 | saver.save(sess, checkpoint_path, global_step=epoch) 204 | if test_MAP > best_map_val or test_MRR > best_mrr_val: 205 | best_map_val = test_MAP 206 | best_mrr_val = test_MRR 207 | best_saver.save(sess, bestcheck_path, global_step=epoch) 208 | last_dev_map = test_MAP 209 | last_dev_mrr = test_MRR 210 | logger.info("\nBest and Last:") 211 | logger.info('--- best_MAP %.4f -- best_MRR %.4f -- last_MAP %.4f -- last_MRR %.4f'% ( 212 | best_map_val, best_mrr_val, last_dev_map, last_dev_mrr)) 213 | print('--- best_MAP %.4f -- best_MRR %.4f -- last_MAP %.4f -- last_MRR %.4f' % ( 214 | best_map_val, best_mrr_val, last_dev_map, last_dev_mrr)) 215 | 216 | 217 | def main(args): 218 | max_q_length = 25 219 | max_a_length = 90 220 | processed_data_path_pairwise = '../data/WikiQA/processed/pairwise' 221 | train_file = os.path.join(processed_data_path_pairwise, 'WikiQA-train-triplets.tsv') 222 | dev_file = os.path.join(processed_data_path_pairwise, 'WikiQA-dev.tsv') 223 | test_file = os.path.join(processed_data_path_pairwise, 'WikiQA-test.tsv') 224 | vocab = os.path.join(processed_data_path_pairwise, 'wiki_clean_vocab.txt') 225 | embeddings_file = os.path.join(processed_data_path_pairwise, 'wiki_embedding.pkl') 226 | _embeddings = load_embedding(embeddings_file) 227 | train_transform = transform_train(train_file, vocab) 228 | dev_transform = transform(dev_file, vocab) 229 | test_transform = transform(test_file, vocab) 230 | train_corpus = load_train_data(train_transform, max_q_length, max_a_length) 231 | dev_corpus = load_data(dev_transform, max_q_length, max_a_length, keep_ids=True) 232 | test_corpus = load_data(test_transform, max_q_length, max_a_length, keep_ids=True) 233 | 234 | config = NNConfig(embeddings=_embeddings) 235 | config.ques_length = max_q_length 236 | config.ans_length = max_a_length 237 | if args.train: 238 | train(deepcopy(train_corpus), dev_corpus, test_corpus, config) 239 | elif args.test: 240 | test(test_corpus, config) 241 | 242 | 243 | if __name__ == '__main__': 244 | import argparse 245 | parser = argparse.ArgumentParser() 246 | parser.add_argument("--train", help="whether to train", action='store_true') 247 | parser.add_argument("--test", help="whether to test", action='store_true') 248 | args = parser.parse_args() 249 | 250 | save_path = "./model/checkpoint" 251 | best_path = "./model/bestval" 252 | main(args) 253 | --------------------------------------------------------------------------------