├── README.md ├── predict.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # CCL_CMRC2017 2 | 3 | 第一届“讯飞杯”中文机器阅读理解评测参考模型 4 | 5 | http://kexue.fm/archives/4564/ 6 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #! -*- coding:utf-8 -*- 2 | 3 | import pickle 4 | import numpy as np 5 | 6 | id2word,word2id,embedding_array = pickle.load(open('model.config')) 7 | word_size = embedding_array.shape[1] 8 | 9 | import tensorflow as tf 10 | 11 | padding_vec = tf.Variable(tf.random_uniform([1, word_size], -0.05, 0.05)) 12 | embeddings = tf.constant(embedding_array, dtype=tf.float32) 13 | embeddings = tf.concat([padding_vec,embeddings], 0) 14 | 15 | L_context = tf.placeholder(tf.int32, shape=[None,None]) 16 | L_context_length = tf.placeholder(tf.int32, shape=[None]) 17 | R_context = tf.placeholder(tf.int32, shape=[None,None]) 18 | R_context_length = tf.placeholder(tf.int32, shape=[None]) 19 | 20 | L_context_vec = tf.nn.embedding_lookup(embeddings, L_context) 21 | R_context_vec = tf.nn.embedding_lookup(embeddings, R_context) 22 | 23 | def add_brnn(inputs, rnn_size, seq_lens, name): 24 | rnn_cell_fw = tf.contrib.rnn.BasicLSTMCell(rnn_size) 25 | rnn_cell_bw = tf.contrib.rnn.BasicLSTMCell(rnn_size) 26 | outputs = [] 27 | with tf.variable_scope(name_or_scope=name) as vs: 28 | for input,seq_len in zip(inputs,seq_lens): 29 | outputs.append(tf.nn.bidirectional_dynamic_rnn(rnn_cell_fw, rnn_cell_bw, input, sequence_length=seq_len, dtype=tf.float32)) 30 | vs.reuse_variables() 31 | return [tf.concat(o[0],2) for o in outputs], [o[1] for o in outputs] 32 | 33 | [L_outputs,R_outputs],[L_final_state,R_final_state] = add_brnn([L_context_vec,R_context_vec], word_size, [L_context_length,R_context_length], name='LSTM_1') 34 | [L_outputs,R_outputs],[L_final_state,R_final_state] = add_brnn([L_outputs,R_outputs], word_size, [L_context_length,R_context_length], name='LSTM_2') 35 | 36 | L_context_mask = (1-tf.cast(tf.sequence_mask(L_context_length), tf.float32))*(-1e12) 37 | R_context_mask = (1-tf.cast(tf.sequence_mask(R_context_length), tf.float32))*(-1e12) 38 | context_mask = tf.concat([L_context_mask,R_context_mask], 1) 39 | 40 | outputs = tf.concat([L_outputs,R_outputs], 1) 41 | final_state = (tf.concat([L_final_state[0][1], L_final_state[1][1]], 1) + tf.concat([R_final_state[0][1], R_final_state[1][1]], 1))/2 42 | attention = context_mask + tf.matmul(outputs, tf.expand_dims(final_state, 2))[:,:,0] 43 | sample_labels = tf.placeholder(tf.float32, shape=[None,None]) 44 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=sample_labels, logits=attention)) 45 | pred = tf.nn.softmax(attention) 46 | 47 | train_step = tf.train.AdamOptimizer().minimize(loss) 48 | init = tf.global_variables_initializer() 49 | sess = tf.Session() 50 | sess.run(init) 51 | 52 | saver = tf.train.Saver() 53 | saver.restore(sess, './tk/tk_highest.ckpt') 54 | 55 | import re 56 | def split_data(text): 57 | words = re.split('[ \n]+', text) 58 | idx = words.index('XXXXX') 59 | return words[:idx],words[idx+1:] 60 | 61 | def cumsum_proba(x, y): 62 | tmp = {} 63 | for i,j in zip(x, y): 64 | if i in tmp: 65 | tmp[i] += j 66 | else: 67 | tmp[i] = j 68 | return tmp.keys()[np.argmax(tmp.values())] 69 | 70 | def predict(text): #输入的text为字符串,用空格隔开分词结果,待填空位置用XXXXX表示 71 | text = split_data(text) 72 | text = [word2id[i] for i in text[0]] if text[0] else [0], [word2id[i] for i in text[1]] if text[1] else [0] 73 | p = sess.run(pred, feed_dict={L_context:[text[0]], R_context:[text[1]], L_context_length:[len(text[0])], R_context_length:[len(text[1])]}) 74 | return id2word.get(cumsum_proba(text[0]+text[1], p[0]),' ') 75 | 76 | 77 | if __name__ == '__main__': 78 | 79 | import codecs 80 | import os 81 | import sys 82 | 83 | vaild_name = sys.argv[1] 84 | output_name = sys.argv[2] 85 | 86 | text = codecs.open(vaild_name, encoding='utf-8').read() 87 | valid_x = re.split('', text) 129 | s = '\n'.join(names[i]+' ||| '+id2word.get(j,' ') for i,j in enumerate(valid_result)) 130 | with codecs.open(output_name, 'w', encoding='utf-8') as f: 131 | f.write(s) 132 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #! -*- coding:utf-8 -*- 2 | 3 | import codecs 4 | import re 5 | import os 6 | import numpy as np 7 | 8 | def split_data(text): 9 | words = re.split('[ \n]+', text) 10 | idx = words.index('XXXXX') 11 | return words[:idx],words[idx+1:] 12 | 13 | print u'正在读取训练语料...' 14 | train_x = codecs.open('../CMRC2017_train/train.doc_query', encoding='utf-8').read() 15 | train_x = re.split('