├── .gitignore ├── picture ├── train_acc.png ├── train_loss.png └── valid_acc.png ├── README.md ├── make_tfrecords.py ├── train.py └── modules.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | data/ 3 | model/ 4 | logs/ 5 | WebQA.v1.0/ -------------------------------------------------------------------------------- /picture/train_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeliangLi/Chinese-QASystem/HEAD/picture/train_acc.png -------------------------------------------------------------------------------- /picture/train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeliangLi/Chinese-QASystem/HEAD/picture/train_loss.png -------------------------------------------------------------------------------- /picture/valid_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeliangLi/Chinese-QASystem/HEAD/picture/valid_acc.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chinese-QASystem 2 | Chinese question answering system based on BLSTM and CRF. 3 |
4 |
5 | 6 | Requirement 7 | ------- 8 | tensorflow 1.5
9 | numpy
10 | thulac
11 | scikit-learn
12 | matplotlib

13 | 14 | DataSet 15 | ----- 16 | 百度的中文问答数据集[WebQA](https://www.spaces.ac.cn/archives/4338/),非常感谢该链接的作者对数据的整理。

17 | 18 | How to get start? 19 | ----- 20 | 1.Download the raw data and extract it to the folder where the source code is located. 21 | 2.python3 make_tfrecords.py.Processing the raw data to generate the tfrecord files for training and validating. 22 | In this experiment,I used 200,000 corpus to train and validate the accuracy of the model on 5000 corpus. 23 | 3.python3 train.py.All the training results as shown below.It is not hard to find that the model eventually 24 | achieved an accuracy of 0.6050 on the validation set. 25 | 26 | ![](https://github.com/YeliangLi/Chinese-QASystem/raw/master/picture/train_loss.png)
27 | ![](https://github.com/YeliangLi/Chinese-QASystem/raw/master/picture/train_acc.png)
28 | ![](https://github.com/YeliangLi/Chinese-QASystem/raw/master/picture/valid_acc.png)
29 | 30 | Note 31 | ---- 32 | In the future, I will write a blog to introduce this work and you will learn how to use tensorflow's tf.while_loop interface to implement conditional random field training and Viterbi decoding.

33 | 34 | 35 | References 36 | ----- 37 | [Li P, Li W, He Z, et al. Dataset and Neural Recurrent Sequence Labeling Model for Open-Domain Factoid Question Answering[J]. 2016.](https://arxiv.org/abs/1607.06275) 38 | 39 | 40 | -------------------------------------------------------------------------------- /make_tfrecords.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | Author: Yeliang Li 5 | Blog: http://blog.yeliangli.com/ 6 | Created: 2018/2/14 7 | """ 8 | 9 | import tensorflow as tf 10 | from sklearn.externals import joblib 11 | from sklearn.model_selection import train_test_split 12 | import os 13 | import json 14 | import thulac 15 | import re 16 | import modules 17 | 18 | 19 | path = "./WebQA.v1.0" 20 | #{"B":0,"I":1,"O1":2,"O2":3,"START":4,"STOP":5} 21 | tag_to_ix = modules.tag_to_ix 22 | evidences_sampling_rate = 0.8 23 | 24 | def write(lang,file,writer,data_size): 25 | thul = thulac.thulac(seg_only=True) 26 | if not os.path.isdir("data"): 27 | os.mkdir("data") 28 | with open(file,'rt',encoding='utf-8') as file: 29 | dict = file.read() 30 | dict = json.loads(dict,encoding='utf-8') 31 | count = 0 32 | number = 0 33 | for key in dict: 34 | if count == data_size: 35 | break 36 | ques = [] 37 | for word in thul.cut(dict[key]['question'],text=True).split(): 38 | index = lang.addWord(word) 39 | ques.append(index) 40 | evidences = [] 41 | evidences_tags = [] 42 | for record in dict[key]['evidences']: 43 | number += 1 44 | print("number:%d question number:%s evidence number:%s" %(number,key,record)) 45 | answer = dict[key]['evidences'][record]['answer'][0] 46 | evidence = dict[key]['evidences'][record]['evidence'] 47 | if answer != 'no_answer': 48 | evidence = re.sub(answer,"XXX",evidence) 49 | answer = thul.cut(answer,text=True).split() 50 | answer_tags = [] 51 | answer_indices = [] 52 | for i in range(len(answer)): 53 | if i != 0: 54 | answer_tags.append(tag_to_ix['I']) 55 | else: 56 | answer_tags.append(tag_to_ix['B']) 57 | answer_indices.append(lang.addWord(answer[i])) 58 | evidence = thul.cut(evidence,text=True).split() 59 | evidence_indices = [] 60 | evidence_tags = [] 61 | before_answer = True 62 | for word in evidence: 63 | if word != "XXX": 64 | evidence_indices.append(lang.addWord(word)) 65 | else: 66 | evidence_indices += answer_indices 67 | evidence_tags += answer_tags 68 | before_answer = False 69 | continue 70 | if before_answer: 71 | evidence_tags.append(tag_to_ix["O1"]) 72 | else: 73 | evidence_tags.append(tag_to_ix["O2"]) 74 | evidences.append(evidence_indices) 75 | evidences_tags.append(evidence_tags) 76 | else: 77 | evidence_indices = [] 78 | evidence_tags = [] 79 | for word in thul.cut(evidence,text=True).split(): 80 | evidence_indices.append(lang.addWord(word)) 81 | evidence_tags.append(tag_to_ix["O1"]) 82 | evidences.append(evidence_indices) 83 | evidences_tags.append(evidence_tags) 84 | selected_evidences,rest_evidences,selected_evidences_tags,rest_evidences_tags = train_test_split(evidences, 85 | evidences_tags, 86 | test_size=1-evidences_sampling_rate, 87 | random_state=0 88 | ) 89 | count += len(selected_evidences) 90 | if count > data_size: 91 | count -= len(selected_evidences) 92 | selected_evidences = selected_evidences[0:(data_size - count)] 93 | count = data_size 94 | for i in range(len(selected_evidences)): 95 | e_e_comm_fea = [] 96 | q_e_comm_fea = [] 97 | for index in selected_evidences[i]: 98 | if index in ques: 99 | q_e_comm_fea.append(1) 100 | else: 101 | q_e_comm_fea.append(0) 102 | comm_tag = False 103 | for evidence in rest_evidences: 104 | if index in evidence: 105 | e_e_comm_fea.append(1) 106 | comm_tag = True 107 | break 108 | if not comm_tag: 109 | e_e_comm_fea.append(0) 110 | feas = {} 111 | feas['question'] = tf.train.Feature(int64_list=tf.train.Int64List(value=ques)) 112 | feas['evidence'] = tf.train.Feature(int64_list=tf.train.Int64List(value=selected_evidences[i])) 113 | feas['evidence_tags'] = tf.train.Feature(int64_list=tf.train.Int64List(value=selected_evidences_tags[i])) 114 | feas['q_e_comm'] = tf.train.Feature(int64_list=tf.train.Int64List(value=q_e_comm_fea)) 115 | feas['e_e_comm'] = tf.train.Feature(int64_list=tf.train.Int64List(value=e_e_comm_fea)) 116 | feas['question_length'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[len(ques)])) 117 | feas['evidence_length'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[len(selected_evidences[i])])) 118 | features_to_write = tf.train.Example(features=tf.train.Features(feature=feas)) 119 | writer.write(features_to_write.SerializeToString()) 120 | writer.close() 121 | return count 122 | 123 | 124 | if __name__ == "__main__": 125 | lang = modules.Lang("Chi") 126 | if not os.path.isdir("data"): 127 | os.mkdir("data") 128 | 129 | writer = tf.python_io.TFRecordWriter("./data/trainData.tfrecords") 130 | train_data_size = write(lang,os.path.join(path,"me_train.json"),writer,200000) 131 | joblib.dump(train_data_size,"./data/trainDataSize.pkl") 132 | 133 | writer = tf.python_io.TFRecordWriter("./data/validData.tfrecords") 134 | valid_data_size = write(lang,os.path.join(path,"me_validation.ir.json"),writer,5000) 135 | joblib.dump(valid_data_size,"./data/validDataSize.pkl") 136 | 137 | joblib.dump(lang,"./data/lang.pkl") 138 | 139 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | Author: Yeliang Li 5 | Blog: http://blog.yeliangli.com/ 6 | Created: 2018/2/14 7 | """ 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | from sklearn.externals import joblib 12 | import os 13 | import modules 14 | 15 | tf.app.flags.DEFINE_integer("batch_size",100,"batch size for training") 16 | tf.app.flags.DEFINE_integer("epoches",17,"epoches for training") 17 | tf.app.flags.DEFINE_integer("buffer_size",500000,"representing the number of bytes in the read buffer") 18 | tf.app.flags.DEFINE_integer("num_parallel_calls",100," representing the number elements to process in parallel") 19 | tf.app.flags.DEFINE_integer("word_embedding",64,"word embedding") 20 | tf.app.flags.DEFINE_integer("feature_embedding",2,"feature embedding") 21 | tf.app.flags.DEFINE_integer("cell_size",64,"cell size of LSTM") 22 | tf.app.flags.DEFINE_float("dropout_rate",0.05,"dropout rate for output of all the LSTM layers") 23 | tf.app.flags.DEFINE_bool("training",True,"whether to train the model") 24 | tf.app.flags.DEFINE_float("lr",0.005,"learning rate for training") 25 | tf.app.flags.DEFINE_float("decay_steps",8000,"decay steps for learning rate") 26 | tf.app.flags.DEFINE_float("decay_rate",0.75,"decay rate for learning rate") 27 | FLAGS = tf.app.flags.FLAGS 28 | 29 | def parse(serialized): 30 | features = {} 31 | features["question"] = tf.VarLenFeature(tf.int64) 32 | features["evidence"] = tf.VarLenFeature(tf.int64) 33 | features["evidence_tags"] = tf.VarLenFeature(tf.int64) 34 | features["q_e_comm"] = tf.VarLenFeature(tf.int64) 35 | features["e_e_comm"] = tf.VarLenFeature(tf.int64) 36 | features["question_length"] = tf.FixedLenFeature([1],tf.int64) 37 | features['evidence_length'] = tf.FixedLenFeature([1],tf.int64) 38 | features = tf.parse_single_example(serialized,features) 39 | question = tf.sparse_tensor_to_dense(features["question"]) 40 | evidence = tf.sparse_tensor_to_dense(features["evidence"]) 41 | evidence_tags = tf.sparse_tensor_to_dense(features["evidence_tags"]) 42 | q_e_comm = tf.sparse_tensor_to_dense(features["q_e_comm"]) 43 | e_e_comm = tf.sparse_tensor_to_dense(features["e_e_comm"]) 44 | return question,evidence,evidence_tags,q_e_comm,e_e_comm,features["question_length"],features["evidence_length"] 45 | 46 | 47 | def main(_): 48 | lang = joblib.load("./data/lang.pkl") 49 | files = ["./data/trainData.tfrecords","./data/validData.tfrecords"] 50 | padded_shapes = tuple([[-1]]*5+[[1]]*2) 51 | padding_values = tuple([np.int64(0)]*2+[np.int64(-1)]+[np.int64(0)]*4) 52 | questions_list = [] 53 | evidences_list = [] 54 | targets_tags_list = [] 55 | q_e_comm_feas_list = [] 56 | e_e_comm_feas_list = [] 57 | question_length_list = [] 58 | evidence_length_list = [] 59 | 60 | for i in range(2): 61 | dataset = tf.data.TFRecordDataset(files[i],buffer_size=FLAGS.buffer_size).map(parse,FLAGS.num_parallel_calls).padded_batch(FLAGS.batch_size,padded_shapes,padding_values).repeat(FLAGS.epoches) 62 | iterator = dataset.make_one_shot_iterator() 63 | input_elements = iterator.get_next() 64 | questions_list.append(input_elements[0]) 65 | evidences_list.append(input_elements[1]) 66 | targets_tags_list.append(input_elements[2]) 67 | q_e_comm_feas_list.append(input_elements[3]) 68 | e_e_comm_feas_list.append(input_elements[4]) 69 | question_length_list.append(tf.reshape(input_elements[5],[-1,])) 70 | evidence_length_list.append(tf.reshape(input_elements[6],[-1,])) 71 | 72 | #model definition 73 | training = tf.placeholder(tf.bool) 74 | word_embedding = modules.Embedding([lang.n_words,FLAGS.word_embedding],"word_embedding") 75 | q_e_comm_feas_embedding = modules.Embedding([2,FLAGS.feature_embedding], 76 | "q_e_comm_feas_embedding") 77 | e_e_comm_feas_embedding = modules.Embedding([2,FLAGS.feature_embedding], 78 | "e_e_comm_feas_embedding") 79 | question_LSTM = modules.QuestionLSTM(FLAGS.cell_size, 80 | FLAGS.dropout_rate, 81 | training) 82 | evidence_LSTMs = modules.EvidenceLSTMs(FLAGS.cell_size,FLAGS.dropout_rate,training) 83 | crf = modules.CRF() 84 | 85 | with tf.device("/gpu:0"): 86 | reuse = False 87 | decoded_list = [] 88 | for i in range(2): 89 | with tf.variable_scope("model",reuse=reuse): 90 | questions = word_embedding(questions_list[i]) #(batch_size,max_ques_length,FLAGS.word_embedding) 91 | evidences = word_embedding(evidences_list[i]) #(batch_size,max_evidences_length,FLAGS.word_embedding) 92 | q_e_comm_feas = q_e_comm_feas_embedding(q_e_comm_feas_list[i]) #(batch_size,max_q_e_comm_feas_length,FLAGS.feature_embedding) 93 | e_e_comm_feas = e_e_comm_feas_embedding(e_e_comm_feas_list[i]) #(batch_size,max_e_e_comm_feas_length,FLAGS.feature_embedding) 94 | q_r = question_LSTM(questions,question_length_list[i]) #(batch_size,FLAGS.cell_size) 95 | q_r = tf.tile(tf.expand_dims(q_r,axis=1),[1,tf.shape(evidences)[1],1]) 96 | x = tf.concat([evidences,q_r,q_e_comm_feas,e_e_comm_feas],axis=2) 97 | outputs = evidence_LSTMs(x,evidence_length_list[i]) 98 | crf.build() 99 | decoded = crf.viterbi_decode(outputs,evidence_length_list[i]) #(batch_size,max_evidences_length) 100 | decoded_list.append(decoded) 101 | if i == 0: 102 | logits = crf.neg_log_likelihood(outputs,evidence_length_list[i],targets_tags_list[i]) #(batch_size,) 103 | loss = tf.reduce_sum(logits) 104 | loss /= tf.to_float(tf.shape(evidences)[0]) 105 | with tf.name_scope("optimizer"): 106 | global_step = tf.Variable(tf.constant(0,tf.int32),trainable=False) 107 | lr = tf.train.exponential_decay(FLAGS.lr, 108 | global_step, 109 | FLAGS.decay_steps, 110 | FLAGS.decay_rate, 111 | staircase=True) 112 | optimizer = tf.train.RMSPropOptimizer(lr) 113 | tvars = tf.trainable_variables() 114 | grads = tf.gradients(loss,tvars) 115 | train_op = optimizer.apply_gradients(zip(grads,tvars),global_step) 116 | reuse = True 117 | 118 | record_loss = tf.placeholder(tf.float32) 119 | record_accuracy = tf.placeholder(tf.float32) 120 | train_merged = [] 121 | train_merged.append(tf.summary.scalar("train_loss",record_loss)) 122 | train_merged.append(tf.summary.scalar("train_accuracy",record_accuracy)) 123 | train_merged = tf.summary.merge(train_merged) 124 | valid_summary = tf.summary.scalar("valid_accuracy",record_accuracy) 125 | 126 | saver = tf.train.Saver(var_list=tvars) 127 | log_device_placement=True 128 | allow_soft_placement=True 129 | config = tf.ConfigProto(log_device_placement=True,allow_soft_placement=True) 130 | config.gpu_options.allow_growth = True 131 | with tf.Session(config=config) as sess: 132 | sess.run(tf.global_variables_initializer()) 133 | train_fileWriter = tf.summary.FileWriter("logs/train",sess.graph) 134 | valid_fileWriter = tf.summary.FileWriter("logs/valid",sess.graph) 135 | train_size = joblib.load("./data/trainDataSize.pkl") 136 | valid_size = joblib.load("./data/validDataSize.pkl") 137 | assert train_size % FLAGS.batch_size == 0,"train_size can't be divisible by batch." 138 | assert valid_size % FLAGS.batch_size == 0,"valid_size can't be divisible by batch." 139 | try: 140 | for i in range(FLAGS.epoches): 141 | total_loss = 0.0 142 | total_accuracy = 0.0 143 | for j in range(train_size // FLAGS.batch_size): 144 | _,cost,hypothesis,truth,step = sess.run([train_op,loss,decoded_list[0],targets_tags_list[0],global_step],{training:FLAGS.training}) 145 | accuracy,_,_ = crf.compute_accuracy(hypothesis,truth) 146 | total_loss += cost 147 | total_accuracy += accuracy 148 | print("global_step:%d epoch:%d batch:%d train_loss:%f train_accuracy:%f" 149 | %(step,i+1,j+1,cost,accuracy)) 150 | if step % 500 == 0: 151 | saver.save(sess,"./model/QASystem",step) 152 | total_loss /= (train_size / FLAGS.batch_size) 153 | total_accuracy /= (train_size / FLAGS.batch_size) 154 | summary = sess.run(train_merged,{record_loss:total_loss,record_accuracy:total_accuracy}) 155 | train_fileWriter.add_summary(summary,i+1) 156 | 157 | total_loss = 0.0 158 | total_accuracy = 0.0 159 | print("\n") 160 | for j in range(valid_size // FLAGS.batch_size): 161 | hypothesis,truth = sess.run([decoded_list[1],targets_tags_list[1]],{training:False}) 162 | accuracy,_,_ = crf.compute_accuracy(hypothesis,truth) 163 | total_accuracy += accuracy 164 | print("epoch:%d batch:%d valid_accuracy:%f" %(i+1,j+1,accuracy)) 165 | total_accuracy /= (valid_size / FLAGS.batch_size) 166 | summary = sess.run(valid_summary,{record_accuracy:total_accuracy}) 167 | valid_fileWriter.add_summary(summary,i+1) 168 | print("\n") 169 | except BaseException: 170 | saver.save(sess,"./model/QASystem") 171 | saver.save(sess,"./model/QASystem") 172 | 173 | if __name__ == "__main__": 174 | tf.app.run() 175 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | Author: Yeliang Li 5 | Blog: http://blog.yeliangli.com/ 6 | Created: 2018/2/14 7 | """ 8 | 9 | import tensorflow as tf 10 | import tensorflow.contrib.eager as tfe 11 | from tensorflow.python.layers import base 12 | 13 | class Lang(object): 14 | def __init__(self, name): 15 | self.name = name 16 | self.word2index = {} 17 | self.word2count = {} 18 | self.index2word = {0:'PAD',1:'UNK'} 19 | self.n_words = 2 # Count PAD and UNK 20 | 21 | def addWord(self, word): 22 | if word not in self.word2index: 23 | self.word2index[word] = self.n_words 24 | self.word2count[word] = 1 25 | self.index2word[self.n_words] = word 26 | self.n_words += 1 27 | else: 28 | self.word2count[word] += 1 29 | return self.word2index[word] 30 | 31 | class Embedding(base.Layer): 32 | def __init__(self,embedding_shape,name,trainable=True): 33 | super(Embedding,self).__init__(trainable=trainable,name=name) 34 | self.embedding_shape = embedding_shape 35 | 36 | def build(self,_): 37 | assert len(self.embedding_shape) == 2, "The length of embedding_shape is not equal to 2." 38 | self.embedding = self.add_variable("embedding", 39 | self.embedding_shape, 40 | tf.float32 41 | ) 42 | 43 | def call(self, ids): 44 | outputs = tf.nn.embedding_lookup(self.embedding,ids) 45 | return outputs 46 | 47 | class AttentionLayer(base.Layer): 48 | def __init__(self, 49 | num_units, 50 | score_value=float("-inf"), 51 | trainable=True, 52 | name="attentionLayer"): 53 | super(AttentionLayer,self).__init__(trainable=trainable,name=name) 54 | self.num_units = num_units 55 | self.score_value = score_value 56 | 57 | def build(self,input_shape): 58 | assert len(input_shape.as_list()) == 3, "The dimension of inputs is not equal to 3." 59 | self.W = self.add_variable("W",[input_shape[-1].value,self.num_units],tf.float32) 60 | self.v = self.add_variable("v",[self.num_units,1],tf.float32) 61 | 62 | def call(self,inputs,sequence_length=None): 63 | W = tf.tile(tf.expand_dims(self.W,0),[tf.shape(inputs)[0],1,1]) 64 | v = tf.tile(tf.expand_dims(self.v,0),[tf.shape(inputs)[0],1,1]) 65 | scores = tf.squeeze(tf.matmul(tf.nn.tanh(tf.matmul(inputs,W)),v), 66 | axis=2 67 | ) 68 | if sequence_length != None: 69 | mask = tf.sequence_mask(sequence_length,maxlen=tf.shape(inputs)[1]) 70 | scores = tf.where(mask,scores,tf.ones_like(scores)*self.score_value) 71 | alignments = tf.nn.softmax(scores) 72 | r = tf.squeeze(tf.matmul(tf.expand_dims(alignments,axis=1),inputs),axis=1) 73 | return r 74 | 75 | class QuestionLSTM(): 76 | def __init__(self, 77 | num_units, 78 | dropout_rate=0.0, 79 | training = True, 80 | name="questionLSTM"): 81 | self.dropout_rate = dropout_rate 82 | self.training = training 83 | self.name = name 84 | self.lstm_cell = tf.contrib.rnn.LSTMCell(num_units,use_peepholes=True) 85 | self.attention = AttentionLayer(num_units) 86 | 87 | def __call__(self,inputs,sequence_length): 88 | with tf.variable_scope(self.name): 89 | state = self.lstm_cell.zero_state(tf.shape(inputs)[0],tf.float32) 90 | outputs,_ = tf.nn.dynamic_rnn(self.lstm_cell, 91 | inputs, 92 | sequence_length, 93 | state, 94 | tf.float32) 95 | outputs = tf.layers.dropout(outputs,self.dropout_rate,training=self.training) 96 | q_r = self.attention(outputs,sequence_length) 97 | return q_r 98 | 99 | class EvidenceLSTMs(): 100 | def __init__(self, 101 | num_units, 102 | dropout_rate = 0.0, 103 | training = True, 104 | name="evidenceLSTMs"): 105 | self.layer1 = tf.contrib.rnn.LSTMCell(num_units,use_peepholes=True,name="layer1") 106 | self.layer2 = tf.contrib.rnn.LSTMCell(num_units,use_peepholes=True,name="layer2") 107 | self.layer3 = tf.contrib.rnn.LSTMCell(num_units,use_peepholes=True,name="layer3") 108 | self.output_dense = tf.layers.Dense(len(tag_to_ix)) 109 | self.dropout_rate = dropout_rate 110 | self.training = training 111 | self.name = name 112 | 113 | def __call__(self,inputs,sequence_length): 114 | with tf.variable_scope(self.name): 115 | state = self.layer1.zero_state(tf.shape(inputs)[0],tf.float32) 116 | layer1_outputs,_ = tf.nn.dynamic_rnn(self.layer1, 117 | inputs, 118 | sequence_length, 119 | state, 120 | tf.float32) 121 | layer1_outputs = tf.layers.dropout(layer1_outputs,self.dropout_rate,training=self.training) 122 | layer1_outputs_reversed = tf.reverse_sequence(layer1_outputs,sequence_length,seq_dim=1) 123 | layer2_outputs,_ = tf.nn.dynamic_rnn(self.layer2, 124 | layer1_outputs_reversed, 125 | sequence_length, 126 | state, 127 | tf.float32) 128 | layer2_outputs = tf.reverse_sequence(layer2_outputs,sequence_length,seq_dim=1) 129 | layer2_outputs = tf.layers.dropout(layer2_outputs,self.dropout_rate,training=self.training) 130 | layer3_inputs = tf.concat([layer1_outputs,layer2_outputs],axis=2) 131 | outputs,_ = tf.nn.dynamic_rnn(self.layer3, 132 | layer3_inputs, 133 | sequence_length, 134 | state, 135 | tf.float32) 136 | outputs = tf.layers.dropout(outputs,self.dropout_rate,training=self.training) 137 | outputs = self.output_dense(outputs) 138 | return outputs 139 | 140 | tag_to_ix = {"B":0,"I":1,"O1":2,"O2":3,"START":4,"STOP":5} 141 | 142 | def log_sum_exp(inputs): 143 | max_scores = tf.reduce_max(inputs,axis=1,keepdims=True) 144 | return tf.squeeze(max_scores,1) + tf.log(tf.reduce_sum(tf.exp(inputs - max_scores),axis=1)) 145 | 146 | class CRF(): 147 | def __init__(self,name="crf"): 148 | self.name = name 149 | def build(self): 150 | with tf.variable_scope(self.name): 151 | self.transitions = tf.get_variable("transitions", 152 | [len(tag_to_ix),len(tag_to_ix)], 153 | tf.float32) 154 | indices = [] 155 | updates = [] 156 | for i in range(len(tag_to_ix)): 157 | indices.append([i,tag_to_ix["START"]]) 158 | indices.append([tag_to_ix["STOP"],i]) 159 | updates += [-1e+8,-1e+8] 160 | 161 | self.transitions = tf.scatter_nd_update(self.transitions,indices,updates) 162 | self.init_alphas = tf.get_variable("init_alphas", 163 | [len(tag_to_ix)], 164 | tf.float32, 165 | tf.constant_initializer(-1e+8), 166 | trainable=False) 167 | self.init_alphas = tf.scatter_update(self.init_alphas, 168 | [tag_to_ix["START"]], 169 | [tf.constant(0.0)]) 170 | 171 | def neg_log_likelihood(self,inputs,sequence_length,targets): 172 | sequence_length = tf.to_int32(sequence_length) 173 | targets = tf.to_int32(targets) 174 | i0 = tf.constant(0,tf.int32) 175 | alphas_0 = tf.tile(tf.expand_dims(self.init_alphas,0),[tf.shape(inputs)[0],1]) 176 | max_seq_len = tf.reduce_max(sequence_length) 177 | alphas_array = tf.TensorArray(tf.float32,size=max_seq_len) 178 | 179 | scores_array = tf.TensorArray(tf.float32,size=max_seq_len) 180 | initial_scores = alphas_0[:,tag_to_ix["START"]] 181 | targets = tf.concat([tf.ones([tf.shape(inputs)[0],1],tf.int32)*tag_to_ix["START"], 182 | targets], 183 | axis=1) 184 | def body(i,alphas_t,ta1,scores,ta2): 185 | inp = inputs[:,i,:] #(batch_size,len(tag_to_ix)) 186 | emit_scores = tf.tile(tf.expand_dims(inp,1),[1,len(tag_to_ix),1]) 187 | forward_vars = tf.tile(tf.expand_dims(alphas_t,2),[1,1,len(tag_to_ix)]) 188 | next_tag_vars = forward_vars + self.transitions + emit_scores 189 | alphas_t_plus_1 = log_sum_exp(next_tag_vars) #(batch_size,len(tag_to_ix)) 190 | 191 | indices = tf.stack([tf.range(0,tf.shape(inputs)[0],dtype=tf.int32), 192 | targets[:,i+1]]) 193 | indices = tf.transpose(indices,[1,0]) 194 | scores += tf.gather_nd(self.transitions,targets[:,i:i+2]) +\ 195 | tf.gather_nd(inp,indices) 196 | scores = tf.reshape(scores,[tf.shape(inputs)[0]]) 197 | return i+1,alphas_t_plus_1,ta1.write(i,alphas_t_plus_1),scores,ta2.write(i,scores) 198 | _,_,ta1_final,_,ta2_final = tf.while_loop(lambda i,alphas_t,ta1,scores,ta2: i < max_seq_len, 199 | body, 200 | loop_vars=[i0,alphas_0,alphas_array,initial_scores,scores_array]) 201 | 202 | ta1_final_result = ta1_final.stack() 203 | indices1 = tf.stack([sequence_length-1, 204 | tf.range(0,tf.shape(inputs)[0],dtype=tf.int32)]) 205 | indices1 = tf.transpose(indices1,[1,0]) 206 | terminal_vars = tf.gather_nd(ta1_final_result,indices1) + self.transitions[:,tag_to_ix["STOP"]] 207 | forward_scores = log_sum_exp(terminal_vars) #(batch_size) 208 | 209 | ta2_final_result = ta2_final.stack() 210 | indices2 = tf.stack([tf.gather_nd(tf.transpose(targets,[1,0]),indices1), 211 | tf.ones([tf.shape(inputs)[0]],tf.int32)*tag_to_ix["STOP"]]) 212 | indices2 = tf.transpose(indices2,[1,0]) 213 | gold_scores = tf.gather_nd(ta2_final_result,indices1) +\ 214 | tf.gather_nd(self.transitions,indices2) 215 | 216 | return forward_scores - gold_scores # (batch_size,) 217 | 218 | def viterbi_decode(self,inputs,sequence_length): 219 | sequence_length = tf.to_int32(sequence_length) 220 | alphas_0 = tf.tile(tf.expand_dims(self.init_alphas,0), 221 | [tf.shape(inputs)[0],1]) #(batch_size,len(tag_to_ix)) 222 | i0 = tf.constant(0,tf.int32) 223 | max_seq_len = tf.reduce_max(sequence_length) 224 | best_tag_ids_array = tf.TensorArray(tf.int32,size=max_seq_len) 225 | alphas_array = tf.TensorArray(tf.float32,size=max_seq_len) 226 | 227 | def body1(i,alphas_t,ta1,ta2): 228 | forward_vars = tf.tile(tf.expand_dims(alphas_t,2),[1,1,len(tag_to_ix)]) 229 | next_tag_vars = forward_vars + self.transitions 230 | best_tag_ids = tf.argmax(next_tag_vars,axis=1,output_type=tf.int32) #(batch_size,len(tag_to_ix)) 231 | alphas_t_plus_1 = tf.reduce_max(next_tag_vars,axis=1) + inputs[:,i,:] #(batch_size,len(tag_to_ix)) 232 | return i+1,alphas_t_plus_1,ta1.write(i,best_tag_ids),ta2.write(i,alphas_t_plus_1) 233 | _,_,ta1_final,ta2_final = tf.while_loop(lambda i,alpha_t,ta1,ta2: i < max_seq_len, 234 | body1, 235 | loop_vars=[i0,alphas_0,best_tag_ids_array,alphas_array]) 236 | 237 | bptrs = ta1_final.stack() 238 | 239 | ta2_final_result = ta2_final.stack() 240 | indices = tf.stack([sequence_length-1, 241 | tf.range(0,tf.shape(inputs)[0],dtype=tf.int32)]) 242 | indices = tf.transpose(indices,[1,0]) 243 | terminal_vars = tf.gather_nd(ta2_final_result,indices) + self.transitions[:,tag_to_ix["STOP"]] 244 | terminal_id = tf.argmax(terminal_vars,axis=1,output_type=tf.int32) #(batch_size) 245 | 246 | path_array = tf.TensorArray(tf.int32,size=tf.shape(inputs)[0]) 247 | def body2(i,ta): 248 | def sub_body(j,past_id,best_path): 249 | best_tag_id = tf.gather_nd(bptrs,[[j,i]]) #(1,len(tag_to_ix)) 250 | next_id = tf.gather_nd(best_tag_id,[[0,past_id]]) 251 | best_path = tf.concat([next_id,best_path],axis=0) 252 | return j-1,next_id[0],best_path 253 | _,_,best_path = tf.while_loop(lambda j,past_id,best_path: j >= 0, 254 | sub_body, 255 | loop_vars=[sequence_length[i] - 1,terminal_id[i],terminal_id[i:i+1]], 256 | shape_invariants=[tf.TensorShape([]),tf.TensorShape([]),tf.TensorShape([None])] 257 | ) 258 | best_path = tf.pad(best_path,[[0,max_seq_len-sequence_length[i]]],constant_values=-1) 259 | return i+1,ta.write(i,best_path) 260 | _,ta_final = tf.while_loop(lambda i,ta: i < tf.shape(inputs)[0], 261 | body2, 262 | loop_vars=[i0,path_array]) 263 | 264 | best_paths = ta_final.stack() #shape=(batch_size,1+max_seq_len),padding_value=-1 265 | return best_paths[:,1:] 266 | 267 | @classmethod 268 | def compute_accuracy(cls,predicts,targets): 269 | ''' 270 | predicts and targets are instances of numpy. 271 | 272 | predicts shape:(batch_size,time_step) 273 | targets shape: (batch_size,time_step) 274 | ''' 275 | assert predicts.shape == targets.shape, "The dimensions of predicts and targets don't match." 276 | assert len(predicts.shape) == 2, "The dimension of predicts don't equal to 2." 277 | count = 0 278 | truth_ans_pos = [] 279 | hypothesis_ans_pos = [] 280 | for i in range(targets.shape[0]): 281 | target_answer_pos = [] 282 | predict_answer_pos = [] 283 | for j in range(targets.shape[1]): 284 | if targets[i,j] == tag_to_ix["B"]: 285 | target_answer_pos.append(j) 286 | for k in range(j+1,targets.shape[1]): 287 | if targets[i,k] != tag_to_ix["I"]: 288 | target_answer_pos.append(k-1) 289 | break 290 | break 291 | for j in range(predicts.shape[1]): 292 | if predicts[i,j] == tag_to_ix["B"]: 293 | predict_answer_pos.append(j) 294 | for k in range(j+1,predicts.shape[1]): 295 | if predicts[i,k] != tag_to_ix["I"]: 296 | predict_answer_pos.append(k-1) 297 | break 298 | break 299 | if target_answer_pos == predict_answer_pos: 300 | count += 1 301 | hypothesis_ans_pos.append(predict_answer_pos) 302 | truth_ans_pos.append(target_answer_pos) 303 | acc = count / targets.shape[0] 304 | return acc,hypothesis_ans_pos,truth_ans_pos --------------------------------------------------------------------------------