├── README.md
├── sentiment.py
├── subject_extract.py
├── relation_extract.py
└── nl2sql_baseline.py
/README.md:
--------------------------------------------------------------------------------
1 | # bert_in_keras
2 | 用Keras来调用Bert,这可能是最简单的Bert打开姿势。
3 |
4 | ##
5 | - sentiment.py:情感分析例子,详细请看这里。
6 | - relation_extract.py:关系抽取例子,详细请看这里。
7 | - subject_extract.py:主体抽取例子,详细请看这里。
8 | - nl2sql_baseline.py:NL2SQL例子,详细请看这里。
9 |
10 | ## 详细介绍
11 | - https://kexue.fm/archives/6736
12 | - https://kexue.fm/archives/6771
13 |
14 | ## 测试环境
15 | python 2.7 + tensorflow 1.13 + keras 2.2.4
16 |
17 | ## keras_bert
18 | - https://github.com/CyberZHG/keras-bert
19 |
20 | ## 中文版权重
21 | - 官方版: https://github.com/google-research/bert
22 | - 哈工大版: https://github.com/ymcui/Chinese-BERT-wwm
23 |
24 | ## 严正声明
25 | - 不欢迎任何NLP和Keras文盲来跑此代码!!你都要玩Bert了,我认为你学习NLP的时间好歹要在半年以上,你学习Keras的时间好歹要一周以上。别想着一蹴而就,不欢迎只想调包跑通的人,不要用任何“我时间紧”的借口。
26 | - Keras是简单,不代表不需要学NLP,不代表不需要学Keras,不代表就可以不经大脑。一句话,请尊重你自己的智商。
27 |
28 | ## 在线交流
29 | QQ交流群:67729435,微信群请加机器人微信号spaces_ac_cn
30 |
--------------------------------------------------------------------------------
/sentiment.py:
--------------------------------------------------------------------------------
1 | #! -*- coding:utf-8 -*-
2 |
3 | import json
4 | import numpy as np
5 | import pandas as pd
6 | from random import choice
7 | from keras_bert import load_trained_model_from_checkpoint, Tokenizer
8 | import re, os
9 | import codecs
10 |
11 |
12 | maxlen = 100
13 | config_path = '../bert/chinese_L-12_H-768_A-12/bert_config.json'
14 | checkpoint_path = '../bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
15 | dict_path = '../bert/chinese_L-12_H-768_A-12/vocab.txt'
16 |
17 |
18 | token_dict = {}
19 |
20 | with codecs.open(dict_path, 'r', 'utf8') as reader:
21 | for line in reader:
22 | token = line.strip()
23 | token_dict[token] = len(token_dict)
24 |
25 |
26 | class OurTokenizer(Tokenizer):
27 | def _tokenize(self, text):
28 | R = []
29 | for c in text:
30 | if c in self._token_dict:
31 | R.append(c)
32 | elif self._is_space(c):
33 | R.append('[unused1]') # space类用未经训练的[unused1]表示
34 | else:
35 | R.append('[UNK]') # 剩余的字符是[UNK]
36 | return R
37 |
38 | tokenizer = OurTokenizer(token_dict)
39 |
40 |
41 | neg = pd.read_excel('neg.xls', header=None)
42 | pos = pd.read_excel('pos.xls', header=None)
43 |
44 | data = []
45 |
46 | for d in neg[0]:
47 | data.append((d, 0))
48 |
49 | for d in pos[0]:
50 | data.append((d, 1))
51 |
52 |
53 | # 按照9:1的比例划分训练集和验证集
54 | random_order = range(len(data))
55 | np.random.shuffle(random_order)
56 | train_data = [data[j] for i, j in enumerate(random_order) if i % 10 != 0]
57 | valid_data = [data[j] for i, j in enumerate(random_order) if i % 10 == 0]
58 |
59 |
60 | def seq_padding(X, padding=0):
61 | L = [len(x) for x in X]
62 | ML = max(L)
63 | return np.array([
64 | np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X
65 | ])
66 |
67 |
68 | class data_generator:
69 | def __init__(self, data, batch_size=32):
70 | self.data = data
71 | self.batch_size = batch_size
72 | self.steps = len(self.data) // self.batch_size
73 | if len(self.data) % self.batch_size != 0:
74 | self.steps += 1
75 | def __len__(self):
76 | return self.steps
77 | def __iter__(self):
78 | while True:
79 | idxs = range(len(self.data))
80 | np.random.shuffle(idxs)
81 | X1, X2, Y = [], [], []
82 | for i in idxs:
83 | d = self.data[i]
84 | text = d[0][:maxlen]
85 | x1, x2 = tokenizer.encode(first=text)
86 | y = d[1]
87 | X1.append(x1)
88 | X2.append(x2)
89 | Y.append([y])
90 | if len(X1) == self.batch_size or i == idxs[-1]:
91 | X1 = seq_padding(X1)
92 | X2 = seq_padding(X2)
93 | Y = seq_padding(Y)
94 | yield [X1, X2], Y
95 | [X1, X2, Y] = [], [], []
96 |
97 |
98 | from keras.layers import *
99 | from keras.models import Model
100 | import keras.backend as K
101 | from keras.optimizers import Adam
102 |
103 |
104 | bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None)
105 |
106 | for l in bert_model.layers:
107 | l.trainable = True
108 |
109 | x1_in = Input(shape=(None,))
110 | x2_in = Input(shape=(None,))
111 |
112 | x = bert_model([x1_in, x2_in])
113 | x = Lambda(lambda x: x[:, 0])(x)
114 | p = Dense(1, activation='sigmoid')(x)
115 |
116 | model = Model([x1_in, x2_in], p)
117 | model.compile(
118 | loss='binary_crossentropy',
119 | optimizer=Adam(1e-5), # 用足够小的学习率
120 | metrics=['accuracy']
121 | )
122 | model.summary()
123 |
124 |
125 | train_D = data_generator(train_data)
126 | valid_D = data_generator(valid_data)
127 |
128 | model.fit_generator(
129 | train_D.__iter__(),
130 | steps_per_epoch=len(train_D),
131 | epochs=5,
132 | validation_data=valid_D.__iter__(),
133 | validation_steps=len(valid_D)
134 | )
135 |
--------------------------------------------------------------------------------
/subject_extract.py:
--------------------------------------------------------------------------------
1 | #! -*- coding: utf-8 -*-
2 |
3 | import json
4 | from tqdm import tqdm
5 | import os, re
6 | import numpy as np
7 | import pandas as pd
8 | from keras_bert import load_trained_model_from_checkpoint, Tokenizer
9 | import codecs
10 |
11 |
12 | mode = 0
13 | maxlen = 128
14 | learning_rate = 5e-5
15 | min_learning_rate = 1e-5
16 |
17 |
18 | config_path = '../../kg/bert/chinese_L-12_H-768_A-12/bert_config.json'
19 | checkpoint_path = '../../kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
20 | dict_path = '../../kg/bert/chinese_L-12_H-768_A-12/vocab.txt'
21 |
22 |
23 | token_dict = {}
24 |
25 | with codecs.open(dict_path, 'r', 'utf8') as reader:
26 | for line in reader:
27 | token = line.strip()
28 | token_dict[token] = len(token_dict)
29 |
30 |
31 | class OurTokenizer(Tokenizer):
32 | def _tokenize(self, text):
33 | R = []
34 | for c in text:
35 | if c in self._token_dict:
36 | R.append(c)
37 | elif self._is_space(c):
38 | R.append('[unused1]') # space类用未经训练的[unused1]表示
39 | else:
40 | R.append('[UNK]') # 剩余的字符是[UNK]
41 | return R
42 |
43 | tokenizer = OurTokenizer(token_dict)
44 |
45 |
46 | D = pd.read_csv('../ccks2019_event_entity_extract/event_type_entity_extract_train.csv', encoding='utf-8', header=None)
47 | D = D[D[2] != u'其他']
48 | classes = set(D[2].unique())
49 |
50 |
51 | train_data = []
52 | for t,c,n in zip(D[1], D[2], D[3]):
53 | train_data.append((t, c, n))
54 |
55 |
56 | if not os.path.exists('../random_order_train.json'):
57 | random_order = range(len(train_data))
58 | np.random.shuffle(random_order)
59 | json.dump(
60 | random_order,
61 | open('../random_order_train.json', 'w'),
62 | indent=4
63 | )
64 | else:
65 | random_order = json.load(open('../random_order_train.json'))
66 |
67 |
68 | dev_data = [train_data[j] for i, j in enumerate(random_order) if i % 9 == mode]
69 | train_data = [train_data[j] for i, j in enumerate(random_order) if i % 9 != mode]
70 | additional_chars = set()
71 | for d in train_data + dev_data:
72 | additional_chars.update(re.findall(u'[^\u4e00-\u9fa5a-zA-Z0-9\*]', d[2]))
73 |
74 | additional_chars.remove(u',')
75 |
76 |
77 | D = pd.read_csv('../ccks2019_event_entity_extract/event_type_entity_extract_eval.csv', encoding='utf-8', header=None)
78 | test_data = []
79 | for id,t,c in zip(D[0], D[1], D[2]):
80 | test_data.append((id, t, c))
81 |
82 |
83 | def seq_padding(X, padding=0):
84 | L = [len(x) for x in X]
85 | ML = max(L)
86 | return np.array([
87 | np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X
88 | ])
89 |
90 |
91 | def list_find(list1, list2):
92 | """在list1中寻找子串list2,如果找到,返回第一个下标;
93 | 如果找不到,返回-1。
94 | """
95 | n_list2 = len(list2)
96 | for i in range(len(list1)):
97 | if list1[i: i+n_list2] == list2:
98 | return i
99 | return -1
100 |
101 |
102 | class data_generator:
103 | def __init__(self, data, batch_size=32):
104 | self.data = data
105 | self.batch_size = batch_size
106 | self.steps = len(self.data) // self.batch_size
107 | if len(self.data) % self.batch_size != 0:
108 | self.steps += 1
109 | def __len__(self):
110 | return self.steps
111 | def __iter__(self):
112 | while True:
113 | idxs = range(len(self.data))
114 | np.random.shuffle(idxs)
115 | X1, X2, S1, S2 = [], [], [], []
116 | for i in idxs:
117 | d = self.data[i]
118 | text, c = d[0][:maxlen], d[1]
119 | text = u'___%s___%s' % (c, text)
120 | tokens = tokenizer.tokenize(text)
121 | e = d[2]
122 | e_tokens = tokenizer.tokenize(e)[1:-1]
123 | s1, s2 = np.zeros(len(tokens)), np.zeros(len(tokens))
124 | start = list_find(tokens, e_tokens)
125 | if start != -1:
126 | end = start + len(e_tokens) - 1
127 | s1[start] = 1
128 | s2[end] = 1
129 | x1, x2 = tokenizer.encode(first=text)
130 | X1.append(x1)
131 | X2.append(x2)
132 | S1.append(s1)
133 | S2.append(s2)
134 | if len(X1) == self.batch_size or i == idxs[-1]:
135 | X1 = seq_padding(X1)
136 | X2 = seq_padding(X2)
137 | S1 = seq_padding(S1)
138 | S2 = seq_padding(S2)
139 | yield [X1, X2, S1, S2], None
140 | X1, X2, S1, S2 = [], [], [], []
141 |
142 |
143 | from keras.layers import *
144 | from keras.models import Model
145 | import keras.backend as K
146 | from keras.callbacks import Callback
147 | from keras.optimizers import Adam
148 |
149 |
150 | bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None)
151 |
152 | for l in bert_model.layers:
153 | l.trainable = True
154 |
155 |
156 | x1_in = Input(shape=(None,)) # 待识别句子输入
157 | x2_in = Input(shape=(None,)) # 待识别句子输入
158 | s1_in = Input(shape=(None,)) # 实体左边界(标签)
159 | s2_in = Input(shape=(None,)) # 实体右边界(标签)
160 |
161 | x1, x2, s1, s2 = x1_in, x2_in, s1_in, s2_in
162 | x_mask = Lambda(lambda x: K.cast(K.greater(K.expand_dims(x, 2), 0), 'float32'))(x1)
163 |
164 | x = bert_model([x1, x2])
165 | ps1 = Dense(1, use_bias=False)(x)
166 | ps1 = Lambda(lambda x: x[0][..., 0] - (1 - x[1][..., 0]) * 1e10)([ps1, x_mask])
167 | ps2 = Dense(1, use_bias=False)(x)
168 | ps2 = Lambda(lambda x: x[0][..., 0] - (1 - x[1][..., 0]) * 1e10)([ps2, x_mask])
169 |
170 | model = Model([x1_in, x2_in], [ps1, ps2])
171 |
172 |
173 | train_model = Model([x1_in, x2_in, s1_in, s2_in], [ps1, ps2])
174 |
175 | loss1 = K.mean(K.categorical_crossentropy(s1_in, ps1, from_logits=True))
176 | ps2 -= (1 - K.cumsum(s1, 1)) * 1e10
177 | loss2 = K.mean(K.categorical_crossentropy(s2_in, ps2, from_logits=True))
178 | loss = loss1 + loss2
179 |
180 | train_model.add_loss(loss)
181 | train_model.compile(optimizer=Adam(learning_rate))
182 | train_model.summary()
183 |
184 |
185 | def softmax(x):
186 | x = x - np.max(x)
187 | x = np.exp(x)
188 | return x / np.sum(x)
189 |
190 |
191 | def extract_entity(text_in, c_in):
192 | if c_in not in classes:
193 | return 'NaN'
194 | text_in = u'___%s___%s' % (c_in, text_in)
195 | text_in = text_in[:510]
196 | _tokens = tokenizer.tokenize(text_in)
197 | _x1, _x2 = tokenizer.encode(first=text_in)
198 | _x1, _x2 = np.array([_x1]), np.array([_x2])
199 | _ps1, _ps2 = model.predict([_x1, _x2])
200 | _ps1, _ps2 = softmax(_ps1[0]), softmax(_ps2[0])
201 | for i, _t in enumerate(_tokens):
202 | if len(_t) == 1 and re.findall(u'[^\u4e00-\u9fa5a-zA-Z0-9\*]', _t) and _t not in additional_chars:
203 | _ps1[i] -= 10
204 | start = _ps1.argmax()
205 | for end in range(start, len(_tokens)):
206 | _t = _tokens[end]
207 | if len(_t) == 1 and re.findall(u'[^\u4e00-\u9fa5a-zA-Z0-9\*]', _t) and _t not in additional_chars:
208 | break
209 | end = _ps2[start:end+1].argmax() + start
210 | a = text_in[start-1: end]
211 | return a
212 |
213 |
214 | class Evaluate(Callback):
215 | def __init__(self):
216 | self.ACC = []
217 | self.best = 0.
218 | self.passed = 0
219 | def on_batch_begin(self, batch, logs=None):
220 | """第一个epoch用来warmup,第二个epoch把学习率降到最低
221 | """
222 | if self.passed < self.params['steps']:
223 | lr = (self.passed + 1.) / self.params['steps'] * learning_rate
224 | K.set_value(self.model.optimizer.lr, lr)
225 | self.passed += 1
226 | elif self.params['steps'] <= self.passed < self.params['steps'] * 2:
227 | lr = (2 - (self.passed + 1.) / self.params['steps']) * (learning_rate - min_learning_rate)
228 | lr += min_learning_rate
229 | K.set_value(self.model.optimizer.lr, lr)
230 | self.passed += 1
231 | def on_epoch_end(self, epoch, logs=None):
232 | acc = self.evaluate()
233 | self.ACC.append(acc)
234 | if acc > self.best:
235 | self.best = acc
236 | train_model.save_weights('best_model.weights')
237 | print 'acc: %.4f, best acc: %.4f\n' % (acc, self.best)
238 | def evaluate(self):
239 | A = 1e-10
240 | F = open('dev_pred.json', 'w')
241 | for d in tqdm(iter(dev_data)):
242 | R = extract_entity(d[0], d[1])
243 | if R == d[2]:
244 | A += 1
245 | s = ', '.join(d + (R,))
246 | F.write(s.encode('utf-8') + '\n')
247 | F.close()
248 | return A / len(dev_data)
249 |
250 |
251 | def test(test_data):
252 | F = open('result.txt', 'w')
253 | for d in tqdm(iter(test_data)):
254 | s = u'"%s","%s"\n' % (d[0], extract_entity(d[1], d[2]))
255 | s = s.encode('utf-8')
256 | F.write(s)
257 | F.close()
258 |
259 |
260 | evaluator = Evaluate()
261 | train_D = data_generator(train_data)
262 |
263 |
264 | if __name__ == '__main__':
265 | train_model.fit_generator(train_D.__iter__(),
266 | steps_per_epoch=len(train_D),
267 | epochs=10,
268 | callbacks=[evaluator]
269 | )
270 | else:
271 | train_model.load_weights('best_model.weights')
272 |
--------------------------------------------------------------------------------
/relation_extract.py:
--------------------------------------------------------------------------------
1 | #! -*- coding:utf-8 -*-
2 |
3 | import json
4 | import numpy as np
5 | from random import choice
6 | from tqdm import tqdm
7 | from keras_bert import load_trained_model_from_checkpoint, Tokenizer
8 | import re, os
9 | import codecs
10 |
11 |
12 | mode = 0
13 | maxlen = 160
14 | learning_rate = 5e-5
15 | min_learning_rate = 1e-5
16 |
17 | config_path = '../bert/chinese_L-12_H-768_A-12/bert_config.json'
18 | checkpoint_path = '../bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
19 | dict_path = '../bert/chinese_L-12_H-768_A-12/vocab.txt'
20 |
21 |
22 | token_dict = {}
23 |
24 | with codecs.open(dict_path, 'r', 'utf8') as reader:
25 | for line in reader:
26 | token = line.strip()
27 | token_dict[token] = len(token_dict)
28 |
29 |
30 | class OurTokenizer(Tokenizer):
31 | def _tokenize(self, text):
32 | R = []
33 | for c in text:
34 | if c in self._token_dict:
35 | R.append(c)
36 | elif self._is_space(c):
37 | R.append('[unused1]') # space类用未经训练的[unused1]表示
38 | else:
39 | R.append('[UNK]') # 剩余的字符是[UNK]
40 | return R
41 |
42 | tokenizer = OurTokenizer(token_dict)
43 |
44 |
45 | train_data = json.load(open('../datasets/train_data_me.json'))
46 | dev_data = json.load(open('../datasets/dev_data_me.json'))
47 | id2predicate, predicate2id = json.load(open('../datasets/all_50_schemas_me.json'))
48 | id2predicate = {int(i):j for i,j in id2predicate.items()}
49 | num_classes = len(id2predicate)
50 |
51 |
52 | total_data = []
53 | total_data.extend(train_data)
54 | total_data.extend(dev_data)
55 |
56 |
57 | if not os.path.exists('../random_order_train_dev.json'):
58 | random_order = range(len(total_data))
59 | np.random.shuffle(random_order)
60 | json.dump(
61 | random_order,
62 | open('../random_order_train_dev.json', 'w'),
63 | indent=4
64 | )
65 | else:
66 | random_order = json.load(open('../random_order_train_dev.json'))
67 |
68 |
69 | train_data = [total_data[j] for i, j in enumerate(random_order) if i % 8 != mode]
70 | dev_data = [total_data[j] for i, j in enumerate(random_order) if i % 8 == mode]
71 |
72 |
73 | predicates = {} # 格式:{predicate: [(subject, predicate, object)]}
74 |
75 |
76 | def repair(d):
77 | d['text'] = d['text'].lower()
78 | something = re.findall(u'《([^《》]*?)》', d['text'])
79 | something = [s.strip() for s in something]
80 | zhuanji = []
81 | gequ = []
82 | for sp in d['spo_list']:
83 | sp[0] = sp[0].strip(u'《》').strip().lower()
84 | sp[2] = sp[2].strip(u'《》').strip().lower()
85 | for some in something:
86 | if sp[0] in some and d['text'].count(sp[0]) == 1:
87 | sp[0] = some
88 | if sp[1] == u'所属专辑':
89 | zhuanji.append(sp[2])
90 | gequ.append(sp[0])
91 | spo_list = []
92 | for sp in d['spo_list']:
93 | if sp[1] in [u'歌手', u'作词', u'作曲']:
94 | if sp[0] in zhuanji and sp[0] not in gequ:
95 | continue
96 | spo_list.append(tuple(sp))
97 | d['spo_list'] = spo_list
98 |
99 |
100 | for d in train_data:
101 | repair(d)
102 | for sp in d['spo_list']:
103 | if sp[1] not in predicates:
104 | predicates[sp[1]] = []
105 | predicates[sp[1]].append(sp)
106 |
107 |
108 | for d in dev_data:
109 | repair(d)
110 |
111 |
112 | def seq_padding(X, padding=0):
113 | L = [len(x) for x in X]
114 | ML = max(L)
115 | return np.array([
116 | np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X
117 | ])
118 |
119 |
120 | def list_find(list1, list2):
121 | """在list1中寻找子串list2,如果找到,返回第一个下标;
122 | 如果找不到,返回-1。
123 | """
124 | n_list2 = len(list2)
125 | for i in range(len(list1)):
126 | if list1[i: i+n_list2] == list2:
127 | return i
128 | return -1
129 |
130 |
131 | class data_generator:
132 | def __init__(self, data, batch_size=32):
133 | self.data = data
134 | self.batch_size = batch_size
135 | self.steps = len(self.data) // self.batch_size
136 | if len(self.data) % self.batch_size != 0:
137 | self.steps += 1
138 | def __len__(self):
139 | return self.steps
140 | def __iter__(self):
141 | while True:
142 | idxs = range(len(self.data))
143 | np.random.shuffle(idxs)
144 | T1, T2, S1, S2, K1, K2, O1, O2 = [], [], [], [], [], [], [], []
145 | for i in idxs:
146 | d = self.data[i]
147 | text = d['text'][:maxlen]
148 | tokens = tokenizer.tokenize(text)
149 | items = {}
150 | for sp in d['spo_list']:
151 | sp = (tokenizer.tokenize(sp[0])[1:-1], sp[1], tokenizer.tokenize(sp[2])[1:-1])
152 | subjectid = list_find(tokens, sp[0])
153 | objectid = list_find(tokens, sp[2])
154 | if subjectid != -1 and objectid != -1:
155 | key = (subjectid, subjectid+len(sp[0]))
156 | if key not in items:
157 | items[key] = []
158 | items[key].append((objectid,
159 | objectid+len(sp[2]),
160 | predicate2id[sp[1]]))
161 | if items:
162 | t1, t2 = tokenizer.encode(first=text)
163 | T1.append(t1)
164 | T2.append(t2)
165 | s1, s2 = np.zeros(len(tokens)), np.zeros(len(tokens))
166 | for j in items:
167 | s1[j[0]] = 1
168 | s2[j[1]-1] = 1
169 | k1, k2 = np.array(items.keys()).T
170 | k1 = choice(k1)
171 | k2 = choice(k2[k2 >= k1])
172 | o1, o2 = np.zeros((len(tokens), num_classes)), np.zeros((len(tokens), num_classes))
173 | for j in items.get((k1, k2), []):
174 | o1[j[0]][j[2]] = 1
175 | o2[j[1]-1][j[2]] = 1
176 | S1.append(s1)
177 | S2.append(s2)
178 | K1.append([k1])
179 | K2.append([k2-1])
180 | O1.append(o1)
181 | O2.append(o2)
182 | if len(T1) == self.batch_size or i == idxs[-1]:
183 | T1 = seq_padding(T1)
184 | T2 = seq_padding(T2)
185 | S1 = seq_padding(S1)
186 | S2 = seq_padding(S2)
187 | O1 = seq_padding(O1, np.zeros(num_classes))
188 | O2 = seq_padding(O2, np.zeros(num_classes))
189 | K1, K2 = np.array(K1), np.array(K2)
190 | yield [T1, T2, S1, S2, K1, K2, O1, O2], None
191 | T1, T2, S1, S2, K1, K2, O1, O2, = [], [], [], [], [], [], [], []
192 |
193 |
194 | from keras.layers import *
195 | from keras.models import Model
196 | import keras.backend as K
197 | from keras.callbacks import Callback
198 | from keras.optimizers import Adam
199 |
200 |
201 | def seq_gather(x):
202 | """seq是[None, seq_len, s_size]的格式,
203 | idxs是[None, 1]的格式,在seq的第i个序列中选出第idxs[i]个向量,
204 | 最终输出[None, s_size]的向量。
205 | """
206 | seq, idxs = x
207 | idxs = K.cast(idxs, 'int32')
208 | batch_idxs = K.arange(0, K.shape(seq)[0])
209 | batch_idxs = K.expand_dims(batch_idxs, 1)
210 | idxs = K.concatenate([batch_idxs, idxs], 1)
211 | return K.tf.gather_nd(seq, idxs)
212 |
213 |
214 | bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None)
215 |
216 | for l in bert_model.layers:
217 | l.trainable = True
218 |
219 |
220 | t1_in = Input(shape=(None,))
221 | t2_in = Input(shape=(None,))
222 | s1_in = Input(shape=(None,))
223 | s2_in = Input(shape=(None,))
224 | k1_in = Input(shape=(1,))
225 | k2_in = Input(shape=(1,))
226 | o1_in = Input(shape=(None, num_classes))
227 | o2_in = Input(shape=(None, num_classes))
228 |
229 | t1, t2, s1, s2, k1, k2, o1, o2 = t1_in, t2_in, s1_in, s2_in, k1_in, k2_in, o1_in, o2_in
230 | mask = Lambda(lambda x: K.cast(K.greater(K.expand_dims(x, 2), 0), 'float32'))(t1)
231 |
232 | t = bert_model([t1, t2])
233 | ps1 = Dense(1, activation='sigmoid')(t)
234 | ps2 = Dense(1, activation='sigmoid')(t)
235 |
236 | subject_model = Model([t1_in, t2_in], [ps1, ps2]) # 预测subject的模型
237 |
238 |
239 | k1v = Lambda(seq_gather)([t, k1])
240 | k2v = Lambda(seq_gather)([t, k2])
241 | kv = Average()([k1v, k2v])
242 | t = Add()([t, kv])
243 | po1 = Dense(num_classes, activation='sigmoid')(t)
244 | po2 = Dense(num_classes, activation='sigmoid')(t)
245 |
246 | object_model = Model([t1_in, t2_in, k1_in, k2_in], [po1, po2]) # 输入text和subject,预测object及其关系
247 |
248 |
249 | train_model = Model([t1_in, t2_in, s1_in, s2_in, k1_in, k2_in, o1_in, o2_in],
250 | [ps1, ps2, po1, po2])
251 |
252 | s1 = K.expand_dims(s1, 2)
253 | s2 = K.expand_dims(s2, 2)
254 |
255 | s1_loss = K.binary_crossentropy(s1, ps1)
256 | s1_loss = K.sum(s1_loss * mask) / K.sum(mask)
257 | s2_loss = K.binary_crossentropy(s2, ps2)
258 | s2_loss = K.sum(s2_loss * mask) / K.sum(mask)
259 |
260 | o1_loss = K.sum(K.binary_crossentropy(o1, po1), 2, keepdims=True)
261 | o1_loss = K.sum(o1_loss * mask) / K.sum(mask)
262 | o2_loss = K.sum(K.binary_crossentropy(o2, po2), 2, keepdims=True)
263 | o2_loss = K.sum(o2_loss * mask) / K.sum(mask)
264 |
265 | loss = (s1_loss + s2_loss) + (o1_loss + o2_loss)
266 |
267 | train_model.add_loss(loss)
268 | train_model.compile(optimizer=Adam(learning_rate))
269 | train_model.summary()
270 |
271 |
272 | def extract_items(text_in):
273 | _tokens = tokenizer.tokenize(text_in)
274 | _t1, _t2 = tokenizer.encode(first=text_in)
275 | _t1, _t2 = np.array([_t1]), np.array([_t2])
276 | _k1, _k2 = subject_model.predict([_t1, _t2])
277 | _k1, _k2 = np.where(_k1[0] > 0.5)[0], np.where(_k2[0] > 0.4)[0]
278 | _subjects = []
279 | for i in _k1:
280 | j = _k2[_k2 >= i]
281 | if len(j) > 0:
282 | j = j[0]
283 | _subject = text_in[i-1: j]
284 | _subjects.append((_subject, i, j))
285 | if _subjects:
286 | R = []
287 | _t1 = np.repeat(_t1, len(_subjects), 0)
288 | _t2 = np.repeat(_t2, len(_subjects), 0)
289 | _k1, _k2 = np.array([_s[1:] for _s in _subjects]).T.reshape((2, -1, 1))
290 | _o1, _o2 = object_model.predict([_t1, _t2, _k1, _k2])
291 | for i,_subject in enumerate(_subjects):
292 | _oo1, _oo2 = np.where(_o1[i] > 0.5), np.where(_o2[i] > 0.4)
293 | for _ooo1, _c1 in zip(*_oo1):
294 | for _ooo2, _c2 in zip(*_oo2):
295 | if _ooo1 <= _ooo2 and _c1 == _c2:
296 | _object = text_in[_ooo1-1: _ooo2]
297 | _predicate = id2predicate[_c1]
298 | R.append((_subject[0], _predicate, _object))
299 | break
300 | zhuanji, gequ = [], []
301 | for s, p, o in R[:]:
302 | if p == u'妻子':
303 | R.append((o, u'丈夫', s))
304 | elif p == u'丈夫':
305 | R.append((o, u'妻子', s))
306 | if p == u'所属专辑':
307 | zhuanji.append(o)
308 | gequ.append(s)
309 | spo_list = set()
310 | for s, p, o in R:
311 | if p in [u'歌手', u'作词', u'作曲']:
312 | if s in zhuanji and s not in gequ:
313 | continue
314 | spo_list.add((s, p, o))
315 | return list(spo_list)
316 | else:
317 | return []
318 |
319 |
320 | class Evaluate(Callback):
321 | def __init__(self):
322 | self.F1 = []
323 | self.best = 0.
324 | self.passed = 0
325 | self.stage = 0
326 | def on_batch_begin(self, batch, logs=None):
327 | """第一个epoch用来warmup,第二个epoch把学习率降到最低
328 | """
329 | if self.passed < self.params['steps']:
330 | lr = (self.passed + 1.) / self.params['steps'] * learning_rate
331 | K.set_value(self.model.optimizer.lr, lr)
332 | self.passed += 1
333 | elif self.params['steps'] <= self.passed < self.params['steps'] * 2:
334 | lr = (2 - (self.passed + 1.) / self.params['steps']) * (learning_rate - min_learning_rate)
335 | lr += min_learning_rate
336 | K.set_value(self.model.optimizer.lr, lr)
337 | self.passed += 1
338 | def on_epoch_end(self, epoch, logs=None):
339 | f1, precision, recall = self.evaluate()
340 | self.F1.append(f1)
341 | if f1 > self.best:
342 | self.best = f1
343 | train_model.save_weights('best_model.weights')
344 | print 'f1: %.4f, precision: %.4f, recall: %.4f, best f1: %.4f\n' % (f1, precision, recall, self.best)
345 | def evaluate(self):
346 | orders = ['subject', 'predicate', 'object']
347 | A, B, C = 1e-10, 1e-10, 1e-10
348 | F = open('dev_pred.json', 'w')
349 | for d in tqdm(iter(dev_data)):
350 | R = set(extract_items(d['text']))
351 | T = set(d['spo_list'])
352 | A += len(R & T)
353 | B += len(R)
354 | C += len(T)
355 | s = json.dumps({
356 | 'text': d['text'],
357 | 'spo_list': [
358 | dict(zip(orders, spo)) for spo in T
359 | ],
360 | 'spo_list_pred': [
361 | dict(zip(orders, spo)) for spo in R
362 | ],
363 | 'new': [
364 | dict(zip(orders, spo)) for spo in R - T
365 | ],
366 | 'lack': [
367 | dict(zip(orders, spo)) for spo in T - R
368 | ]
369 | }, ensure_ascii=False, indent=4)
370 | F.write(s.encode('utf-8') + '\n')
371 | F.close()
372 | return 2 * A / (B + C), A / B, A / C
373 |
374 |
375 | def test(test_data):
376 | """输出测试结果
377 | """
378 | orders = ['subject', 'predicate', 'object', 'object_type', 'subject_type']
379 | F = open('test_pred.json', 'w')
380 | for d in tqdm(iter(test_data)):
381 | R = set(extract_items(d['text']))
382 | s = json.dumps({
383 | 'text': d['text'],
384 | 'spo_list': [
385 | dict(zip(orders, spo + ('', ''))) for spo in R
386 | ]
387 | }, ensure_ascii=False)
388 | F.write(s.encode('utf-8') + '\n')
389 | F.close()
390 |
391 |
392 | train_D = data_generator(train_data)
393 | evaluator = Evaluate()
394 |
395 |
396 | if __name__ == '__main__':
397 | train_model.fit_generator(train_D.__iter__(),
398 | steps_per_epoch=1000,
399 | epochs=30,
400 | callbacks=[evaluator]
401 | )
402 | else:
403 | train_model.load_weights('best_model.weights')
404 |
--------------------------------------------------------------------------------
/nl2sql_baseline.py:
--------------------------------------------------------------------------------
1 | #! -*- coding: utf-8 -*-
2 | # 追一科技2019年NL2SQL挑战赛的一个Baseline(个人作品,非官方发布,基于Bert)
3 | # 比赛地址:https://tianchi.aliyun.com/competition/entrance/231716/introduction
4 | # 目前全匹配率大概是58%左右
5 |
6 | import json
7 | import uniout
8 | from keras_bert import load_trained_model_from_checkpoint, Tokenizer
9 | import codecs
10 | from keras.layers import *
11 | from keras.models import Model
12 | import keras.backend as K
13 | from keras.optimizers import Adam
14 | from keras.callbacks import Callback
15 | from tqdm import tqdm
16 | import jieba
17 | import editdistance
18 | import re
19 |
20 |
21 | maxlen = 160
22 | num_agg = 7 # agg_sql_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM", 6:"不被select"}
23 | num_op = 5 # {0:">", 1:"<", 2:"==", 3:"!=", 4:"不被select"}
24 | num_cond_conn_op = 3 # conn_sql_dict = {0:"", 1:"and", 2:"or"}
25 | learning_rate = 5e-5
26 | min_learning_rate = 1e-5
27 |
28 |
29 | config_path = '../../kg/bert/chinese_wwm_L-12_H-768_A-12/bert_config.json'
30 | checkpoint_path = '../../kg/bert/chinese_wwm_L-12_H-768_A-12/bert_model.ckpt'
31 | dict_path = '../../kg/bert/chinese_wwm_L-12_H-768_A-12/vocab.txt'
32 |
33 |
34 | def read_data(data_file, table_file):
35 | data, tables = [], {}
36 | with open(data_file) as f:
37 | for l in f:
38 | data.append(json.loads(l))
39 | with open(table_file) as f:
40 | for l in f:
41 | l = json.loads(l)
42 | d = {}
43 | d['headers'] = l['header']
44 | d['header2id'] = {j: i for i, j in enumerate(d['headers'])}
45 | d['content'] = {}
46 | d['all_values'] = set()
47 | rows = np.array(l['rows'])
48 | for i, h in enumerate(d['headers']):
49 | d['content'][h] = set(rows[:, i])
50 | d['all_values'].update(d['content'][h])
51 | d['all_values'] = set([i for i in d['all_values'] if hasattr(i, '__len__')])
52 | tables[l['id']] = d
53 | return data, tables
54 |
55 |
56 | train_data, train_tables = read_data('../datasets/train.json', '../datasets/train.tables.json')
57 | valid_data, valid_tables = read_data('../datasets/val.json', '../datasets/val.tables.json')
58 | test_data, test_tables = read_data('../datasets/test.json', '../datasets/test.tables.json')
59 |
60 |
61 | token_dict = {}
62 |
63 | with codecs.open(dict_path, 'r', 'utf8') as reader:
64 | for line in reader:
65 | token = line.strip()
66 | token_dict[token] = len(token_dict)
67 |
68 |
69 | class OurTokenizer(Tokenizer):
70 | def _tokenize(self, text):
71 | R = []
72 | for c in text:
73 | if c in self._token_dict:
74 | R.append(c)
75 | elif self._is_space(c):
76 | R.append('[unused1]') # space类用未经训练的[unused1]表示
77 | else:
78 | R.append('[UNK]') # 剩余的字符是[UNK]
79 | return R
80 |
81 | tokenizer = OurTokenizer(token_dict)
82 |
83 |
84 | def seq_padding(X, padding=0, maxlen=None):
85 | if maxlen is None:
86 | L = [len(x) for x in X]
87 | ML = max(L)
88 | else:
89 | ML = maxlen
90 | return np.array([
91 | np.concatenate([x[:ML], [padding] * (ML - len(x))]) if len(x[:ML]) < ML else x for x in X
92 | ])
93 |
94 |
95 | def most_similar(s, slist):
96 | """从词表中找最相近的词(当无法全匹配的时候)
97 | """
98 | if len(slist) == 0:
99 | return s
100 | scores = [editdistance.eval(s, t) for t in slist]
101 | return slist[np.argmin(scores)]
102 |
103 |
104 | def most_similar_2(w, s):
105 | """从句子s中找与w最相近的片段,
106 | 借助分词工具和ngram的方式尽量精确地确定边界。
107 | """
108 | sw = jieba.lcut(s)
109 | sl = list(sw)
110 | sl.extend([''.join(i) for i in zip(sw, sw[1:])])
111 | sl.extend([''.join(i) for i in zip(sw, sw[1:], sw[2:])])
112 | return most_similar(w, sl)
113 |
114 |
115 | class data_generator:
116 | def __init__(self, data, tables, batch_size=32):
117 | self.data = data
118 | self.tables = tables
119 | self.batch_size = batch_size
120 | self.steps = len(self.data) // self.batch_size
121 | if len(self.data) % self.batch_size != 0:
122 | self.steps += 1
123 | def __len__(self):
124 | return self.steps
125 | def __iter__(self):
126 | while True:
127 | idxs = range(len(self.data))
128 | np.random.shuffle(idxs)
129 | X1, X2, XM, H, HM, SEL, CONN, CSEL, COP = [], [], [], [], [], [], [], [], []
130 | for i in idxs:
131 | d = self.data[i]
132 | t = self.tables[d['table_id']]['headers']
133 | x1, x2 = tokenizer.encode(d['question'])
134 | xm = [0] + [1] * len(d['question']) + [0]
135 | h = []
136 | for j in t:
137 | _x1, _x2 = tokenizer.encode(j)
138 | h.append(len(x1))
139 | x1.extend(_x1)
140 | x2.extend(_x2)
141 | hm = [1] * len(h)
142 | sel = []
143 | for j in range(len(h)):
144 | if j in d['sql']['sel']:
145 | j = d['sql']['sel'].index(j)
146 | sel.append(d['sql']['agg'][j])
147 | else:
148 | sel.append(num_agg - 1) # 不被select则被标记为num_agg-1
149 | conn = [d['sql']['cond_conn_op']]
150 | csel = np.zeros(len(d['question']) + 2, dtype='int32') # 这里的0既表示padding,又表示第一列,padding部分训练时会被mask
151 | cop = np.zeros(len(d['question']) + 2, dtype='int32') + num_op - 1 # 不被select则被标记为num_op-1
152 | for j in d['sql']['conds']:
153 | if j[2] not in d['question']:
154 | j[2] = most_similar_2(j[2], d['question'])
155 | if j[2] not in d['question']:
156 | continue
157 | k = d['question'].index(j[2])
158 | csel[k + 1: k + 1 + len(j[2])] = j[0]
159 | cop[k + 1: k + 1 + len(j[2])] = j[1]
160 | if len(x1) > maxlen:
161 | continue
162 | X1.append(x1) # bert的输入
163 | X2.append(x2) # bert的输入
164 | XM.append(xm) # 输入序列的mask
165 | H.append(h) # 列名所在位置
166 | HM.append(hm) # 列名mask
167 | SEL.append(sel) # 被select的列
168 | CONN.append(conn) # 连接类型
169 | CSEL.append(csel) # 条件中的列
170 | COP.append(cop) # 条件中的运算符(同时也是值的标记)
171 | if len(X1) == self.batch_size:
172 | X1 = seq_padding(X1)
173 | X2 = seq_padding(X2)
174 | XM = seq_padding(XM, maxlen=X1.shape[1])
175 | H = seq_padding(H)
176 | HM = seq_padding(HM)
177 | SEL = seq_padding(SEL)
178 | CONN = seq_padding(CONN)
179 | CSEL = seq_padding(CSEL, maxlen=X1.shape[1])
180 | COP = seq_padding(COP, maxlen=X1.shape[1])
181 | yield [X1, X2, XM, H, HM, SEL, CONN, CSEL, COP], None
182 | X1, X2, XM, H, HM, SEL, CONN, CSEL, COP = [], [], [], [], [], [], [], [], []
183 |
184 |
185 | def seq_gather(x):
186 | """seq是[None, seq_len, s_size]的格式,
187 | idxs是[None, n]的格式,在seq的第i个序列中选出第idxs[i]个向量,
188 | 最终输出[None, n, s_size]的向量。
189 | """
190 | seq, idxs = x
191 | idxs = K.cast(idxs, 'int32')
192 | return K.tf.batch_gather(seq, idxs)
193 |
194 |
195 | bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None)
196 |
197 | for l in bert_model.layers:
198 | l.trainable = True
199 |
200 |
201 | x1_in = Input(shape=(None,), dtype='int32')
202 | x2_in = Input(shape=(None,))
203 | xm_in = Input(shape=(None,))
204 | h_in = Input(shape=(None,), dtype='int32')
205 | hm_in = Input(shape=(None,))
206 | sel_in = Input(shape=(None,), dtype='int32')
207 | conn_in = Input(shape=(1,), dtype='int32')
208 | csel_in = Input(shape=(None,), dtype='int32')
209 | cop_in = Input(shape=(None,), dtype='int32')
210 |
211 | x1, x2, xm, h, hm, sel, conn, csel, cop = (
212 | x1_in, x2_in, xm_in, h_in, hm_in, sel_in, conn_in, csel_in, cop_in
213 | )
214 |
215 | hm = Lambda(lambda x: K.expand_dims(x, 1))(hm) # header的mask.shape=(None, 1, h_len)
216 |
217 | x = bert_model([x1_in, x2_in])
218 | x4conn = Lambda(lambda x: x[:, 0])(x)
219 | pconn = Dense(num_cond_conn_op, activation='softmax')(x4conn)
220 |
221 | x4h = Lambda(seq_gather)([x, h])
222 | psel = Dense(num_agg, activation='softmax')(x4h)
223 |
224 | pcop = Dense(num_op, activation='softmax')(x)
225 |
226 | x = Lambda(lambda x: K.expand_dims(x, 2))(x)
227 | x4h = Lambda(lambda x: K.expand_dims(x, 1))(x4h)
228 | pcsel_1 = Dense(256)(x)
229 | pcsel_2 = Dense(256)(x4h)
230 | pcsel = Lambda(lambda x: x[0] + x[1])([pcsel_1, pcsel_2])
231 | pcsel = Activation('tanh')(pcsel)
232 | pcsel = Dense(1)(pcsel)
233 | pcsel = Lambda(lambda x: x[0][..., 0] - (1 - x[1]) * 1e10)([pcsel, hm])
234 | pcsel = Activation('softmax')(pcsel)
235 |
236 | model = Model(
237 | [x1_in, x2_in, h_in, hm_in],
238 | [psel, pconn, pcop, pcsel]
239 | )
240 |
241 | train_model = Model(
242 | [x1_in, x2_in, xm_in, h_in, hm_in, sel_in, conn_in, csel_in, cop_in],
243 | [psel, pconn, pcop, pcsel]
244 | )
245 |
246 | xm = xm # question的mask.shape=(None, x_len)
247 | hm = hm[:, 0] # header的mask.shape=(None, h_len)
248 | cm = K.cast(K.not_equal(cop, num_op - 1), 'float32') # conds的mask.shape=(None, x_len)
249 |
250 | psel_loss = K.sparse_categorical_crossentropy(sel_in, psel)
251 | psel_loss = K.sum(psel_loss * hm) / K.sum(hm)
252 | pconn_loss = K.sparse_categorical_crossentropy(conn_in, pconn)
253 | pconn_loss = K.mean(pconn_loss)
254 | pcop_loss = K.sparse_categorical_crossentropy(cop_in, pcop)
255 | pcop_loss = K.sum(pcop_loss * xm) / K.sum(xm)
256 | pcsel_loss = K.sparse_categorical_crossentropy(csel_in, pcsel)
257 | pcsel_loss = K.sum(pcsel_loss * xm * cm) / K.sum(xm * cm)
258 | loss = psel_loss + pconn_loss + pcop_loss + pcsel_loss
259 |
260 | train_model.add_loss(loss)
261 | train_model.compile(optimizer=Adam(learning_rate))
262 | train_model.summary()
263 |
264 |
265 | def nl2sql(question, table):
266 | """输入question和headers,转SQL
267 | """
268 | x1, x2 = tokenizer.encode(question)
269 | h = []
270 | for i in table['headers']:
271 | _x1, _x2 = tokenizer.encode(i)
272 | h.append(len(x1))
273 | x1.extend(_x1)
274 | x2.extend(_x2)
275 | hm = [1] * len(h)
276 | psel, pconn, pcop, pcsel = model.predict([
277 | np.array([x1]),
278 | np.array([x2]),
279 | np.array([h]),
280 | np.array([hm])
281 | ])
282 | R = {'agg': [], 'sel': []}
283 | for i, j in enumerate(psel[0].argmax(1)):
284 | if j != num_agg - 1: # num_agg-1类是不被select的意思
285 | R['sel'].append(i)
286 | R['agg'].append(j)
287 | conds = []
288 | v_op = -1
289 | for i, j in enumerate(pcop[0, :len(question)+1].argmax(1)):
290 | # 这里结合标注和分类来预测条件
291 | if j != num_op - 1:
292 | if v_op != j:
293 | if v_op != -1:
294 | v_end = v_start + len(v_str)
295 | csel = pcsel[0][v_start: v_end].mean(0).argmax()
296 | conds.append((csel, v_op, v_str))
297 | v_start = i
298 | v_op = j
299 | v_str = question[i - 1]
300 | else:
301 | v_str += question[i - 1]
302 | elif v_op != -1:
303 | v_end = v_start + len(v_str)
304 | csel = pcsel[0][v_start: v_end].mean(0).argmax()
305 | conds.append((csel, v_op, v_str))
306 | v_op = -1
307 | R['conds'] = set()
308 | for i, j, k in conds:
309 | if re.findall('[^\d\.]', k):
310 | j = 2 # 非数字只能用等号
311 | if j == 2:
312 | if k not in table['all_values']:
313 | # 等号的值必须在table出现过,否则找一个最相近的
314 | k = most_similar(k, list(table['all_values']))
315 | h = table['headers'][i]
316 | # 然后检查值对应的列是否正确,如果不正确,直接修正列名
317 | if k not in table['content'][h]:
318 | for r, v in table['content'].items():
319 | if k in v:
320 | i = table['header2id'][r]
321 | break
322 | R['conds'].add((i, j, k))
323 | R['conds'] = list(R['conds'])
324 | if len(R['conds']) <= 1: # 条件数少于等于1时,条件连接符直接为0
325 | R['cond_conn_op'] = 0
326 | else:
327 | R['cond_conn_op'] = 1 + pconn[0, 1:].argmax() # 不能是0
328 | return R
329 |
330 |
331 | def is_equal(R1, R2):
332 | """判断两个SQL字典是否全匹配
333 | """
334 | return (R1['cond_conn_op'] == R2['cond_conn_op']) &\
335 | (set(zip(R1['sel'], R1['agg'])) == set(zip(R2['sel'], R2['agg']))) &\
336 | (set([tuple(i) for i in R1['conds']]) == set([tuple(i) for i in R2['conds']]))
337 |
338 |
339 | def evaluate(data, tables):
340 | right = 0.
341 | pbar = tqdm()
342 | F = open('evaluate_pred.json', 'w')
343 | for i, d in enumerate(data):
344 | question = d['question']
345 | table = tables[d['table_id']]
346 | R = nl2sql(question, table)
347 | right += float(is_equal(R, d['sql']))
348 | pbar.update(1)
349 | pbar.set_description('< acc: %.5f >' % (right / (i + 1)))
350 | d['sql_pred'] = R
351 | s = json.dumps(d, ensure_ascii=False, indent=4)
352 | F.write(s.encode('utf-8') + '\n')
353 | F.close()
354 | pbar.close()
355 | return right / len(data)
356 |
357 |
358 | def test(data, tables, outfile='result.json'):
359 | pbar = tqdm()
360 | F = open(outfile, 'w')
361 | for i, d in enumerate(data):
362 | question = d['question']
363 | table = tables[d['table_id']]
364 | R = nl2sql(question, table)
365 | pbar.update(1)
366 | s = json.dumps(R, ensure_ascii=False)
367 | F.write(s.encode('utf-8') + '\n')
368 | F.close()
369 | pbar.close()
370 |
371 | # test(test_data, test_tables)
372 |
373 |
374 | class Evaluate(Callback):
375 | def __init__(self):
376 | self.accs = []
377 | self.best = 0.
378 | self.passed = 0
379 | self.stage = 0
380 | def on_batch_begin(self, batch, logs=None):
381 | """第一个epoch用来warmup,第二个epoch把学习率降到最低
382 | """
383 | if self.passed < self.params['steps']:
384 | lr = (self.passed + 1.) / self.params['steps'] * learning_rate
385 | K.set_value(self.model.optimizer.lr, lr)
386 | self.passed += 1
387 | elif self.params['steps'] <= self.passed < self.params['steps'] * 2:
388 | lr = (2 - (self.passed + 1.) / self.params['steps']) * (learning_rate - min_learning_rate)
389 | lr += min_learning_rate
390 | K.set_value(self.model.optimizer.lr, lr)
391 | self.passed += 1
392 | def on_epoch_end(self, epoch, logs=None):
393 | acc = self.evaluate()
394 | self.accs.append(acc)
395 | if acc > self.best:
396 | self.best = acc
397 | train_model.save_weights('best_model.weights')
398 | print 'acc: %.5f, best acc: %.5f\n' % (acc, self.best)
399 | def evaluate(self):
400 | return evaluate(valid_data, valid_tables)
401 |
402 |
403 | train_D = data_generator(train_data, train_tables)
404 | evaluator = Evaluate()
405 |
406 | if __name__ == '__main__':
407 | train_model.fit_generator(
408 | train_D.__iter__(),
409 | steps_per_epoch=len(train_D),
410 | epochs=15,
411 | callbacks=[evaluator]
412 | )
413 | else:
414 | train_model.load_weights('best_model.weights')
415 |
--------------------------------------------------------------------------------