├── README.md ├── sentiment.py ├── subject_extract.py ├── relation_extract.py └── nl2sql_baseline.py /README.md: -------------------------------------------------------------------------------- 1 | # bert_in_keras 2 | 用Keras来调用Bert,这可能是最简单的Bert打开姿势。 3 | 4 | ## 5 | - sentiment.py:情感分析例子,详细请看这里。 6 | - relation_extract.py:关系抽取例子,详细请看这里。 7 | - subject_extract.py:主体抽取例子,详细请看这里。 8 | - nl2sql_baseline.py:NL2SQL例子,详细请看这里。 9 | 10 | ## 详细介绍 11 | - https://kexue.fm/archives/6736 12 | - https://kexue.fm/archives/6771 13 | 14 | ## 测试环境 15 | python 2.7 + tensorflow 1.13 + keras 2.2.4 16 | 17 | ## keras_bert 18 | - https://github.com/CyberZHG/keras-bert 19 | 20 | ## 中文版权重 21 | - 官方版: https://github.com/google-research/bert 22 | - 哈工大版: https://github.com/ymcui/Chinese-BERT-wwm 23 | 24 | ## 严正声明 25 | - 不欢迎任何NLP和Keras文盲来跑此代码!!你都要玩Bert了,我认为你学习NLP的时间好歹要在半年以上,你学习Keras的时间好歹要一周以上。别想着一蹴而就,不欢迎只想调包跑通的人,不要用任何“我时间紧”的借口。 26 | - Keras是简单,不代表不需要学NLP,不代表不需要学Keras,不代表就可以不经大脑。一句话,请尊重你自己的智商。 27 | 28 | ## 在线交流 29 | QQ交流群:67729435,微信群请加机器人微信号spaces_ac_cn 30 | -------------------------------------------------------------------------------- /sentiment.py: -------------------------------------------------------------------------------- 1 | #! -*- coding:utf-8 -*- 2 | 3 | import json 4 | import numpy as np 5 | import pandas as pd 6 | from random import choice 7 | from keras_bert import load_trained_model_from_checkpoint, Tokenizer 8 | import re, os 9 | import codecs 10 | 11 | 12 | maxlen = 100 13 | config_path = '../bert/chinese_L-12_H-768_A-12/bert_config.json' 14 | checkpoint_path = '../bert/chinese_L-12_H-768_A-12/bert_model.ckpt' 15 | dict_path = '../bert/chinese_L-12_H-768_A-12/vocab.txt' 16 | 17 | 18 | token_dict = {} 19 | 20 | with codecs.open(dict_path, 'r', 'utf8') as reader: 21 | for line in reader: 22 | token = line.strip() 23 | token_dict[token] = len(token_dict) 24 | 25 | 26 | class OurTokenizer(Tokenizer): 27 | def _tokenize(self, text): 28 | R = [] 29 | for c in text: 30 | if c in self._token_dict: 31 | R.append(c) 32 | elif self._is_space(c): 33 | R.append('[unused1]') # space类用未经训练的[unused1]表示 34 | else: 35 | R.append('[UNK]') # 剩余的字符是[UNK] 36 | return R 37 | 38 | tokenizer = OurTokenizer(token_dict) 39 | 40 | 41 | neg = pd.read_excel('neg.xls', header=None) 42 | pos = pd.read_excel('pos.xls', header=None) 43 | 44 | data = [] 45 | 46 | for d in neg[0]: 47 | data.append((d, 0)) 48 | 49 | for d in pos[0]: 50 | data.append((d, 1)) 51 | 52 | 53 | # 按照9:1的比例划分训练集和验证集 54 | random_order = range(len(data)) 55 | np.random.shuffle(random_order) 56 | train_data = [data[j] for i, j in enumerate(random_order) if i % 10 != 0] 57 | valid_data = [data[j] for i, j in enumerate(random_order) if i % 10 == 0] 58 | 59 | 60 | def seq_padding(X, padding=0): 61 | L = [len(x) for x in X] 62 | ML = max(L) 63 | return np.array([ 64 | np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X 65 | ]) 66 | 67 | 68 | class data_generator: 69 | def __init__(self, data, batch_size=32): 70 | self.data = data 71 | self.batch_size = batch_size 72 | self.steps = len(self.data) // self.batch_size 73 | if len(self.data) % self.batch_size != 0: 74 | self.steps += 1 75 | def __len__(self): 76 | return self.steps 77 | def __iter__(self): 78 | while True: 79 | idxs = range(len(self.data)) 80 | np.random.shuffle(idxs) 81 | X1, X2, Y = [], [], [] 82 | for i in idxs: 83 | d = self.data[i] 84 | text = d[0][:maxlen] 85 | x1, x2 = tokenizer.encode(first=text) 86 | y = d[1] 87 | X1.append(x1) 88 | X2.append(x2) 89 | Y.append([y]) 90 | if len(X1) == self.batch_size or i == idxs[-1]: 91 | X1 = seq_padding(X1) 92 | X2 = seq_padding(X2) 93 | Y = seq_padding(Y) 94 | yield [X1, X2], Y 95 | [X1, X2, Y] = [], [], [] 96 | 97 | 98 | from keras.layers import * 99 | from keras.models import Model 100 | import keras.backend as K 101 | from keras.optimizers import Adam 102 | 103 | 104 | bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None) 105 | 106 | for l in bert_model.layers: 107 | l.trainable = True 108 | 109 | x1_in = Input(shape=(None,)) 110 | x2_in = Input(shape=(None,)) 111 | 112 | x = bert_model([x1_in, x2_in]) 113 | x = Lambda(lambda x: x[:, 0])(x) 114 | p = Dense(1, activation='sigmoid')(x) 115 | 116 | model = Model([x1_in, x2_in], p) 117 | model.compile( 118 | loss='binary_crossentropy', 119 | optimizer=Adam(1e-5), # 用足够小的学习率 120 | metrics=['accuracy'] 121 | ) 122 | model.summary() 123 | 124 | 125 | train_D = data_generator(train_data) 126 | valid_D = data_generator(valid_data) 127 | 128 | model.fit_generator( 129 | train_D.__iter__(), 130 | steps_per_epoch=len(train_D), 131 | epochs=5, 132 | validation_data=valid_D.__iter__(), 133 | validation_steps=len(valid_D) 134 | ) 135 | -------------------------------------------------------------------------------- /subject_extract.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | 3 | import json 4 | from tqdm import tqdm 5 | import os, re 6 | import numpy as np 7 | import pandas as pd 8 | from keras_bert import load_trained_model_from_checkpoint, Tokenizer 9 | import codecs 10 | 11 | 12 | mode = 0 13 | maxlen = 128 14 | learning_rate = 5e-5 15 | min_learning_rate = 1e-5 16 | 17 | 18 | config_path = '../../kg/bert/chinese_L-12_H-768_A-12/bert_config.json' 19 | checkpoint_path = '../../kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt' 20 | dict_path = '../../kg/bert/chinese_L-12_H-768_A-12/vocab.txt' 21 | 22 | 23 | token_dict = {} 24 | 25 | with codecs.open(dict_path, 'r', 'utf8') as reader: 26 | for line in reader: 27 | token = line.strip() 28 | token_dict[token] = len(token_dict) 29 | 30 | 31 | class OurTokenizer(Tokenizer): 32 | def _tokenize(self, text): 33 | R = [] 34 | for c in text: 35 | if c in self._token_dict: 36 | R.append(c) 37 | elif self._is_space(c): 38 | R.append('[unused1]') # space类用未经训练的[unused1]表示 39 | else: 40 | R.append('[UNK]') # 剩余的字符是[UNK] 41 | return R 42 | 43 | tokenizer = OurTokenizer(token_dict) 44 | 45 | 46 | D = pd.read_csv('../ccks2019_event_entity_extract/event_type_entity_extract_train.csv', encoding='utf-8', header=None) 47 | D = D[D[2] != u'其他'] 48 | classes = set(D[2].unique()) 49 | 50 | 51 | train_data = [] 52 | for t,c,n in zip(D[1], D[2], D[3]): 53 | train_data.append((t, c, n)) 54 | 55 | 56 | if not os.path.exists('../random_order_train.json'): 57 | random_order = range(len(train_data)) 58 | np.random.shuffle(random_order) 59 | json.dump( 60 | random_order, 61 | open('../random_order_train.json', 'w'), 62 | indent=4 63 | ) 64 | else: 65 | random_order = json.load(open('../random_order_train.json')) 66 | 67 | 68 | dev_data = [train_data[j] for i, j in enumerate(random_order) if i % 9 == mode] 69 | train_data = [train_data[j] for i, j in enumerate(random_order) if i % 9 != mode] 70 | additional_chars = set() 71 | for d in train_data + dev_data: 72 | additional_chars.update(re.findall(u'[^\u4e00-\u9fa5a-zA-Z0-9\*]', d[2])) 73 | 74 | additional_chars.remove(u',') 75 | 76 | 77 | D = pd.read_csv('../ccks2019_event_entity_extract/event_type_entity_extract_eval.csv', encoding='utf-8', header=None) 78 | test_data = [] 79 | for id,t,c in zip(D[0], D[1], D[2]): 80 | test_data.append((id, t, c)) 81 | 82 | 83 | def seq_padding(X, padding=0): 84 | L = [len(x) for x in X] 85 | ML = max(L) 86 | return np.array([ 87 | np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X 88 | ]) 89 | 90 | 91 | def list_find(list1, list2): 92 | """在list1中寻找子串list2,如果找到,返回第一个下标; 93 | 如果找不到,返回-1。 94 | """ 95 | n_list2 = len(list2) 96 | for i in range(len(list1)): 97 | if list1[i: i+n_list2] == list2: 98 | return i 99 | return -1 100 | 101 | 102 | class data_generator: 103 | def __init__(self, data, batch_size=32): 104 | self.data = data 105 | self.batch_size = batch_size 106 | self.steps = len(self.data) // self.batch_size 107 | if len(self.data) % self.batch_size != 0: 108 | self.steps += 1 109 | def __len__(self): 110 | return self.steps 111 | def __iter__(self): 112 | while True: 113 | idxs = range(len(self.data)) 114 | np.random.shuffle(idxs) 115 | X1, X2, S1, S2 = [], [], [], [] 116 | for i in idxs: 117 | d = self.data[i] 118 | text, c = d[0][:maxlen], d[1] 119 | text = u'___%s___%s' % (c, text) 120 | tokens = tokenizer.tokenize(text) 121 | e = d[2] 122 | e_tokens = tokenizer.tokenize(e)[1:-1] 123 | s1, s2 = np.zeros(len(tokens)), np.zeros(len(tokens)) 124 | start = list_find(tokens, e_tokens) 125 | if start != -1: 126 | end = start + len(e_tokens) - 1 127 | s1[start] = 1 128 | s2[end] = 1 129 | x1, x2 = tokenizer.encode(first=text) 130 | X1.append(x1) 131 | X2.append(x2) 132 | S1.append(s1) 133 | S2.append(s2) 134 | if len(X1) == self.batch_size or i == idxs[-1]: 135 | X1 = seq_padding(X1) 136 | X2 = seq_padding(X2) 137 | S1 = seq_padding(S1) 138 | S2 = seq_padding(S2) 139 | yield [X1, X2, S1, S2], None 140 | X1, X2, S1, S2 = [], [], [], [] 141 | 142 | 143 | from keras.layers import * 144 | from keras.models import Model 145 | import keras.backend as K 146 | from keras.callbacks import Callback 147 | from keras.optimizers import Adam 148 | 149 | 150 | bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None) 151 | 152 | for l in bert_model.layers: 153 | l.trainable = True 154 | 155 | 156 | x1_in = Input(shape=(None,)) # 待识别句子输入 157 | x2_in = Input(shape=(None,)) # 待识别句子输入 158 | s1_in = Input(shape=(None,)) # 实体左边界(标签) 159 | s2_in = Input(shape=(None,)) # 实体右边界(标签) 160 | 161 | x1, x2, s1, s2 = x1_in, x2_in, s1_in, s2_in 162 | x_mask = Lambda(lambda x: K.cast(K.greater(K.expand_dims(x, 2), 0), 'float32'))(x1) 163 | 164 | x = bert_model([x1, x2]) 165 | ps1 = Dense(1, use_bias=False)(x) 166 | ps1 = Lambda(lambda x: x[0][..., 0] - (1 - x[1][..., 0]) * 1e10)([ps1, x_mask]) 167 | ps2 = Dense(1, use_bias=False)(x) 168 | ps2 = Lambda(lambda x: x[0][..., 0] - (1 - x[1][..., 0]) * 1e10)([ps2, x_mask]) 169 | 170 | model = Model([x1_in, x2_in], [ps1, ps2]) 171 | 172 | 173 | train_model = Model([x1_in, x2_in, s1_in, s2_in], [ps1, ps2]) 174 | 175 | loss1 = K.mean(K.categorical_crossentropy(s1_in, ps1, from_logits=True)) 176 | ps2 -= (1 - K.cumsum(s1, 1)) * 1e10 177 | loss2 = K.mean(K.categorical_crossentropy(s2_in, ps2, from_logits=True)) 178 | loss = loss1 + loss2 179 | 180 | train_model.add_loss(loss) 181 | train_model.compile(optimizer=Adam(learning_rate)) 182 | train_model.summary() 183 | 184 | 185 | def softmax(x): 186 | x = x - np.max(x) 187 | x = np.exp(x) 188 | return x / np.sum(x) 189 | 190 | 191 | def extract_entity(text_in, c_in): 192 | if c_in not in classes: 193 | return 'NaN' 194 | text_in = u'___%s___%s' % (c_in, text_in) 195 | text_in = text_in[:510] 196 | _tokens = tokenizer.tokenize(text_in) 197 | _x1, _x2 = tokenizer.encode(first=text_in) 198 | _x1, _x2 = np.array([_x1]), np.array([_x2]) 199 | _ps1, _ps2 = model.predict([_x1, _x2]) 200 | _ps1, _ps2 = softmax(_ps1[0]), softmax(_ps2[0]) 201 | for i, _t in enumerate(_tokens): 202 | if len(_t) == 1 and re.findall(u'[^\u4e00-\u9fa5a-zA-Z0-9\*]', _t) and _t not in additional_chars: 203 | _ps1[i] -= 10 204 | start = _ps1.argmax() 205 | for end in range(start, len(_tokens)): 206 | _t = _tokens[end] 207 | if len(_t) == 1 and re.findall(u'[^\u4e00-\u9fa5a-zA-Z0-9\*]', _t) and _t not in additional_chars: 208 | break 209 | end = _ps2[start:end+1].argmax() + start 210 | a = text_in[start-1: end] 211 | return a 212 | 213 | 214 | class Evaluate(Callback): 215 | def __init__(self): 216 | self.ACC = [] 217 | self.best = 0. 218 | self.passed = 0 219 | def on_batch_begin(self, batch, logs=None): 220 | """第一个epoch用来warmup,第二个epoch把学习率降到最低 221 | """ 222 | if self.passed < self.params['steps']: 223 | lr = (self.passed + 1.) / self.params['steps'] * learning_rate 224 | K.set_value(self.model.optimizer.lr, lr) 225 | self.passed += 1 226 | elif self.params['steps'] <= self.passed < self.params['steps'] * 2: 227 | lr = (2 - (self.passed + 1.) / self.params['steps']) * (learning_rate - min_learning_rate) 228 | lr += min_learning_rate 229 | K.set_value(self.model.optimizer.lr, lr) 230 | self.passed += 1 231 | def on_epoch_end(self, epoch, logs=None): 232 | acc = self.evaluate() 233 | self.ACC.append(acc) 234 | if acc > self.best: 235 | self.best = acc 236 | train_model.save_weights('best_model.weights') 237 | print 'acc: %.4f, best acc: %.4f\n' % (acc, self.best) 238 | def evaluate(self): 239 | A = 1e-10 240 | F = open('dev_pred.json', 'w') 241 | for d in tqdm(iter(dev_data)): 242 | R = extract_entity(d[0], d[1]) 243 | if R == d[2]: 244 | A += 1 245 | s = ', '.join(d + (R,)) 246 | F.write(s.encode('utf-8') + '\n') 247 | F.close() 248 | return A / len(dev_data) 249 | 250 | 251 | def test(test_data): 252 | F = open('result.txt', 'w') 253 | for d in tqdm(iter(test_data)): 254 | s = u'"%s","%s"\n' % (d[0], extract_entity(d[1], d[2])) 255 | s = s.encode('utf-8') 256 | F.write(s) 257 | F.close() 258 | 259 | 260 | evaluator = Evaluate() 261 | train_D = data_generator(train_data) 262 | 263 | 264 | if __name__ == '__main__': 265 | train_model.fit_generator(train_D.__iter__(), 266 | steps_per_epoch=len(train_D), 267 | epochs=10, 268 | callbacks=[evaluator] 269 | ) 270 | else: 271 | train_model.load_weights('best_model.weights') 272 | -------------------------------------------------------------------------------- /relation_extract.py: -------------------------------------------------------------------------------- 1 | #! -*- coding:utf-8 -*- 2 | 3 | import json 4 | import numpy as np 5 | from random import choice 6 | from tqdm import tqdm 7 | from keras_bert import load_trained_model_from_checkpoint, Tokenizer 8 | import re, os 9 | import codecs 10 | 11 | 12 | mode = 0 13 | maxlen = 160 14 | learning_rate = 5e-5 15 | min_learning_rate = 1e-5 16 | 17 | config_path = '../bert/chinese_L-12_H-768_A-12/bert_config.json' 18 | checkpoint_path = '../bert/chinese_L-12_H-768_A-12/bert_model.ckpt' 19 | dict_path = '../bert/chinese_L-12_H-768_A-12/vocab.txt' 20 | 21 | 22 | token_dict = {} 23 | 24 | with codecs.open(dict_path, 'r', 'utf8') as reader: 25 | for line in reader: 26 | token = line.strip() 27 | token_dict[token] = len(token_dict) 28 | 29 | 30 | class OurTokenizer(Tokenizer): 31 | def _tokenize(self, text): 32 | R = [] 33 | for c in text: 34 | if c in self._token_dict: 35 | R.append(c) 36 | elif self._is_space(c): 37 | R.append('[unused1]') # space类用未经训练的[unused1]表示 38 | else: 39 | R.append('[UNK]') # 剩余的字符是[UNK] 40 | return R 41 | 42 | tokenizer = OurTokenizer(token_dict) 43 | 44 | 45 | train_data = json.load(open('../datasets/train_data_me.json')) 46 | dev_data = json.load(open('../datasets/dev_data_me.json')) 47 | id2predicate, predicate2id = json.load(open('../datasets/all_50_schemas_me.json')) 48 | id2predicate = {int(i):j for i,j in id2predicate.items()} 49 | num_classes = len(id2predicate) 50 | 51 | 52 | total_data = [] 53 | total_data.extend(train_data) 54 | total_data.extend(dev_data) 55 | 56 | 57 | if not os.path.exists('../random_order_train_dev.json'): 58 | random_order = range(len(total_data)) 59 | np.random.shuffle(random_order) 60 | json.dump( 61 | random_order, 62 | open('../random_order_train_dev.json', 'w'), 63 | indent=4 64 | ) 65 | else: 66 | random_order = json.load(open('../random_order_train_dev.json')) 67 | 68 | 69 | train_data = [total_data[j] for i, j in enumerate(random_order) if i % 8 != mode] 70 | dev_data = [total_data[j] for i, j in enumerate(random_order) if i % 8 == mode] 71 | 72 | 73 | predicates = {} # 格式:{predicate: [(subject, predicate, object)]} 74 | 75 | 76 | def repair(d): 77 | d['text'] = d['text'].lower() 78 | something = re.findall(u'《([^《》]*?)》', d['text']) 79 | something = [s.strip() for s in something] 80 | zhuanji = [] 81 | gequ = [] 82 | for sp in d['spo_list']: 83 | sp[0] = sp[0].strip(u'《》').strip().lower() 84 | sp[2] = sp[2].strip(u'《》').strip().lower() 85 | for some in something: 86 | if sp[0] in some and d['text'].count(sp[0]) == 1: 87 | sp[0] = some 88 | if sp[1] == u'所属专辑': 89 | zhuanji.append(sp[2]) 90 | gequ.append(sp[0]) 91 | spo_list = [] 92 | for sp in d['spo_list']: 93 | if sp[1] in [u'歌手', u'作词', u'作曲']: 94 | if sp[0] in zhuanji and sp[0] not in gequ: 95 | continue 96 | spo_list.append(tuple(sp)) 97 | d['spo_list'] = spo_list 98 | 99 | 100 | for d in train_data: 101 | repair(d) 102 | for sp in d['spo_list']: 103 | if sp[1] not in predicates: 104 | predicates[sp[1]] = [] 105 | predicates[sp[1]].append(sp) 106 | 107 | 108 | for d in dev_data: 109 | repair(d) 110 | 111 | 112 | def seq_padding(X, padding=0): 113 | L = [len(x) for x in X] 114 | ML = max(L) 115 | return np.array([ 116 | np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X 117 | ]) 118 | 119 | 120 | def list_find(list1, list2): 121 | """在list1中寻找子串list2,如果找到,返回第一个下标; 122 | 如果找不到,返回-1。 123 | """ 124 | n_list2 = len(list2) 125 | for i in range(len(list1)): 126 | if list1[i: i+n_list2] == list2: 127 | return i 128 | return -1 129 | 130 | 131 | class data_generator: 132 | def __init__(self, data, batch_size=32): 133 | self.data = data 134 | self.batch_size = batch_size 135 | self.steps = len(self.data) // self.batch_size 136 | if len(self.data) % self.batch_size != 0: 137 | self.steps += 1 138 | def __len__(self): 139 | return self.steps 140 | def __iter__(self): 141 | while True: 142 | idxs = range(len(self.data)) 143 | np.random.shuffle(idxs) 144 | T1, T2, S1, S2, K1, K2, O1, O2 = [], [], [], [], [], [], [], [] 145 | for i in idxs: 146 | d = self.data[i] 147 | text = d['text'][:maxlen] 148 | tokens = tokenizer.tokenize(text) 149 | items = {} 150 | for sp in d['spo_list']: 151 | sp = (tokenizer.tokenize(sp[0])[1:-1], sp[1], tokenizer.tokenize(sp[2])[1:-1]) 152 | subjectid = list_find(tokens, sp[0]) 153 | objectid = list_find(tokens, sp[2]) 154 | if subjectid != -1 and objectid != -1: 155 | key = (subjectid, subjectid+len(sp[0])) 156 | if key not in items: 157 | items[key] = [] 158 | items[key].append((objectid, 159 | objectid+len(sp[2]), 160 | predicate2id[sp[1]])) 161 | if items: 162 | t1, t2 = tokenizer.encode(first=text) 163 | T1.append(t1) 164 | T2.append(t2) 165 | s1, s2 = np.zeros(len(tokens)), np.zeros(len(tokens)) 166 | for j in items: 167 | s1[j[0]] = 1 168 | s2[j[1]-1] = 1 169 | k1, k2 = np.array(items.keys()).T 170 | k1 = choice(k1) 171 | k2 = choice(k2[k2 >= k1]) 172 | o1, o2 = np.zeros((len(tokens), num_classes)), np.zeros((len(tokens), num_classes)) 173 | for j in items.get((k1, k2), []): 174 | o1[j[0]][j[2]] = 1 175 | o2[j[1]-1][j[2]] = 1 176 | S1.append(s1) 177 | S2.append(s2) 178 | K1.append([k1]) 179 | K2.append([k2-1]) 180 | O1.append(o1) 181 | O2.append(o2) 182 | if len(T1) == self.batch_size or i == idxs[-1]: 183 | T1 = seq_padding(T1) 184 | T2 = seq_padding(T2) 185 | S1 = seq_padding(S1) 186 | S2 = seq_padding(S2) 187 | O1 = seq_padding(O1, np.zeros(num_classes)) 188 | O2 = seq_padding(O2, np.zeros(num_classes)) 189 | K1, K2 = np.array(K1), np.array(K2) 190 | yield [T1, T2, S1, S2, K1, K2, O1, O2], None 191 | T1, T2, S1, S2, K1, K2, O1, O2, = [], [], [], [], [], [], [], [] 192 | 193 | 194 | from keras.layers import * 195 | from keras.models import Model 196 | import keras.backend as K 197 | from keras.callbacks import Callback 198 | from keras.optimizers import Adam 199 | 200 | 201 | def seq_gather(x): 202 | """seq是[None, seq_len, s_size]的格式, 203 | idxs是[None, 1]的格式,在seq的第i个序列中选出第idxs[i]个向量, 204 | 最终输出[None, s_size]的向量。 205 | """ 206 | seq, idxs = x 207 | idxs = K.cast(idxs, 'int32') 208 | batch_idxs = K.arange(0, K.shape(seq)[0]) 209 | batch_idxs = K.expand_dims(batch_idxs, 1) 210 | idxs = K.concatenate([batch_idxs, idxs], 1) 211 | return K.tf.gather_nd(seq, idxs) 212 | 213 | 214 | bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None) 215 | 216 | for l in bert_model.layers: 217 | l.trainable = True 218 | 219 | 220 | t1_in = Input(shape=(None,)) 221 | t2_in = Input(shape=(None,)) 222 | s1_in = Input(shape=(None,)) 223 | s2_in = Input(shape=(None,)) 224 | k1_in = Input(shape=(1,)) 225 | k2_in = Input(shape=(1,)) 226 | o1_in = Input(shape=(None, num_classes)) 227 | o2_in = Input(shape=(None, num_classes)) 228 | 229 | t1, t2, s1, s2, k1, k2, o1, o2 = t1_in, t2_in, s1_in, s2_in, k1_in, k2_in, o1_in, o2_in 230 | mask = Lambda(lambda x: K.cast(K.greater(K.expand_dims(x, 2), 0), 'float32'))(t1) 231 | 232 | t = bert_model([t1, t2]) 233 | ps1 = Dense(1, activation='sigmoid')(t) 234 | ps2 = Dense(1, activation='sigmoid')(t) 235 | 236 | subject_model = Model([t1_in, t2_in], [ps1, ps2]) # 预测subject的模型 237 | 238 | 239 | k1v = Lambda(seq_gather)([t, k1]) 240 | k2v = Lambda(seq_gather)([t, k2]) 241 | kv = Average()([k1v, k2v]) 242 | t = Add()([t, kv]) 243 | po1 = Dense(num_classes, activation='sigmoid')(t) 244 | po2 = Dense(num_classes, activation='sigmoid')(t) 245 | 246 | object_model = Model([t1_in, t2_in, k1_in, k2_in], [po1, po2]) # 输入text和subject,预测object及其关系 247 | 248 | 249 | train_model = Model([t1_in, t2_in, s1_in, s2_in, k1_in, k2_in, o1_in, o2_in], 250 | [ps1, ps2, po1, po2]) 251 | 252 | s1 = K.expand_dims(s1, 2) 253 | s2 = K.expand_dims(s2, 2) 254 | 255 | s1_loss = K.binary_crossentropy(s1, ps1) 256 | s1_loss = K.sum(s1_loss * mask) / K.sum(mask) 257 | s2_loss = K.binary_crossentropy(s2, ps2) 258 | s2_loss = K.sum(s2_loss * mask) / K.sum(mask) 259 | 260 | o1_loss = K.sum(K.binary_crossentropy(o1, po1), 2, keepdims=True) 261 | o1_loss = K.sum(o1_loss * mask) / K.sum(mask) 262 | o2_loss = K.sum(K.binary_crossentropy(o2, po2), 2, keepdims=True) 263 | o2_loss = K.sum(o2_loss * mask) / K.sum(mask) 264 | 265 | loss = (s1_loss + s2_loss) + (o1_loss + o2_loss) 266 | 267 | train_model.add_loss(loss) 268 | train_model.compile(optimizer=Adam(learning_rate)) 269 | train_model.summary() 270 | 271 | 272 | def extract_items(text_in): 273 | _tokens = tokenizer.tokenize(text_in) 274 | _t1, _t2 = tokenizer.encode(first=text_in) 275 | _t1, _t2 = np.array([_t1]), np.array([_t2]) 276 | _k1, _k2 = subject_model.predict([_t1, _t2]) 277 | _k1, _k2 = np.where(_k1[0] > 0.5)[0], np.where(_k2[0] > 0.4)[0] 278 | _subjects = [] 279 | for i in _k1: 280 | j = _k2[_k2 >= i] 281 | if len(j) > 0: 282 | j = j[0] 283 | _subject = text_in[i-1: j] 284 | _subjects.append((_subject, i, j)) 285 | if _subjects: 286 | R = [] 287 | _t1 = np.repeat(_t1, len(_subjects), 0) 288 | _t2 = np.repeat(_t2, len(_subjects), 0) 289 | _k1, _k2 = np.array([_s[1:] for _s in _subjects]).T.reshape((2, -1, 1)) 290 | _o1, _o2 = object_model.predict([_t1, _t2, _k1, _k2]) 291 | for i,_subject in enumerate(_subjects): 292 | _oo1, _oo2 = np.where(_o1[i] > 0.5), np.where(_o2[i] > 0.4) 293 | for _ooo1, _c1 in zip(*_oo1): 294 | for _ooo2, _c2 in zip(*_oo2): 295 | if _ooo1 <= _ooo2 and _c1 == _c2: 296 | _object = text_in[_ooo1-1: _ooo2] 297 | _predicate = id2predicate[_c1] 298 | R.append((_subject[0], _predicate, _object)) 299 | break 300 | zhuanji, gequ = [], [] 301 | for s, p, o in R[:]: 302 | if p == u'妻子': 303 | R.append((o, u'丈夫', s)) 304 | elif p == u'丈夫': 305 | R.append((o, u'妻子', s)) 306 | if p == u'所属专辑': 307 | zhuanji.append(o) 308 | gequ.append(s) 309 | spo_list = set() 310 | for s, p, o in R: 311 | if p in [u'歌手', u'作词', u'作曲']: 312 | if s in zhuanji and s not in gequ: 313 | continue 314 | spo_list.add((s, p, o)) 315 | return list(spo_list) 316 | else: 317 | return [] 318 | 319 | 320 | class Evaluate(Callback): 321 | def __init__(self): 322 | self.F1 = [] 323 | self.best = 0. 324 | self.passed = 0 325 | self.stage = 0 326 | def on_batch_begin(self, batch, logs=None): 327 | """第一个epoch用来warmup,第二个epoch把学习率降到最低 328 | """ 329 | if self.passed < self.params['steps']: 330 | lr = (self.passed + 1.) / self.params['steps'] * learning_rate 331 | K.set_value(self.model.optimizer.lr, lr) 332 | self.passed += 1 333 | elif self.params['steps'] <= self.passed < self.params['steps'] * 2: 334 | lr = (2 - (self.passed + 1.) / self.params['steps']) * (learning_rate - min_learning_rate) 335 | lr += min_learning_rate 336 | K.set_value(self.model.optimizer.lr, lr) 337 | self.passed += 1 338 | def on_epoch_end(self, epoch, logs=None): 339 | f1, precision, recall = self.evaluate() 340 | self.F1.append(f1) 341 | if f1 > self.best: 342 | self.best = f1 343 | train_model.save_weights('best_model.weights') 344 | print 'f1: %.4f, precision: %.4f, recall: %.4f, best f1: %.4f\n' % (f1, precision, recall, self.best) 345 | def evaluate(self): 346 | orders = ['subject', 'predicate', 'object'] 347 | A, B, C = 1e-10, 1e-10, 1e-10 348 | F = open('dev_pred.json', 'w') 349 | for d in tqdm(iter(dev_data)): 350 | R = set(extract_items(d['text'])) 351 | T = set(d['spo_list']) 352 | A += len(R & T) 353 | B += len(R) 354 | C += len(T) 355 | s = json.dumps({ 356 | 'text': d['text'], 357 | 'spo_list': [ 358 | dict(zip(orders, spo)) for spo in T 359 | ], 360 | 'spo_list_pred': [ 361 | dict(zip(orders, spo)) for spo in R 362 | ], 363 | 'new': [ 364 | dict(zip(orders, spo)) for spo in R - T 365 | ], 366 | 'lack': [ 367 | dict(zip(orders, spo)) for spo in T - R 368 | ] 369 | }, ensure_ascii=False, indent=4) 370 | F.write(s.encode('utf-8') + '\n') 371 | F.close() 372 | return 2 * A / (B + C), A / B, A / C 373 | 374 | 375 | def test(test_data): 376 | """输出测试结果 377 | """ 378 | orders = ['subject', 'predicate', 'object', 'object_type', 'subject_type'] 379 | F = open('test_pred.json', 'w') 380 | for d in tqdm(iter(test_data)): 381 | R = set(extract_items(d['text'])) 382 | s = json.dumps({ 383 | 'text': d['text'], 384 | 'spo_list': [ 385 | dict(zip(orders, spo + ('', ''))) for spo in R 386 | ] 387 | }, ensure_ascii=False) 388 | F.write(s.encode('utf-8') + '\n') 389 | F.close() 390 | 391 | 392 | train_D = data_generator(train_data) 393 | evaluator = Evaluate() 394 | 395 | 396 | if __name__ == '__main__': 397 | train_model.fit_generator(train_D.__iter__(), 398 | steps_per_epoch=1000, 399 | epochs=30, 400 | callbacks=[evaluator] 401 | ) 402 | else: 403 | train_model.load_weights('best_model.weights') 404 | -------------------------------------------------------------------------------- /nl2sql_baseline.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 追一科技2019年NL2SQL挑战赛的一个Baseline(个人作品,非官方发布,基于Bert) 3 | # 比赛地址:https://tianchi.aliyun.com/competition/entrance/231716/introduction 4 | # 目前全匹配率大概是58%左右 5 | 6 | import json 7 | import uniout 8 | from keras_bert import load_trained_model_from_checkpoint, Tokenizer 9 | import codecs 10 | from keras.layers import * 11 | from keras.models import Model 12 | import keras.backend as K 13 | from keras.optimizers import Adam 14 | from keras.callbacks import Callback 15 | from tqdm import tqdm 16 | import jieba 17 | import editdistance 18 | import re 19 | 20 | 21 | maxlen = 160 22 | num_agg = 7 # agg_sql_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM", 6:"不被select"} 23 | num_op = 5 # {0:">", 1:"<", 2:"==", 3:"!=", 4:"不被select"} 24 | num_cond_conn_op = 3 # conn_sql_dict = {0:"", 1:"and", 2:"or"} 25 | learning_rate = 5e-5 26 | min_learning_rate = 1e-5 27 | 28 | 29 | config_path = '../../kg/bert/chinese_wwm_L-12_H-768_A-12/bert_config.json' 30 | checkpoint_path = '../../kg/bert/chinese_wwm_L-12_H-768_A-12/bert_model.ckpt' 31 | dict_path = '../../kg/bert/chinese_wwm_L-12_H-768_A-12/vocab.txt' 32 | 33 | 34 | def read_data(data_file, table_file): 35 | data, tables = [], {} 36 | with open(data_file) as f: 37 | for l in f: 38 | data.append(json.loads(l)) 39 | with open(table_file) as f: 40 | for l in f: 41 | l = json.loads(l) 42 | d = {} 43 | d['headers'] = l['header'] 44 | d['header2id'] = {j: i for i, j in enumerate(d['headers'])} 45 | d['content'] = {} 46 | d['all_values'] = set() 47 | rows = np.array(l['rows']) 48 | for i, h in enumerate(d['headers']): 49 | d['content'][h] = set(rows[:, i]) 50 | d['all_values'].update(d['content'][h]) 51 | d['all_values'] = set([i for i in d['all_values'] if hasattr(i, '__len__')]) 52 | tables[l['id']] = d 53 | return data, tables 54 | 55 | 56 | train_data, train_tables = read_data('../datasets/train.json', '../datasets/train.tables.json') 57 | valid_data, valid_tables = read_data('../datasets/val.json', '../datasets/val.tables.json') 58 | test_data, test_tables = read_data('../datasets/test.json', '../datasets/test.tables.json') 59 | 60 | 61 | token_dict = {} 62 | 63 | with codecs.open(dict_path, 'r', 'utf8') as reader: 64 | for line in reader: 65 | token = line.strip() 66 | token_dict[token] = len(token_dict) 67 | 68 | 69 | class OurTokenizer(Tokenizer): 70 | def _tokenize(self, text): 71 | R = [] 72 | for c in text: 73 | if c in self._token_dict: 74 | R.append(c) 75 | elif self._is_space(c): 76 | R.append('[unused1]') # space类用未经训练的[unused1]表示 77 | else: 78 | R.append('[UNK]') # 剩余的字符是[UNK] 79 | return R 80 | 81 | tokenizer = OurTokenizer(token_dict) 82 | 83 | 84 | def seq_padding(X, padding=0, maxlen=None): 85 | if maxlen is None: 86 | L = [len(x) for x in X] 87 | ML = max(L) 88 | else: 89 | ML = maxlen 90 | return np.array([ 91 | np.concatenate([x[:ML], [padding] * (ML - len(x))]) if len(x[:ML]) < ML else x for x in X 92 | ]) 93 | 94 | 95 | def most_similar(s, slist): 96 | """从词表中找最相近的词(当无法全匹配的时候) 97 | """ 98 | if len(slist) == 0: 99 | return s 100 | scores = [editdistance.eval(s, t) for t in slist] 101 | return slist[np.argmin(scores)] 102 | 103 | 104 | def most_similar_2(w, s): 105 | """从句子s中找与w最相近的片段, 106 | 借助分词工具和ngram的方式尽量精确地确定边界。 107 | """ 108 | sw = jieba.lcut(s) 109 | sl = list(sw) 110 | sl.extend([''.join(i) for i in zip(sw, sw[1:])]) 111 | sl.extend([''.join(i) for i in zip(sw, sw[1:], sw[2:])]) 112 | return most_similar(w, sl) 113 | 114 | 115 | class data_generator: 116 | def __init__(self, data, tables, batch_size=32): 117 | self.data = data 118 | self.tables = tables 119 | self.batch_size = batch_size 120 | self.steps = len(self.data) // self.batch_size 121 | if len(self.data) % self.batch_size != 0: 122 | self.steps += 1 123 | def __len__(self): 124 | return self.steps 125 | def __iter__(self): 126 | while True: 127 | idxs = range(len(self.data)) 128 | np.random.shuffle(idxs) 129 | X1, X2, XM, H, HM, SEL, CONN, CSEL, COP = [], [], [], [], [], [], [], [], [] 130 | for i in idxs: 131 | d = self.data[i] 132 | t = self.tables[d['table_id']]['headers'] 133 | x1, x2 = tokenizer.encode(d['question']) 134 | xm = [0] + [1] * len(d['question']) + [0] 135 | h = [] 136 | for j in t: 137 | _x1, _x2 = tokenizer.encode(j) 138 | h.append(len(x1)) 139 | x1.extend(_x1) 140 | x2.extend(_x2) 141 | hm = [1] * len(h) 142 | sel = [] 143 | for j in range(len(h)): 144 | if j in d['sql']['sel']: 145 | j = d['sql']['sel'].index(j) 146 | sel.append(d['sql']['agg'][j]) 147 | else: 148 | sel.append(num_agg - 1) # 不被select则被标记为num_agg-1 149 | conn = [d['sql']['cond_conn_op']] 150 | csel = np.zeros(len(d['question']) + 2, dtype='int32') # 这里的0既表示padding,又表示第一列,padding部分训练时会被mask 151 | cop = np.zeros(len(d['question']) + 2, dtype='int32') + num_op - 1 # 不被select则被标记为num_op-1 152 | for j in d['sql']['conds']: 153 | if j[2] not in d['question']: 154 | j[2] = most_similar_2(j[2], d['question']) 155 | if j[2] not in d['question']: 156 | continue 157 | k = d['question'].index(j[2]) 158 | csel[k + 1: k + 1 + len(j[2])] = j[0] 159 | cop[k + 1: k + 1 + len(j[2])] = j[1] 160 | if len(x1) > maxlen: 161 | continue 162 | X1.append(x1) # bert的输入 163 | X2.append(x2) # bert的输入 164 | XM.append(xm) # 输入序列的mask 165 | H.append(h) # 列名所在位置 166 | HM.append(hm) # 列名mask 167 | SEL.append(sel) # 被select的列 168 | CONN.append(conn) # 连接类型 169 | CSEL.append(csel) # 条件中的列 170 | COP.append(cop) # 条件中的运算符(同时也是值的标记) 171 | if len(X1) == self.batch_size: 172 | X1 = seq_padding(X1) 173 | X2 = seq_padding(X2) 174 | XM = seq_padding(XM, maxlen=X1.shape[1]) 175 | H = seq_padding(H) 176 | HM = seq_padding(HM) 177 | SEL = seq_padding(SEL) 178 | CONN = seq_padding(CONN) 179 | CSEL = seq_padding(CSEL, maxlen=X1.shape[1]) 180 | COP = seq_padding(COP, maxlen=X1.shape[1]) 181 | yield [X1, X2, XM, H, HM, SEL, CONN, CSEL, COP], None 182 | X1, X2, XM, H, HM, SEL, CONN, CSEL, COP = [], [], [], [], [], [], [], [], [] 183 | 184 | 185 | def seq_gather(x): 186 | """seq是[None, seq_len, s_size]的格式, 187 | idxs是[None, n]的格式,在seq的第i个序列中选出第idxs[i]个向量, 188 | 最终输出[None, n, s_size]的向量。 189 | """ 190 | seq, idxs = x 191 | idxs = K.cast(idxs, 'int32') 192 | return K.tf.batch_gather(seq, idxs) 193 | 194 | 195 | bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None) 196 | 197 | for l in bert_model.layers: 198 | l.trainable = True 199 | 200 | 201 | x1_in = Input(shape=(None,), dtype='int32') 202 | x2_in = Input(shape=(None,)) 203 | xm_in = Input(shape=(None,)) 204 | h_in = Input(shape=(None,), dtype='int32') 205 | hm_in = Input(shape=(None,)) 206 | sel_in = Input(shape=(None,), dtype='int32') 207 | conn_in = Input(shape=(1,), dtype='int32') 208 | csel_in = Input(shape=(None,), dtype='int32') 209 | cop_in = Input(shape=(None,), dtype='int32') 210 | 211 | x1, x2, xm, h, hm, sel, conn, csel, cop = ( 212 | x1_in, x2_in, xm_in, h_in, hm_in, sel_in, conn_in, csel_in, cop_in 213 | ) 214 | 215 | hm = Lambda(lambda x: K.expand_dims(x, 1))(hm) # header的mask.shape=(None, 1, h_len) 216 | 217 | x = bert_model([x1_in, x2_in]) 218 | x4conn = Lambda(lambda x: x[:, 0])(x) 219 | pconn = Dense(num_cond_conn_op, activation='softmax')(x4conn) 220 | 221 | x4h = Lambda(seq_gather)([x, h]) 222 | psel = Dense(num_agg, activation='softmax')(x4h) 223 | 224 | pcop = Dense(num_op, activation='softmax')(x) 225 | 226 | x = Lambda(lambda x: K.expand_dims(x, 2))(x) 227 | x4h = Lambda(lambda x: K.expand_dims(x, 1))(x4h) 228 | pcsel_1 = Dense(256)(x) 229 | pcsel_2 = Dense(256)(x4h) 230 | pcsel = Lambda(lambda x: x[0] + x[1])([pcsel_1, pcsel_2]) 231 | pcsel = Activation('tanh')(pcsel) 232 | pcsel = Dense(1)(pcsel) 233 | pcsel = Lambda(lambda x: x[0][..., 0] - (1 - x[1]) * 1e10)([pcsel, hm]) 234 | pcsel = Activation('softmax')(pcsel) 235 | 236 | model = Model( 237 | [x1_in, x2_in, h_in, hm_in], 238 | [psel, pconn, pcop, pcsel] 239 | ) 240 | 241 | train_model = Model( 242 | [x1_in, x2_in, xm_in, h_in, hm_in, sel_in, conn_in, csel_in, cop_in], 243 | [psel, pconn, pcop, pcsel] 244 | ) 245 | 246 | xm = xm # question的mask.shape=(None, x_len) 247 | hm = hm[:, 0] # header的mask.shape=(None, h_len) 248 | cm = K.cast(K.not_equal(cop, num_op - 1), 'float32') # conds的mask.shape=(None, x_len) 249 | 250 | psel_loss = K.sparse_categorical_crossentropy(sel_in, psel) 251 | psel_loss = K.sum(psel_loss * hm) / K.sum(hm) 252 | pconn_loss = K.sparse_categorical_crossentropy(conn_in, pconn) 253 | pconn_loss = K.mean(pconn_loss) 254 | pcop_loss = K.sparse_categorical_crossentropy(cop_in, pcop) 255 | pcop_loss = K.sum(pcop_loss * xm) / K.sum(xm) 256 | pcsel_loss = K.sparse_categorical_crossentropy(csel_in, pcsel) 257 | pcsel_loss = K.sum(pcsel_loss * xm * cm) / K.sum(xm * cm) 258 | loss = psel_loss + pconn_loss + pcop_loss + pcsel_loss 259 | 260 | train_model.add_loss(loss) 261 | train_model.compile(optimizer=Adam(learning_rate)) 262 | train_model.summary() 263 | 264 | 265 | def nl2sql(question, table): 266 | """输入question和headers,转SQL 267 | """ 268 | x1, x2 = tokenizer.encode(question) 269 | h = [] 270 | for i in table['headers']: 271 | _x1, _x2 = tokenizer.encode(i) 272 | h.append(len(x1)) 273 | x1.extend(_x1) 274 | x2.extend(_x2) 275 | hm = [1] * len(h) 276 | psel, pconn, pcop, pcsel = model.predict([ 277 | np.array([x1]), 278 | np.array([x2]), 279 | np.array([h]), 280 | np.array([hm]) 281 | ]) 282 | R = {'agg': [], 'sel': []} 283 | for i, j in enumerate(psel[0].argmax(1)): 284 | if j != num_agg - 1: # num_agg-1类是不被select的意思 285 | R['sel'].append(i) 286 | R['agg'].append(j) 287 | conds = [] 288 | v_op = -1 289 | for i, j in enumerate(pcop[0, :len(question)+1].argmax(1)): 290 | # 这里结合标注和分类来预测条件 291 | if j != num_op - 1: 292 | if v_op != j: 293 | if v_op != -1: 294 | v_end = v_start + len(v_str) 295 | csel = pcsel[0][v_start: v_end].mean(0).argmax() 296 | conds.append((csel, v_op, v_str)) 297 | v_start = i 298 | v_op = j 299 | v_str = question[i - 1] 300 | else: 301 | v_str += question[i - 1] 302 | elif v_op != -1: 303 | v_end = v_start + len(v_str) 304 | csel = pcsel[0][v_start: v_end].mean(0).argmax() 305 | conds.append((csel, v_op, v_str)) 306 | v_op = -1 307 | R['conds'] = set() 308 | for i, j, k in conds: 309 | if re.findall('[^\d\.]', k): 310 | j = 2 # 非数字只能用等号 311 | if j == 2: 312 | if k not in table['all_values']: 313 | # 等号的值必须在table出现过,否则找一个最相近的 314 | k = most_similar(k, list(table['all_values'])) 315 | h = table['headers'][i] 316 | # 然后检查值对应的列是否正确,如果不正确,直接修正列名 317 | if k not in table['content'][h]: 318 | for r, v in table['content'].items(): 319 | if k in v: 320 | i = table['header2id'][r] 321 | break 322 | R['conds'].add((i, j, k)) 323 | R['conds'] = list(R['conds']) 324 | if len(R['conds']) <= 1: # 条件数少于等于1时,条件连接符直接为0 325 | R['cond_conn_op'] = 0 326 | else: 327 | R['cond_conn_op'] = 1 + pconn[0, 1:].argmax() # 不能是0 328 | return R 329 | 330 | 331 | def is_equal(R1, R2): 332 | """判断两个SQL字典是否全匹配 333 | """ 334 | return (R1['cond_conn_op'] == R2['cond_conn_op']) &\ 335 | (set(zip(R1['sel'], R1['agg'])) == set(zip(R2['sel'], R2['agg']))) &\ 336 | (set([tuple(i) for i in R1['conds']]) == set([tuple(i) for i in R2['conds']])) 337 | 338 | 339 | def evaluate(data, tables): 340 | right = 0. 341 | pbar = tqdm() 342 | F = open('evaluate_pred.json', 'w') 343 | for i, d in enumerate(data): 344 | question = d['question'] 345 | table = tables[d['table_id']] 346 | R = nl2sql(question, table) 347 | right += float(is_equal(R, d['sql'])) 348 | pbar.update(1) 349 | pbar.set_description('< acc: %.5f >' % (right / (i + 1))) 350 | d['sql_pred'] = R 351 | s = json.dumps(d, ensure_ascii=False, indent=4) 352 | F.write(s.encode('utf-8') + '\n') 353 | F.close() 354 | pbar.close() 355 | return right / len(data) 356 | 357 | 358 | def test(data, tables, outfile='result.json'): 359 | pbar = tqdm() 360 | F = open(outfile, 'w') 361 | for i, d in enumerate(data): 362 | question = d['question'] 363 | table = tables[d['table_id']] 364 | R = nl2sql(question, table) 365 | pbar.update(1) 366 | s = json.dumps(R, ensure_ascii=False) 367 | F.write(s.encode('utf-8') + '\n') 368 | F.close() 369 | pbar.close() 370 | 371 | # test(test_data, test_tables) 372 | 373 | 374 | class Evaluate(Callback): 375 | def __init__(self): 376 | self.accs = [] 377 | self.best = 0. 378 | self.passed = 0 379 | self.stage = 0 380 | def on_batch_begin(self, batch, logs=None): 381 | """第一个epoch用来warmup,第二个epoch把学习率降到最低 382 | """ 383 | if self.passed < self.params['steps']: 384 | lr = (self.passed + 1.) / self.params['steps'] * learning_rate 385 | K.set_value(self.model.optimizer.lr, lr) 386 | self.passed += 1 387 | elif self.params['steps'] <= self.passed < self.params['steps'] * 2: 388 | lr = (2 - (self.passed + 1.) / self.params['steps']) * (learning_rate - min_learning_rate) 389 | lr += min_learning_rate 390 | K.set_value(self.model.optimizer.lr, lr) 391 | self.passed += 1 392 | def on_epoch_end(self, epoch, logs=None): 393 | acc = self.evaluate() 394 | self.accs.append(acc) 395 | if acc > self.best: 396 | self.best = acc 397 | train_model.save_weights('best_model.weights') 398 | print 'acc: %.5f, best acc: %.5f\n' % (acc, self.best) 399 | def evaluate(self): 400 | return evaluate(valid_data, valid_tables) 401 | 402 | 403 | train_D = data_generator(train_data, train_tables) 404 | evaluator = Evaluate() 405 | 406 | if __name__ == '__main__': 407 | train_model.fit_generator( 408 | train_D.__iter__(), 409 | steps_per_epoch=len(train_D), 410 | epochs=15, 411 | callbacks=[evaluator] 412 | ) 413 | else: 414 | train_model.load_weights('best_model.weights') 415 | --------------------------------------------------------------------------------