├── .gitignore
├── picture
├── train_acc.png
├── train_loss.png
└── valid_acc.png
├── README.md
├── make_tfrecords.py
├── train.py
└── modules.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | data/
3 | model/
4 | logs/
5 | WebQA.v1.0/
--------------------------------------------------------------------------------
/picture/train_acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeliangLi/Chinese-QASystem/HEAD/picture/train_acc.png
--------------------------------------------------------------------------------
/picture/train_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeliangLi/Chinese-QASystem/HEAD/picture/train_loss.png
--------------------------------------------------------------------------------
/picture/valid_acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeliangLi/Chinese-QASystem/HEAD/picture/valid_acc.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Chinese-QASystem
2 | Chinese question answering system based on BLSTM and CRF.
3 |
4 |
5 |
6 | Requirement
7 | -------
8 | tensorflow 1.5
9 | numpy
10 | thulac
11 | scikit-learn
12 | matplotlib
13 |
14 | DataSet
15 | -----
16 | 百度的中文问答数据集[WebQA](https://www.spaces.ac.cn/archives/4338/),非常感谢该链接的作者对数据的整理。
17 |
18 | How to get start?
19 | -----
20 | 1.Download the raw data and extract it to the folder where the source code is located.
21 | 2.python3 make_tfrecords.py.Processing the raw data to generate the tfrecord files for training and validating.
22 | In this experiment,I used 200,000 corpus to train and validate the accuracy of the model on 5000 corpus.
23 | 3.python3 train.py.All the training results as shown below.It is not hard to find that the model eventually
24 | achieved an accuracy of 0.6050 on the validation set.
25 |
26 | 
27 | 
28 | 
29 |
30 | Note
31 | ----
32 | In the future, I will write a blog to introduce this work and you will learn how to use tensorflow's tf.while_loop interface to implement conditional random field training and Viterbi decoding.
33 |
34 |
35 | References
36 | -----
37 | [Li P, Li W, He Z, et al. Dataset and Neural Recurrent Sequence Labeling Model for Open-Domain Factoid Question Answering[J]. 2016.](https://arxiv.org/abs/1607.06275)
38 |
39 |
40 |
--------------------------------------------------------------------------------
/make_tfrecords.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | """
4 | Author: Yeliang Li
5 | Blog: http://blog.yeliangli.com/
6 | Created: 2018/2/14
7 | """
8 |
9 | import tensorflow as tf
10 | from sklearn.externals import joblib
11 | from sklearn.model_selection import train_test_split
12 | import os
13 | import json
14 | import thulac
15 | import re
16 | import modules
17 |
18 |
19 | path = "./WebQA.v1.0"
20 | #{"B":0,"I":1,"O1":2,"O2":3,"START":4,"STOP":5}
21 | tag_to_ix = modules.tag_to_ix
22 | evidences_sampling_rate = 0.8
23 |
24 | def write(lang,file,writer,data_size):
25 | thul = thulac.thulac(seg_only=True)
26 | if not os.path.isdir("data"):
27 | os.mkdir("data")
28 | with open(file,'rt',encoding='utf-8') as file:
29 | dict = file.read()
30 | dict = json.loads(dict,encoding='utf-8')
31 | count = 0
32 | number = 0
33 | for key in dict:
34 | if count == data_size:
35 | break
36 | ques = []
37 | for word in thul.cut(dict[key]['question'],text=True).split():
38 | index = lang.addWord(word)
39 | ques.append(index)
40 | evidences = []
41 | evidences_tags = []
42 | for record in dict[key]['evidences']:
43 | number += 1
44 | print("number:%d question number:%s evidence number:%s" %(number,key,record))
45 | answer = dict[key]['evidences'][record]['answer'][0]
46 | evidence = dict[key]['evidences'][record]['evidence']
47 | if answer != 'no_answer':
48 | evidence = re.sub(answer,"XXX",evidence)
49 | answer = thul.cut(answer,text=True).split()
50 | answer_tags = []
51 | answer_indices = []
52 | for i in range(len(answer)):
53 | if i != 0:
54 | answer_tags.append(tag_to_ix['I'])
55 | else:
56 | answer_tags.append(tag_to_ix['B'])
57 | answer_indices.append(lang.addWord(answer[i]))
58 | evidence = thul.cut(evidence,text=True).split()
59 | evidence_indices = []
60 | evidence_tags = []
61 | before_answer = True
62 | for word in evidence:
63 | if word != "XXX":
64 | evidence_indices.append(lang.addWord(word))
65 | else:
66 | evidence_indices += answer_indices
67 | evidence_tags += answer_tags
68 | before_answer = False
69 | continue
70 | if before_answer:
71 | evidence_tags.append(tag_to_ix["O1"])
72 | else:
73 | evidence_tags.append(tag_to_ix["O2"])
74 | evidences.append(evidence_indices)
75 | evidences_tags.append(evidence_tags)
76 | else:
77 | evidence_indices = []
78 | evidence_tags = []
79 | for word in thul.cut(evidence,text=True).split():
80 | evidence_indices.append(lang.addWord(word))
81 | evidence_tags.append(tag_to_ix["O1"])
82 | evidences.append(evidence_indices)
83 | evidences_tags.append(evidence_tags)
84 | selected_evidences,rest_evidences,selected_evidences_tags,rest_evidences_tags = train_test_split(evidences,
85 | evidences_tags,
86 | test_size=1-evidences_sampling_rate,
87 | random_state=0
88 | )
89 | count += len(selected_evidences)
90 | if count > data_size:
91 | count -= len(selected_evidences)
92 | selected_evidences = selected_evidences[0:(data_size - count)]
93 | count = data_size
94 | for i in range(len(selected_evidences)):
95 | e_e_comm_fea = []
96 | q_e_comm_fea = []
97 | for index in selected_evidences[i]:
98 | if index in ques:
99 | q_e_comm_fea.append(1)
100 | else:
101 | q_e_comm_fea.append(0)
102 | comm_tag = False
103 | for evidence in rest_evidences:
104 | if index in evidence:
105 | e_e_comm_fea.append(1)
106 | comm_tag = True
107 | break
108 | if not comm_tag:
109 | e_e_comm_fea.append(0)
110 | feas = {}
111 | feas['question'] = tf.train.Feature(int64_list=tf.train.Int64List(value=ques))
112 | feas['evidence'] = tf.train.Feature(int64_list=tf.train.Int64List(value=selected_evidences[i]))
113 | feas['evidence_tags'] = tf.train.Feature(int64_list=tf.train.Int64List(value=selected_evidences_tags[i]))
114 | feas['q_e_comm'] = tf.train.Feature(int64_list=tf.train.Int64List(value=q_e_comm_fea))
115 | feas['e_e_comm'] = tf.train.Feature(int64_list=tf.train.Int64List(value=e_e_comm_fea))
116 | feas['question_length'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[len(ques)]))
117 | feas['evidence_length'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[len(selected_evidences[i])]))
118 | features_to_write = tf.train.Example(features=tf.train.Features(feature=feas))
119 | writer.write(features_to_write.SerializeToString())
120 | writer.close()
121 | return count
122 |
123 |
124 | if __name__ == "__main__":
125 | lang = modules.Lang("Chi")
126 | if not os.path.isdir("data"):
127 | os.mkdir("data")
128 |
129 | writer = tf.python_io.TFRecordWriter("./data/trainData.tfrecords")
130 | train_data_size = write(lang,os.path.join(path,"me_train.json"),writer,200000)
131 | joblib.dump(train_data_size,"./data/trainDataSize.pkl")
132 |
133 | writer = tf.python_io.TFRecordWriter("./data/validData.tfrecords")
134 | valid_data_size = write(lang,os.path.join(path,"me_validation.ir.json"),writer,5000)
135 | joblib.dump(valid_data_size,"./data/validDataSize.pkl")
136 |
137 | joblib.dump(lang,"./data/lang.pkl")
138 |
139 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | """
4 | Author: Yeliang Li
5 | Blog: http://blog.yeliangli.com/
6 | Created: 2018/2/14
7 | """
8 |
9 | import tensorflow as tf
10 | import numpy as np
11 | from sklearn.externals import joblib
12 | import os
13 | import modules
14 |
15 | tf.app.flags.DEFINE_integer("batch_size",100,"batch size for training")
16 | tf.app.flags.DEFINE_integer("epoches",17,"epoches for training")
17 | tf.app.flags.DEFINE_integer("buffer_size",500000,"representing the number of bytes in the read buffer")
18 | tf.app.flags.DEFINE_integer("num_parallel_calls",100," representing the number elements to process in parallel")
19 | tf.app.flags.DEFINE_integer("word_embedding",64,"word embedding")
20 | tf.app.flags.DEFINE_integer("feature_embedding",2,"feature embedding")
21 | tf.app.flags.DEFINE_integer("cell_size",64,"cell size of LSTM")
22 | tf.app.flags.DEFINE_float("dropout_rate",0.05,"dropout rate for output of all the LSTM layers")
23 | tf.app.flags.DEFINE_bool("training",True,"whether to train the model")
24 | tf.app.flags.DEFINE_float("lr",0.005,"learning rate for training")
25 | tf.app.flags.DEFINE_float("decay_steps",8000,"decay steps for learning rate")
26 | tf.app.flags.DEFINE_float("decay_rate",0.75,"decay rate for learning rate")
27 | FLAGS = tf.app.flags.FLAGS
28 |
29 | def parse(serialized):
30 | features = {}
31 | features["question"] = tf.VarLenFeature(tf.int64)
32 | features["evidence"] = tf.VarLenFeature(tf.int64)
33 | features["evidence_tags"] = tf.VarLenFeature(tf.int64)
34 | features["q_e_comm"] = tf.VarLenFeature(tf.int64)
35 | features["e_e_comm"] = tf.VarLenFeature(tf.int64)
36 | features["question_length"] = tf.FixedLenFeature([1],tf.int64)
37 | features['evidence_length'] = tf.FixedLenFeature([1],tf.int64)
38 | features = tf.parse_single_example(serialized,features)
39 | question = tf.sparse_tensor_to_dense(features["question"])
40 | evidence = tf.sparse_tensor_to_dense(features["evidence"])
41 | evidence_tags = tf.sparse_tensor_to_dense(features["evidence_tags"])
42 | q_e_comm = tf.sparse_tensor_to_dense(features["q_e_comm"])
43 | e_e_comm = tf.sparse_tensor_to_dense(features["e_e_comm"])
44 | return question,evidence,evidence_tags,q_e_comm,e_e_comm,features["question_length"],features["evidence_length"]
45 |
46 |
47 | def main(_):
48 | lang = joblib.load("./data/lang.pkl")
49 | files = ["./data/trainData.tfrecords","./data/validData.tfrecords"]
50 | padded_shapes = tuple([[-1]]*5+[[1]]*2)
51 | padding_values = tuple([np.int64(0)]*2+[np.int64(-1)]+[np.int64(0)]*4)
52 | questions_list = []
53 | evidences_list = []
54 | targets_tags_list = []
55 | q_e_comm_feas_list = []
56 | e_e_comm_feas_list = []
57 | question_length_list = []
58 | evidence_length_list = []
59 |
60 | for i in range(2):
61 | dataset = tf.data.TFRecordDataset(files[i],buffer_size=FLAGS.buffer_size).map(parse,FLAGS.num_parallel_calls).padded_batch(FLAGS.batch_size,padded_shapes,padding_values).repeat(FLAGS.epoches)
62 | iterator = dataset.make_one_shot_iterator()
63 | input_elements = iterator.get_next()
64 | questions_list.append(input_elements[0])
65 | evidences_list.append(input_elements[1])
66 | targets_tags_list.append(input_elements[2])
67 | q_e_comm_feas_list.append(input_elements[3])
68 | e_e_comm_feas_list.append(input_elements[4])
69 | question_length_list.append(tf.reshape(input_elements[5],[-1,]))
70 | evidence_length_list.append(tf.reshape(input_elements[6],[-1,]))
71 |
72 | #model definition
73 | training = tf.placeholder(tf.bool)
74 | word_embedding = modules.Embedding([lang.n_words,FLAGS.word_embedding],"word_embedding")
75 | q_e_comm_feas_embedding = modules.Embedding([2,FLAGS.feature_embedding],
76 | "q_e_comm_feas_embedding")
77 | e_e_comm_feas_embedding = modules.Embedding([2,FLAGS.feature_embedding],
78 | "e_e_comm_feas_embedding")
79 | question_LSTM = modules.QuestionLSTM(FLAGS.cell_size,
80 | FLAGS.dropout_rate,
81 | training)
82 | evidence_LSTMs = modules.EvidenceLSTMs(FLAGS.cell_size,FLAGS.dropout_rate,training)
83 | crf = modules.CRF()
84 |
85 | with tf.device("/gpu:0"):
86 | reuse = False
87 | decoded_list = []
88 | for i in range(2):
89 | with tf.variable_scope("model",reuse=reuse):
90 | questions = word_embedding(questions_list[i]) #(batch_size,max_ques_length,FLAGS.word_embedding)
91 | evidences = word_embedding(evidences_list[i]) #(batch_size,max_evidences_length,FLAGS.word_embedding)
92 | q_e_comm_feas = q_e_comm_feas_embedding(q_e_comm_feas_list[i]) #(batch_size,max_q_e_comm_feas_length,FLAGS.feature_embedding)
93 | e_e_comm_feas = e_e_comm_feas_embedding(e_e_comm_feas_list[i]) #(batch_size,max_e_e_comm_feas_length,FLAGS.feature_embedding)
94 | q_r = question_LSTM(questions,question_length_list[i]) #(batch_size,FLAGS.cell_size)
95 | q_r = tf.tile(tf.expand_dims(q_r,axis=1),[1,tf.shape(evidences)[1],1])
96 | x = tf.concat([evidences,q_r,q_e_comm_feas,e_e_comm_feas],axis=2)
97 | outputs = evidence_LSTMs(x,evidence_length_list[i])
98 | crf.build()
99 | decoded = crf.viterbi_decode(outputs,evidence_length_list[i]) #(batch_size,max_evidences_length)
100 | decoded_list.append(decoded)
101 | if i == 0:
102 | logits = crf.neg_log_likelihood(outputs,evidence_length_list[i],targets_tags_list[i]) #(batch_size,)
103 | loss = tf.reduce_sum(logits)
104 | loss /= tf.to_float(tf.shape(evidences)[0])
105 | with tf.name_scope("optimizer"):
106 | global_step = tf.Variable(tf.constant(0,tf.int32),trainable=False)
107 | lr = tf.train.exponential_decay(FLAGS.lr,
108 | global_step,
109 | FLAGS.decay_steps,
110 | FLAGS.decay_rate,
111 | staircase=True)
112 | optimizer = tf.train.RMSPropOptimizer(lr)
113 | tvars = tf.trainable_variables()
114 | grads = tf.gradients(loss,tvars)
115 | train_op = optimizer.apply_gradients(zip(grads,tvars),global_step)
116 | reuse = True
117 |
118 | record_loss = tf.placeholder(tf.float32)
119 | record_accuracy = tf.placeholder(tf.float32)
120 | train_merged = []
121 | train_merged.append(tf.summary.scalar("train_loss",record_loss))
122 | train_merged.append(tf.summary.scalar("train_accuracy",record_accuracy))
123 | train_merged = tf.summary.merge(train_merged)
124 | valid_summary = tf.summary.scalar("valid_accuracy",record_accuracy)
125 |
126 | saver = tf.train.Saver(var_list=tvars)
127 | log_device_placement=True
128 | allow_soft_placement=True
129 | config = tf.ConfigProto(log_device_placement=True,allow_soft_placement=True)
130 | config.gpu_options.allow_growth = True
131 | with tf.Session(config=config) as sess:
132 | sess.run(tf.global_variables_initializer())
133 | train_fileWriter = tf.summary.FileWriter("logs/train",sess.graph)
134 | valid_fileWriter = tf.summary.FileWriter("logs/valid",sess.graph)
135 | train_size = joblib.load("./data/trainDataSize.pkl")
136 | valid_size = joblib.load("./data/validDataSize.pkl")
137 | assert train_size % FLAGS.batch_size == 0,"train_size can't be divisible by batch."
138 | assert valid_size % FLAGS.batch_size == 0,"valid_size can't be divisible by batch."
139 | try:
140 | for i in range(FLAGS.epoches):
141 | total_loss = 0.0
142 | total_accuracy = 0.0
143 | for j in range(train_size // FLAGS.batch_size):
144 | _,cost,hypothesis,truth,step = sess.run([train_op,loss,decoded_list[0],targets_tags_list[0],global_step],{training:FLAGS.training})
145 | accuracy,_,_ = crf.compute_accuracy(hypothesis,truth)
146 | total_loss += cost
147 | total_accuracy += accuracy
148 | print("global_step:%d epoch:%d batch:%d train_loss:%f train_accuracy:%f"
149 | %(step,i+1,j+1,cost,accuracy))
150 | if step % 500 == 0:
151 | saver.save(sess,"./model/QASystem",step)
152 | total_loss /= (train_size / FLAGS.batch_size)
153 | total_accuracy /= (train_size / FLAGS.batch_size)
154 | summary = sess.run(train_merged,{record_loss:total_loss,record_accuracy:total_accuracy})
155 | train_fileWriter.add_summary(summary,i+1)
156 |
157 | total_loss = 0.0
158 | total_accuracy = 0.0
159 | print("\n")
160 | for j in range(valid_size // FLAGS.batch_size):
161 | hypothesis,truth = sess.run([decoded_list[1],targets_tags_list[1]],{training:False})
162 | accuracy,_,_ = crf.compute_accuracy(hypothesis,truth)
163 | total_accuracy += accuracy
164 | print("epoch:%d batch:%d valid_accuracy:%f" %(i+1,j+1,accuracy))
165 | total_accuracy /= (valid_size / FLAGS.batch_size)
166 | summary = sess.run(valid_summary,{record_accuracy:total_accuracy})
167 | valid_fileWriter.add_summary(summary,i+1)
168 | print("\n")
169 | except BaseException:
170 | saver.save(sess,"./model/QASystem")
171 | saver.save(sess,"./model/QASystem")
172 |
173 | if __name__ == "__main__":
174 | tf.app.run()
175 |
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | """
4 | Author: Yeliang Li
5 | Blog: http://blog.yeliangli.com/
6 | Created: 2018/2/14
7 | """
8 |
9 | import tensorflow as tf
10 | import tensorflow.contrib.eager as tfe
11 | from tensorflow.python.layers import base
12 |
13 | class Lang(object):
14 | def __init__(self, name):
15 | self.name = name
16 | self.word2index = {}
17 | self.word2count = {}
18 | self.index2word = {0:'PAD',1:'UNK'}
19 | self.n_words = 2 # Count PAD and UNK
20 |
21 | def addWord(self, word):
22 | if word not in self.word2index:
23 | self.word2index[word] = self.n_words
24 | self.word2count[word] = 1
25 | self.index2word[self.n_words] = word
26 | self.n_words += 1
27 | else:
28 | self.word2count[word] += 1
29 | return self.word2index[word]
30 |
31 | class Embedding(base.Layer):
32 | def __init__(self,embedding_shape,name,trainable=True):
33 | super(Embedding,self).__init__(trainable=trainable,name=name)
34 | self.embedding_shape = embedding_shape
35 |
36 | def build(self,_):
37 | assert len(self.embedding_shape) == 2, "The length of embedding_shape is not equal to 2."
38 | self.embedding = self.add_variable("embedding",
39 | self.embedding_shape,
40 | tf.float32
41 | )
42 |
43 | def call(self, ids):
44 | outputs = tf.nn.embedding_lookup(self.embedding,ids)
45 | return outputs
46 |
47 | class AttentionLayer(base.Layer):
48 | def __init__(self,
49 | num_units,
50 | score_value=float("-inf"),
51 | trainable=True,
52 | name="attentionLayer"):
53 | super(AttentionLayer,self).__init__(trainable=trainable,name=name)
54 | self.num_units = num_units
55 | self.score_value = score_value
56 |
57 | def build(self,input_shape):
58 | assert len(input_shape.as_list()) == 3, "The dimension of inputs is not equal to 3."
59 | self.W = self.add_variable("W",[input_shape[-1].value,self.num_units],tf.float32)
60 | self.v = self.add_variable("v",[self.num_units,1],tf.float32)
61 |
62 | def call(self,inputs,sequence_length=None):
63 | W = tf.tile(tf.expand_dims(self.W,0),[tf.shape(inputs)[0],1,1])
64 | v = tf.tile(tf.expand_dims(self.v,0),[tf.shape(inputs)[0],1,1])
65 | scores = tf.squeeze(tf.matmul(tf.nn.tanh(tf.matmul(inputs,W)),v),
66 | axis=2
67 | )
68 | if sequence_length != None:
69 | mask = tf.sequence_mask(sequence_length,maxlen=tf.shape(inputs)[1])
70 | scores = tf.where(mask,scores,tf.ones_like(scores)*self.score_value)
71 | alignments = tf.nn.softmax(scores)
72 | r = tf.squeeze(tf.matmul(tf.expand_dims(alignments,axis=1),inputs),axis=1)
73 | return r
74 |
75 | class QuestionLSTM():
76 | def __init__(self,
77 | num_units,
78 | dropout_rate=0.0,
79 | training = True,
80 | name="questionLSTM"):
81 | self.dropout_rate = dropout_rate
82 | self.training = training
83 | self.name = name
84 | self.lstm_cell = tf.contrib.rnn.LSTMCell(num_units,use_peepholes=True)
85 | self.attention = AttentionLayer(num_units)
86 |
87 | def __call__(self,inputs,sequence_length):
88 | with tf.variable_scope(self.name):
89 | state = self.lstm_cell.zero_state(tf.shape(inputs)[0],tf.float32)
90 | outputs,_ = tf.nn.dynamic_rnn(self.lstm_cell,
91 | inputs,
92 | sequence_length,
93 | state,
94 | tf.float32)
95 | outputs = tf.layers.dropout(outputs,self.dropout_rate,training=self.training)
96 | q_r = self.attention(outputs,sequence_length)
97 | return q_r
98 |
99 | class EvidenceLSTMs():
100 | def __init__(self,
101 | num_units,
102 | dropout_rate = 0.0,
103 | training = True,
104 | name="evidenceLSTMs"):
105 | self.layer1 = tf.contrib.rnn.LSTMCell(num_units,use_peepholes=True,name="layer1")
106 | self.layer2 = tf.contrib.rnn.LSTMCell(num_units,use_peepholes=True,name="layer2")
107 | self.layer3 = tf.contrib.rnn.LSTMCell(num_units,use_peepholes=True,name="layer3")
108 | self.output_dense = tf.layers.Dense(len(tag_to_ix))
109 | self.dropout_rate = dropout_rate
110 | self.training = training
111 | self.name = name
112 |
113 | def __call__(self,inputs,sequence_length):
114 | with tf.variable_scope(self.name):
115 | state = self.layer1.zero_state(tf.shape(inputs)[0],tf.float32)
116 | layer1_outputs,_ = tf.nn.dynamic_rnn(self.layer1,
117 | inputs,
118 | sequence_length,
119 | state,
120 | tf.float32)
121 | layer1_outputs = tf.layers.dropout(layer1_outputs,self.dropout_rate,training=self.training)
122 | layer1_outputs_reversed = tf.reverse_sequence(layer1_outputs,sequence_length,seq_dim=1)
123 | layer2_outputs,_ = tf.nn.dynamic_rnn(self.layer2,
124 | layer1_outputs_reversed,
125 | sequence_length,
126 | state,
127 | tf.float32)
128 | layer2_outputs = tf.reverse_sequence(layer2_outputs,sequence_length,seq_dim=1)
129 | layer2_outputs = tf.layers.dropout(layer2_outputs,self.dropout_rate,training=self.training)
130 | layer3_inputs = tf.concat([layer1_outputs,layer2_outputs],axis=2)
131 | outputs,_ = tf.nn.dynamic_rnn(self.layer3,
132 | layer3_inputs,
133 | sequence_length,
134 | state,
135 | tf.float32)
136 | outputs = tf.layers.dropout(outputs,self.dropout_rate,training=self.training)
137 | outputs = self.output_dense(outputs)
138 | return outputs
139 |
140 | tag_to_ix = {"B":0,"I":1,"O1":2,"O2":3,"START":4,"STOP":5}
141 |
142 | def log_sum_exp(inputs):
143 | max_scores = tf.reduce_max(inputs,axis=1,keepdims=True)
144 | return tf.squeeze(max_scores,1) + tf.log(tf.reduce_sum(tf.exp(inputs - max_scores),axis=1))
145 |
146 | class CRF():
147 | def __init__(self,name="crf"):
148 | self.name = name
149 | def build(self):
150 | with tf.variable_scope(self.name):
151 | self.transitions = tf.get_variable("transitions",
152 | [len(tag_to_ix),len(tag_to_ix)],
153 | tf.float32)
154 | indices = []
155 | updates = []
156 | for i in range(len(tag_to_ix)):
157 | indices.append([i,tag_to_ix["START"]])
158 | indices.append([tag_to_ix["STOP"],i])
159 | updates += [-1e+8,-1e+8]
160 |
161 | self.transitions = tf.scatter_nd_update(self.transitions,indices,updates)
162 | self.init_alphas = tf.get_variable("init_alphas",
163 | [len(tag_to_ix)],
164 | tf.float32,
165 | tf.constant_initializer(-1e+8),
166 | trainable=False)
167 | self.init_alphas = tf.scatter_update(self.init_alphas,
168 | [tag_to_ix["START"]],
169 | [tf.constant(0.0)])
170 |
171 | def neg_log_likelihood(self,inputs,sequence_length,targets):
172 | sequence_length = tf.to_int32(sequence_length)
173 | targets = tf.to_int32(targets)
174 | i0 = tf.constant(0,tf.int32)
175 | alphas_0 = tf.tile(tf.expand_dims(self.init_alphas,0),[tf.shape(inputs)[0],1])
176 | max_seq_len = tf.reduce_max(sequence_length)
177 | alphas_array = tf.TensorArray(tf.float32,size=max_seq_len)
178 |
179 | scores_array = tf.TensorArray(tf.float32,size=max_seq_len)
180 | initial_scores = alphas_0[:,tag_to_ix["START"]]
181 | targets = tf.concat([tf.ones([tf.shape(inputs)[0],1],tf.int32)*tag_to_ix["START"],
182 | targets],
183 | axis=1)
184 | def body(i,alphas_t,ta1,scores,ta2):
185 | inp = inputs[:,i,:] #(batch_size,len(tag_to_ix))
186 | emit_scores = tf.tile(tf.expand_dims(inp,1),[1,len(tag_to_ix),1])
187 | forward_vars = tf.tile(tf.expand_dims(alphas_t,2),[1,1,len(tag_to_ix)])
188 | next_tag_vars = forward_vars + self.transitions + emit_scores
189 | alphas_t_plus_1 = log_sum_exp(next_tag_vars) #(batch_size,len(tag_to_ix))
190 |
191 | indices = tf.stack([tf.range(0,tf.shape(inputs)[0],dtype=tf.int32),
192 | targets[:,i+1]])
193 | indices = tf.transpose(indices,[1,0])
194 | scores += tf.gather_nd(self.transitions,targets[:,i:i+2]) +\
195 | tf.gather_nd(inp,indices)
196 | scores = tf.reshape(scores,[tf.shape(inputs)[0]])
197 | return i+1,alphas_t_plus_1,ta1.write(i,alphas_t_plus_1),scores,ta2.write(i,scores)
198 | _,_,ta1_final,_,ta2_final = tf.while_loop(lambda i,alphas_t,ta1,scores,ta2: i < max_seq_len,
199 | body,
200 | loop_vars=[i0,alphas_0,alphas_array,initial_scores,scores_array])
201 |
202 | ta1_final_result = ta1_final.stack()
203 | indices1 = tf.stack([sequence_length-1,
204 | tf.range(0,tf.shape(inputs)[0],dtype=tf.int32)])
205 | indices1 = tf.transpose(indices1,[1,0])
206 | terminal_vars = tf.gather_nd(ta1_final_result,indices1) + self.transitions[:,tag_to_ix["STOP"]]
207 | forward_scores = log_sum_exp(terminal_vars) #(batch_size)
208 |
209 | ta2_final_result = ta2_final.stack()
210 | indices2 = tf.stack([tf.gather_nd(tf.transpose(targets,[1,0]),indices1),
211 | tf.ones([tf.shape(inputs)[0]],tf.int32)*tag_to_ix["STOP"]])
212 | indices2 = tf.transpose(indices2,[1,0])
213 | gold_scores = tf.gather_nd(ta2_final_result,indices1) +\
214 | tf.gather_nd(self.transitions,indices2)
215 |
216 | return forward_scores - gold_scores # (batch_size,)
217 |
218 | def viterbi_decode(self,inputs,sequence_length):
219 | sequence_length = tf.to_int32(sequence_length)
220 | alphas_0 = tf.tile(tf.expand_dims(self.init_alphas,0),
221 | [tf.shape(inputs)[0],1]) #(batch_size,len(tag_to_ix))
222 | i0 = tf.constant(0,tf.int32)
223 | max_seq_len = tf.reduce_max(sequence_length)
224 | best_tag_ids_array = tf.TensorArray(tf.int32,size=max_seq_len)
225 | alphas_array = tf.TensorArray(tf.float32,size=max_seq_len)
226 |
227 | def body1(i,alphas_t,ta1,ta2):
228 | forward_vars = tf.tile(tf.expand_dims(alphas_t,2),[1,1,len(tag_to_ix)])
229 | next_tag_vars = forward_vars + self.transitions
230 | best_tag_ids = tf.argmax(next_tag_vars,axis=1,output_type=tf.int32) #(batch_size,len(tag_to_ix))
231 | alphas_t_plus_1 = tf.reduce_max(next_tag_vars,axis=1) + inputs[:,i,:] #(batch_size,len(tag_to_ix))
232 | return i+1,alphas_t_plus_1,ta1.write(i,best_tag_ids),ta2.write(i,alphas_t_plus_1)
233 | _,_,ta1_final,ta2_final = tf.while_loop(lambda i,alpha_t,ta1,ta2: i < max_seq_len,
234 | body1,
235 | loop_vars=[i0,alphas_0,best_tag_ids_array,alphas_array])
236 |
237 | bptrs = ta1_final.stack()
238 |
239 | ta2_final_result = ta2_final.stack()
240 | indices = tf.stack([sequence_length-1,
241 | tf.range(0,tf.shape(inputs)[0],dtype=tf.int32)])
242 | indices = tf.transpose(indices,[1,0])
243 | terminal_vars = tf.gather_nd(ta2_final_result,indices) + self.transitions[:,tag_to_ix["STOP"]]
244 | terminal_id = tf.argmax(terminal_vars,axis=1,output_type=tf.int32) #(batch_size)
245 |
246 | path_array = tf.TensorArray(tf.int32,size=tf.shape(inputs)[0])
247 | def body2(i,ta):
248 | def sub_body(j,past_id,best_path):
249 | best_tag_id = tf.gather_nd(bptrs,[[j,i]]) #(1,len(tag_to_ix))
250 | next_id = tf.gather_nd(best_tag_id,[[0,past_id]])
251 | best_path = tf.concat([next_id,best_path],axis=0)
252 | return j-1,next_id[0],best_path
253 | _,_,best_path = tf.while_loop(lambda j,past_id,best_path: j >= 0,
254 | sub_body,
255 | loop_vars=[sequence_length[i] - 1,terminal_id[i],terminal_id[i:i+1]],
256 | shape_invariants=[tf.TensorShape([]),tf.TensorShape([]),tf.TensorShape([None])]
257 | )
258 | best_path = tf.pad(best_path,[[0,max_seq_len-sequence_length[i]]],constant_values=-1)
259 | return i+1,ta.write(i,best_path)
260 | _,ta_final = tf.while_loop(lambda i,ta: i < tf.shape(inputs)[0],
261 | body2,
262 | loop_vars=[i0,path_array])
263 |
264 | best_paths = ta_final.stack() #shape=(batch_size,1+max_seq_len),padding_value=-1
265 | return best_paths[:,1:]
266 |
267 | @classmethod
268 | def compute_accuracy(cls,predicts,targets):
269 | '''
270 | predicts and targets are instances of numpy.
271 |
272 | predicts shape:(batch_size,time_step)
273 | targets shape: (batch_size,time_step)
274 | '''
275 | assert predicts.shape == targets.shape, "The dimensions of predicts and targets don't match."
276 | assert len(predicts.shape) == 2, "The dimension of predicts don't equal to 2."
277 | count = 0
278 | truth_ans_pos = []
279 | hypothesis_ans_pos = []
280 | for i in range(targets.shape[0]):
281 | target_answer_pos = []
282 | predict_answer_pos = []
283 | for j in range(targets.shape[1]):
284 | if targets[i,j] == tag_to_ix["B"]:
285 | target_answer_pos.append(j)
286 | for k in range(j+1,targets.shape[1]):
287 | if targets[i,k] != tag_to_ix["I"]:
288 | target_answer_pos.append(k-1)
289 | break
290 | break
291 | for j in range(predicts.shape[1]):
292 | if predicts[i,j] == tag_to_ix["B"]:
293 | predict_answer_pos.append(j)
294 | for k in range(j+1,predicts.shape[1]):
295 | if predicts[i,k] != tag_to_ix["I"]:
296 | predict_answer_pos.append(k-1)
297 | break
298 | break
299 | if target_answer_pos == predict_answer_pos:
300 | count += 1
301 | hypothesis_ans_pos.append(predict_answer_pos)
302 | truth_ans_pos.append(target_answer_pos)
303 | acc = count / targets.shape[0]
304 | return acc,hypothesis_ans_pos,truth_ans_pos
--------------------------------------------------------------------------------